diff --git a/.rat-excludes b/.rat-excludes index c0f81b57fe09d..994c7e86f8a91 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -80,5 +80,8 @@ local-1425081759269/* local-1426533911241/* local-1426633911242/* local-1430917381534/* +local-1430917381535_1 +local-1430917381535_2 DESCRIPTION NAMESPACE +test_support/* diff --git a/LICENSE b/LICENSE index 9d1b00beff748..d0cd0dcb4bdb7 100644 --- a/LICENSE +++ b/LICENSE @@ -853,6 +853,52 @@ and Vis.js may be distributed under either license. +======================================================================== +For dagre-d3 (core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js): +======================================================================== +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +======================================================================== +For graphlib-dot (core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js): +======================================================================== +Copyright (c) 2012-2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + ======================================================================== BSD-style licenses ======================================================================== diff --git a/R/create-docs.sh b/R/create-docs.sh index 4194172a2e115..6a4687b06ecb9 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -23,14 +23,14 @@ # After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html +set -o pipefail +set -e + # Figure out where the script is export FWDIR="$(cd "`dirname "$0"`"; pwd)" pushd $FWDIR -# Generate Rd file -Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))' - -# Install the package +# Install the package (this will also generate the Rd files) ./install-dev.sh # Now create HTML files diff --git a/R/install-dev.sh b/R/install-dev.sh index 55ed6f4be1a4a..1edd551f8d243 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -26,11 +26,20 @@ # NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory # to load the SparkR package on the worker nodes. +set -o pipefail +set -e FWDIR="$(cd `dirname $0`; pwd)" LIB_DIR="$FWDIR/lib" mkdir -p $LIB_DIR -# Install R +pushd $FWDIR + +# Generate Rd files if devtools is installed +Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' + +# Install SparkR to $LIB_DIR R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ + +popd diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 411126a377950..f9447f6c3288d 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -19,9 +19,11 @@ exportMethods("arrange", "count", "describe", "distinct", + "dropna", "dtypes", "except", "explain", + "fillna", "filter", "first", "group_by", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index ed8093c80d360..0af5cb8881e35 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1314,9 +1314,8 @@ setMethod("except", #' write.df(df, "myfile", "parquet", "overwrite") #' } setMethod("write.df", - signature(df = "DataFrame", path = 'character', source = 'character', - mode = 'character'), - function(df, path = NULL, source = NULL, mode = "append", ...){ + signature(df = "DataFrame", path = 'character'), + function(df, path, source = NULL, mode = "append", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", @@ -1338,9 +1337,8 @@ setMethod("write.df", #' @aliases saveDF #' @export setMethod("saveDF", - signature(df = "DataFrame", path = 'character', source = 'character', - mode = 'character'), - function(df, path = NULL, source = NULL, mode = "append", ...){ + signature(df = "DataFrame", path = 'character'), + function(df, path, source = NULL, mode = "append", ...){ write.df(df, path, source, mode, ...) }) @@ -1431,3 +1429,128 @@ setMethod("describe", sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) dataFrame(sdf) }) + +#' dropna +#' +#' Returns a new DataFrame omitting rows with null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param how "any" or "all". +#' if "any", drop a row if it contains any nulls. +#' if "all", drop a row only if all its values are null. +#' if minNonNulls is specified, how is ignored. +#' @param minNonNulls If specified, drop rows that have less than +#' minNonNulls non-null values. +#' This overwrites the how parameter. +#' @param cols Optional list of column names to consider. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dropna(df) +#' } +setMethod("dropna", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + how <- match.arg(how) + if (is.null(cols)) { + cols <- columns(x) + } + if (is.null(minNonNulls)) { + minNonNulls <- if (how == "any") { length(cols) } else { 1 } + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- callJMethod(naFunctions, "drop", + as.integer(minNonNulls), listToSeq(as.list(cols))) + dataFrame(sdf) + }) + +#' @aliases dropna +#' @export +setMethod("na.omit", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + dropna(x, how, minNonNulls, cols) + }) + +#' fillna +#' +#' Replace null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param value Value to replace null values with. +#' Should be an integer, numeric, character or named list. +#' If the value is a named list, then cols is ignored and +#' value must be a mapping from column name (character) to +#' replacement value. The replacement value must be an +#' integer, numeric or character. +#' @param cols optional list of column names to consider. +#' Columns specified in cols that do not have matching data +#' type are ignored. For example, if value is a character, and +#' subset contains a non-character column, then the non-character +#' column is simply ignored. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' fillna(df, 1) +#' fillna(df, list("age" = 20, "name" = "unknown")) +#' } +setMethod("fillna", + signature(x = "DataFrame"), + function(x, value, cols = NULL) { + if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { + stop("value should be an integer, numeric, charactor or named list.") + } + + if (class(value) == "list") { + # Check column names in the named list + colNames <- names(value) + if (length(colNames) == 0 || !all(colNames != "")) { + stop("value should be an a named list with each name being a column name.") + } + + # Convert to the named list to an environment to be passed to JVM + valueMap <- new.env() + for (col in colNames) { + # Check each item in the named list is of valid type + v <- value[[col]] + if (!(class(v) %in% c("integer", "numeric", "character"))) { + stop("Each item in value should be an integer, numeric or charactor.") + } + valueMap[[col]] <- v + } + + # When value is a named list, caller is expected not to pass in cols + if (!is.null(cols)) { + warning("When value is a named list, cols is ignored!") + cols <- NULL + } + + value <- valueMap + } else if (is.integer(value)) { + # Cast an integer to a numeric + value <- as.numeric(value) + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- if (length(cols) == 0) { + callJMethod(naFunctions, "fill", value) + } else { + callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols))) + } + dataFrame(sdf) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 36cc612875879..22a4b5bf86ebd 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -452,20 +452,31 @@ dropTempTable <- function(sqlContext, tableName) { #' df <- read.df(sqlContext, "path/to/file.json", source = "json") #' } -read.df <- function(sqlContext, path = NULL, source = NULL, ...) { +read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path } - sdf <- callJMethod(sqlContext, "load", source, options) + if (is.null(source)) { + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + if (!is.null(schema)) { + stopifnot(class(schema) == "structType") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, + schema$jobj, options) + } else { + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options) + } dataFrame(sdf) } #' @aliases loadDF #' @export -loadDF <- function(sqlContext, path = NULL, source = NULL, ...) { - read.df(sqlContext, path, source, ...) +loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { + read.df(sqlContext, path, source, schema, ...) } #' Create an external table diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index a23d3b217b2fd..12e09176c9f92 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -396,6 +396,20 @@ setGeneric("columns", function(x) {standardGeneric("columns") }) #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) +#' @rdname nafunctions +#' @export +setGeneric("dropna", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("dropna") + }) + +#' @rdname nafunctions +#' @export +setGeneric("na.omit", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("na.omit") + }) + #' @rdname schema #' @export setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) @@ -408,6 +422,10 @@ setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @export setGeneric("except", function(x, y) { standardGeneric("except") }) +#' @rdname nafunctions +#' @export +setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) + #' @rdname filter #' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) @@ -482,11 +500,11 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { #' @rdname write.df #' @export -setGeneric("write.df", function(df, path, source, mode, ...) { standardGeneric("write.df") }) +setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) #' @rdname write.df #' @export -setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") }) +setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") }) #' @rdname schema #' @export diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index c53d0a961016f..3169d7968f8fe 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -37,6 +37,14 @@ writeObject <- function(con, object, writeType = TRUE) { # passing in vectors as arrays and instead require arrays to be passed # as lists. type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") + # Checking types is needed here, since ‘is.na’ only handles atomic vectors, + # lists and pairlists + if (type %in% c("integer", "character", "logical", "double", "numeric")) { + if (is.na(object)) { + object <- NULL + type <- "NULL" + } + } if (writeType) { writeType(con, type) } @@ -160,6 +168,14 @@ writeList <- function(con, arr) { } } +# Used to pass arrays where the elements can be of different types +writeGenericList <- function(con, list) { + writeInt(con, length(list)) + for (elem in list) { + writeObject(con, elem) + } +} + # Used to pass in hash maps required on Java side. writeEnv <- function(con, env) { len <- length(env) @@ -168,7 +184,7 @@ writeEnv <- function(con, env) { if (len > 0) { writeList(con, as.list(ls(env))) vals <- lapply(ls(env), function(x) { env[[x]] }) - writeList(con, as.list(vals)) + writeGenericList(con, as.list(vals)) } } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 68387f0f5365d..5ced7c688f98a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -225,14 +225,21 @@ sparkR.init <- function( #' sqlContext <- sparkRSQL.init(sc) #'} -sparkRSQL.init <- function(jsc) { +sparkRSQL.init <- function(jsc = NULL) { if (exists(".sparkRSQLsc", envir = .sparkREnv)) { return(get(".sparkRSQLsc", envir = .sparkREnv)) } + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "createSQLContext", - jsc) + "createSQLContext", + sc) assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv) sqlContext } @@ -249,12 +256,19 @@ sparkRSQL.init <- function(jsc) { #' sqlContext <- sparkRHive.init(sc) #'} -sparkRHive.init <- function(jsc) { +sparkRHive.init <- function(jsc = NULL) { if (exists(".sparkRHivesc", envir = .sparkREnv)) { return(get(".sparkRHivesc", envir = .sparkREnv)) } - ssc <- callJMethod(jsc, "sc") + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + + ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.HiveContext", ssc) }, error = function(err) { diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index ca94f1d4e7fd5..773b6ecf582d9 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -24,7 +24,7 @@ old <- getOption("defaultPackages") options(defaultPackages = c(old, "SparkR")) - sc <- SparkR::sparkR.init(Sys.getenv("MASTER", unset = "")) + sc <- SparkR::sparkR.init() assign("sc", sc, envir=.GlobalEnv) sqlContext <- SparkR::sparkRSQL.init(sc) assign("sqlContext", sqlContext, envir=.GlobalEnv) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1857e636e8577..8946348ef801c 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -32,6 +32,15 @@ jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") writeLines(mockLines, jsonPath) +# For test nafunctions, like dropna(), fillna(),... +mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", + "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", + "{\"name\":\"David\",\"age\":60,\"height\":null}", + "{\"name\":\"Amy\",\"age\":null,\"height\":null}", + "{\"name\":null,\"age\":null,\"height\":null}") +jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesNa, jsonPathNa) + test_that("infer types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") @@ -92,6 +101,43 @@ test_that("create DataFrame from RDD", { expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) +test_that("convert NAs to null type in DataFrames", { + rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4L) + + l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1L) + expect_true(is.na(collect(df)[2, "y"])) + + rdd <- parallelize(sc, list(list(1, 2), list(NA, 4))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4) + + l <- data.frame(x = 1, y = c(1, NA_real_, 3)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1) + expect_true(is.na(collect(df)[2, "y"])) + + l <- list("a", "b", NA, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list("a", "b", NA_character_, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list(TRUE, FALSE, NA, TRUE) + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], TRUE) +}) + test_that("toDF", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) @@ -495,6 +541,19 @@ test_that("read.df() from json file", { df <- read.df(sqlContext, jsonPath, "json") expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 3) + + # Check if we can apply a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(sqlContext, jsonPath, "json", schema) + expect_true(inherits(df1, "DataFrame")) + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Run the same with loadDF + df2 <- loadDF(sqlContext, jsonPath, "json", schema) + expect_true(inherits(df2, "DataFrame")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) }) test_that("write.df() as parquet file", { @@ -765,5 +824,105 @@ test_that("describe() on a DataFrame", { expect_equal(collect(stats)[5, "age"], "30") }) +test_that("dropna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # drop with columns + + expected <- rows[!is.na(rows$name),] + actual <- collect(dropna(df, cols = "name")) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age),] + actual <- collect(dropna(df, cols = "age")) + row.names(expected) <- row.names(actual) + # identical on two dataframes does not work here. Don't know why. + # use identical on all columns as a workaround. + expect_true(identical(expected$age, actual$age)) + expect_true(identical(expected$height, actual$height)) + expect_true(identical(expected$name, actual$name)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_true(identical(expected, actual)) + + # drop with how + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] + actual <- collect(dropna(df, "all")) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df, "any")) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, "any", cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height),] + actual <- collect(dropna(df, "all", cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + # drop with threshold + + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] + actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + expected <- rows[as.integer(!is.na(rows$age)) + + as.integer(!is.na(rows$height)) + + as.integer(!is.na(rows$name)) >= 3,] + actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_true(identical(expected, actual)) +}) + +test_that("fillna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # fill with value + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + actual <- collect(fillna(df, 50.6)) + expect_true(identical(expected, actual)) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown")) + expect_true(identical(expected, actual)) + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + actual <- collect(fillna(df, 50.6, "age")) + expect_true(identical(expected, actual)) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown", c("age", "name"))) + expect_true(identical(expected, actual)) + + # fill with named list + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) + expect_true(identical(expected, actual)) +}) + unlink(parquetPath) unlink(jsonPath) +unlink(jsonPathNa) diff --git a/README.md b/README.md index 9c09d40e2bdae..380422ca00dbe 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@ Spark is a fast and general cluster computing system for Big Data. It provides high-level APIs in Scala, Java, and Python, and an optimized engine that supports general computation graphs for data analysis. It also supports a -rich set of higher-level tools including Spark SQL for SQL and structured -data processing, MLlib for machine learning, GraphX for graph processing, +rich set of higher-level tools including Spark SQL for SQL and DataFrames, +MLlib for machine learning, GraphX for graph processing, and Spark Streaming for stream processing. @@ -22,7 +22,7 @@ This README file only contains basic setup instructions. Spark is built using [Apache Maven](http://maven.apache.org/). To build Spark and its example programs, run: - mvn -DskipTests clean package + build/mvn -DskipTests clean package (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at @@ -43,7 +43,7 @@ Try the following command, which should return 1000: Alternatively, if you prefer Python, you can use the Python shell: ./bin/pyspark - + And run the following command, which should also return 1000: >>> sc.parallelize(range(1000)).count() @@ -58,9 +58,9 @@ To run one of them, use `./bin/run-example [params]`. For example: will run the Pi example locally. You can set the MASTER environment variable when running examples to submit -examples to a cluster. This can be a mesos:// or spark:// URL, -"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run -locally with one thread, or "local[N]" to run locally with N threads. You +examples to a cluster. This can be a mesos:// or spark:// URL, +"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run +locally with one thread, or "local[N]" to run locally with N threads. You can also use an abbreviated class name if the class is in the `examples` package. For instance: @@ -75,7 +75,7 @@ can be run using: ./dev/run-tests -Please see the guidance on how to +Please see the guidance on how to [run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools). ## A Note About Hadoop Versions diff --git a/assembly/pom.xml b/assembly/pom.xml index 626c8577e31fe..e9c6d26ccddc7 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index 1f3dec91314f2..ed5c37e595a96 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala index ccb262a4ee02a..fb10d734ac74b 100644 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.bagel -import org.scalatest.{BeforeAndAfter, FunSuite, Assertions} +import org.scalatest.{BeforeAndAfter, Assertions} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -27,7 +27,7 @@ import org.apache.spark.storage.StorageLevel class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable -class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { +class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { var sc: SparkContext = _ diff --git a/bin/pyspark b/bin/pyspark index 8acad6113797d..f9dbddfa53560 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -17,24 +17,10 @@ # limitations under the License. # -# Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" source "$SPARK_HOME"/bin/load-spark-env.sh - -function usage() { - if [ -n "$1" ]; then - echo $1 - fi - echo "Usage: ./bin/pyspark [options]" 1>&2 - "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit $2 -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage -fi +export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` # executable, while the worker would still be launched using PYSPARK_PYTHON. @@ -90,11 +76,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 - else - exec "$PYSPARK_DRIVER_PYTHON" $1 - fi + exec "$PYSPARK_DRIVER_PYTHON" -m $1 exit fi diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 09b4149c2a439..45e9e3def5121 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -21,6 +21,7 @@ rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. call %SPARK_HOME%\bin\load-spark-env.cmd +set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options] rem Figure out which Python to use. if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( diff --git a/bin/spark-class b/bin/spark-class index c49d97ce5cf25..2b59e5df5736f 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -16,18 +16,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -set -e # Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" . "$SPARK_HOME"/bin/load-spark-env.sh -if [ -z "$1" ]; then - echo "Usage: spark-class []" 1>&2 - exit 1 -fi - # Find the java binary if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" @@ -64,24 +58,6 @@ fi SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" -# Verify that versions of java used to build the jars and run Spark are compatible -if [ -n "$JAVA_HOME" ]; then - JAR_CMD="$JAVA_HOME/bin/jar" -else - JAR_CMD="jar" -fi - -if [ $(command -v "$JAR_CMD") ] ; then - jar_error_check=$("$JAR_CMD" -tf "$SPARK_ASSEMBLY_JAR" nonexistent/class/path 2>&1) - if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then - echo "Loading Spark jar with '$JAR_CMD' failed. " 1>&2 - echo "This is likely because Spark was compiled with Java 7 and run " 1>&2 - echo "with Java 6. (see SPARK-1703). Please use Java 7 to run Spark " 1>&2 - echo "or build Spark with Java 6." 1>&2 - exit 1 - fi -fi - LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR" # Add the launcher build dir to the classpath if requested. @@ -98,9 +74,4 @@ CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@") - -if [ "${CMD[0]}" = "usage" ]; then - "${CMD[@]}" -else - exec "${CMD[@]}" -fi +exec "${CMD[@]}" diff --git a/bin/spark-shell b/bin/spark-shell index b3761b5e1375b..a6dc863d83fc6 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -29,20 +29,7 @@ esac set -o posix export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -usage() { - if [ -n "$1" ]; then - echo "$1" - fi - echo "Usage: ./bin/spark-shell [options]" - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit "$2" -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage "" 0 -fi +export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" # SPARK-4161: scala does not assume use of the java classpath, # so we need to add the "-Dscala.usejavacp=true" flag manually. We diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 00fd30fa38d36..251309d67f860 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -18,12 +18,7 @@ rem limitations under the License. rem set SPARK_HOME=%~dp0.. - -echo "%*" | findstr " \<--help\> \<-h\>" >nul -if %ERRORLEVEL% equ 0 ( - call :usage - exit /b 0 -) +set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options] rem SPARK-4161: scala does not assume use of the java classpath, rem so we need to add the "-Dscala.usejavacp=true" flag manually. We @@ -37,16 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" ( set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" :run_shell -call %SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* -set SPARK_ERROR_LEVEL=%ERRORLEVEL% -if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( - call :usage - exit /b 1 -) -exit /b %SPARK_ERROR_LEVEL% - -:usage -echo %SPARK_LAUNCHER_USAGE_ERROR% -echo "Usage: .\bin\spark-shell.cmd [options]" >&2 -call %SPARK_HOME%\bin\spark-submit2.cmd --help 2>&1 | findstr /V "Usage" 1>&2 -goto :eof +%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* diff --git a/bin/spark-sql b/bin/spark-sql index ca1729f4cfcb4..4ea7bc6e39c07 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -17,41 +17,6 @@ # limitations under the License. # -# -# Shell script for starting the Spark SQL CLI - -# Enter posix mode for bash -set -o posix - -# NOTE: This exact class name is matched downstream by SparkSubmit. -# Any changes need to be reflected there. -export CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" - -# Figure out where Spark is installed export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -function usage { - if [ -n "$1" ]; then - echo "$1" - fi - echo "Usage: ./bin/spark-sql [options] [cli option]" - pattern="usage" - pattern+="\|Spark assembly has been built with Hive" - pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" - pattern+="\|Spark Command: " - pattern+="\|--help" - pattern+="\|=======" - - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - echo - echo "CLI options:" - "$FWDIR"/bin/spark-class "$CLASS" --help 2>&1 | grep -v "$pattern" 1>&2 - exit "$2" -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage "" 0 -fi - -exec "$FWDIR"/bin/spark-submit --class "$CLASS" "$@" +export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" +exec "$FWDIR"/bin/spark-submit --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver "$@" diff --git a/bin/spark-submit b/bin/spark-submit index 0e0afe71a0f05..255378b0f077c 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -22,16 +22,4 @@ SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" # disable randomized hash for string in Python 3.3+ export PYTHONHASHSEED=0 -# Only define a usage function if an upstream script hasn't done so. -if ! type -t usage >/dev/null 2>&1; then - usage() { - if [ -n "$1" ]; then - echo "$1" - fi - "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit --help - exit "$2" - } - export -f usage -fi - exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index d3fc4a5cc3f6e..651376e526928 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -24,15 +24,4 @@ rem disable randomized hash for string in Python 3.3+ set PYTHONHASHSEED=0 set CLASS=org.apache.spark.deploy.SparkSubmit -call %~dp0spark-class2.cmd %CLASS% %* -set SPARK_ERROR_LEVEL=%ERRORLEVEL% -if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( - call :usage - exit /b 1 -) -exit /b %SPARK_ERROR_LEVEL% - -:usage -echo %SPARK_LAUNCHER_USAGE_ERROR% -call %SPARK_HOME%\bin\spark-class2.cmd %CLASS% --help -goto :eof +%~dp0spark-class2.cmd %CLASS% %* diff --git a/bin/sparkR b/bin/sparkR index 8c918e2b09aef..464c29f369424 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -17,23 +17,7 @@ # limitations under the License. # -# Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" - source "$SPARK_HOME"/bin/load-spark-env.sh - -function usage() { - if [ -n "$1" ]; then - echo $1 - fi - echo "Usage: ./bin/sparkR [options]" 1>&2 - "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit $2 -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage -fi - +export _SPARK_CMD_USAGE="Usage: ./bin/sparkR [options]" exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 7de0011a48ca8..7f17bc7eea4f5 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -4,7 +4,7 @@ # divided into instances which correspond to internal components. # Each instance can be configured to report its metrics to one or more sinks. # Accepted values for [instance] are "master", "worker", "executor", "driver", -# and "applications". A wild card "*" can be used as an instance name, in +# and "applications". A wildcard "*" can be used as an instance name, in # which case all instances will inherit the supplied property. # # Within an instance, a "source" specifies a particular set of grouped metrics. @@ -32,7 +32,7 @@ # name (see examples below). # 2. Some sinks involve a polling period. The minimum allowed polling period # is 1 second. -# 3. Wild card properties can be overridden by more specific properties. +# 3. Wildcard properties can be overridden by more specific properties. # For example, master.sink.console.period takes precedence over # *.sink.console.period. # 4. A metrics specific configuration @@ -47,6 +47,13 @@ # instance master and applications. MetricsServlet may not be configured by self. # +## List of available common sources and their properties. + +# org.apache.spark.metrics.source.JvmSource +# Note: Currently, JvmSource is the only available common source +# to add additionaly to an instance, to enable this, +# set the "class" option to its fully qulified class name (see examples below) + ## List of available sinks and their properties. # org.apache.spark.metrics.sink.ConsoleSink diff --git a/core/pom.xml b/core/pom.xml index bfa49d0d6dc25..40a64beccdc24 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -338,6 +338,12 @@ org.seleniumhq.selenium selenium-java + + + com.google.guava + guava + + test @@ -377,9 +383,15 @@ test - org.spark-project + net.razorvine pyrolite 4.4 + + + net.razorvine + serpent + + net.sf.py4j diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java new file mode 100644 index 0000000000000..d3d6280284beb --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import com.google.common.io.Closeables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.Partitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.storage.*; +import org.apache.spark.util.Utils; + +/** + * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path + * writes incoming records to separate files, one file per reduce partition, then concatenates these + * per-partition files to form a single output file, regions of which are served to reducers. + * Records are not buffered in memory. This is essentially identical to + * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format + * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}. + *

+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it + * simultaneously opens separate serializers and file streams for all partitions. As a result, + * {@link SortShuffleManager} only selects this write path when + *

    + *
  • no Ordering is specified,
  • + *
  • no Aggregator is specific, and
  • + *
  • the number of partitions is less than + * spark.shuffle.sort.bypassMergeThreshold.
  • + *
+ * + * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was + * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details. + *

+ * There have been proposals to completely remove this code path; see SPARK-6026 for details. + */ +final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter { + + private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); + + private final int fileBufferSize; + private final boolean transferToEnabled; + private final int numPartitions; + private final BlockManager blockManager; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final Serializer serializer; + + /** Array of file writers, one for each partition */ + private BlockObjectWriter[] partitionWriters; + + public BypassMergeSortShuffleWriter( + SparkConf conf, + BlockManager blockManager, + Partitioner partitioner, + ShuffleWriteMetrics writeMetrics, + Serializer serializer) { + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); + this.numPartitions = partitioner.numPartitions(); + this.blockManager = blockManager; + this.partitioner = partitioner; + this.writeMetrics = writeMetrics; + this.serializer = serializer; + } + + @Override + public void insertAll(Iterator> records) throws IOException { + assert (partitionWriters == null); + if (!records.hasNext()) { + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new BlockObjectWriter[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open(); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime); + + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + } + + for (BlockObjectWriter writer : partitionWriters) { + writer.commitAndClose(); + } + } + + @Override + public long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException { + // Track location of the partition starts in the output file + final long[] lengths = new long[numPartitions]; + if (partitionWriters == null) { + // We were passed an empty iterator + return lengths; + } + + final FileOutputStream out = new FileOutputStream(outputFile, true); + final long writeStartTime = System.nanoTime(); + boolean threwException = true; + try { + for (int i = 0; i < numPartitions; i++) { + final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file()); + boolean copyThrewException = true; + try { + lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) { + logger.error("Unable to delete file for partition {}", i); + } + } + threwException = false; + } finally { + Closeables.close(out, threwException); + writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + } + partitionWriters = null; + return lengths; + } + + @Override + public void stop() throws IOException { + if (partitionWriters != null) { + try { + final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); + for (BlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + writer.revertPartialWritesAndClose(); + if (!diskBlockManager.getFile(writer.blockId()).delete()) { + logger.error("Error while deleting file for block {}", writer.blockId()); + } + } + } finally { + partitionWriters = null; + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java new file mode 100644 index 0000000000000..656ea0401a144 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.IOException; + +import scala.Product2; +import scala.collection.Iterator; + +import org.apache.spark.annotation.Private; +import org.apache.spark.TaskContext; +import org.apache.spark.storage.BlockId; + +/** + * Interface for objects that {@link SortShuffleWriter} uses to write its output files. + */ +@Private +public interface SortShuffleFileWriter { + + void insertAll(Iterator> records) throws IOException; + + /** + * Write all the data added into this shuffle sorter into a file in the disk store. This is + * called by the SortShuffleWriter and can go through an efficient path of just concatenating + * binary files if we decided to avoid merge-sorting. + * + * @param blockId block ID to write to. The index file will be blockId.name + ".index". + * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException; + + void stop() throws IOException; +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index 013db8df9b363..0b450dc76bc38 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -50,4 +50,9 @@ $(function() { $("span.additional-metric-title").click(function() { $(this).parent().find('input[type="checkbox"]').trigger('click'); }); + + // Trigger a double click on the span to show full job description. + $(".description-input").dblclick(function() { + $(this).removeClass("description-input").addClass("description-input-full"); + }); }); diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index dbacbf19beee5..dde6069000bc4 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -100,7 +100,7 @@ sorttable = { this.removeChild(document.getElementById('sorttable_sortfwdind')); sortrevind = document.createElement('span'); sortrevind.id = "sorttable_sortrevind"; - sortrevind.innerHTML = stIsIE ? ' 5' : ' ▴'; + sortrevind.innerHTML = stIsIE ? ' 5' : ' ▾'; this.appendChild(sortrevind); return; } @@ -113,7 +113,7 @@ sorttable = { this.removeChild(document.getElementById('sorttable_sortrevind')); sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); return; } @@ -134,7 +134,7 @@ sorttable = { this.className += ' sorttable_sorted'; sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); // build an array to sort. This is a Schwartzian transform thing, diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index aaeba5b1027c9..7a0dec2a3eaec 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -193,7 +193,7 @@ function renderDagVizForJob(svgContainer) { // Use the link from the stage table so it also works for the history server var attemptId = 0 var stageLink = d3.select("#stage-" + stageId + "-" + attemptId) - .select("a") + .select("a.name-link") .attr("href") + "&expandDagViz=true"; container = svgContainer .append("a") @@ -235,7 +235,7 @@ function renderDagVizForJob(svgContainer) { // them separately later. Note that we cannot draw them now because we need to // put these edges in a separate container that is on top of all stage graphs. metadata.selectAll(".incoming-edge").each(function(v) { - var edge = d3.select(this).text().split(","); // e.g. 3,4 => [3, 4] + var edge = d3.select(this).text().trim().split(","); // e.g. 3,4 => [3, 4] crossStageEdges.push(edge); }); }); diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index 604c29994145a..ca74ef9d7e94e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -46,7 +46,7 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) { }; $(this).click(function() { - var jobPagePath = $(getSelectorForJobEntry(this)).find("a").attr("href") + var jobPagePath = $(getSelectorForJobEntry(this)).find("a.name-link").attr("href") window.location.href = jobPagePath }); @@ -105,7 +105,7 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { }; $(this).click(function() { - var stagePagePath = $(getSelectorForStageEntry(this)).find("a").attr("href") + var stagePagePath = $(getSelectorForStageEntry(this)).find("a.name-link").attr("href") window.location.href = stagePagePath }); diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index e7c1d475d4e52..b1cef47042247 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -135,6 +135,14 @@ pre { display: block; } +.description-input-full { + overflow: hidden; + text-overflow: ellipsis; + width: 100%; + white-space: normal; + display: block; +} + .stacktrace-details { max-height: 300px; overflow-y: auto; diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 330df1d59a9b1..5a8d17bd99933 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -228,7 +228,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @tparam T result type */ class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) - extends Accumulable[T,T](initialValue, param, name) { + extends Accumulable[T, T](initialValue, param, name) { def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) } diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index af9765d313e9e..ceeb58075d345 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -34,8 +34,8 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. + // When spilling is enabled sorting will happen externally, but not necessarily with an + // ExternalSorter. private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") @@ -45,7 +45,7 @@ case class Aggregator[K, V, C] ( def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], context: TaskContext): Iterator[(K, C)] = { if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K,C] + val combiners = new AppendOnlyMap[K, C] var kv: Product2[K, V] = null val update = (hadValue: Boolean, oldValue: C) => { if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) @@ -76,7 +76,7 @@ case class Aggregator[K, V, C] ( : Iterator[(K, C)] = { if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K,C] + val combiners = new AppendOnlyMap[K, C] var kc: Product2[K, C] = null val update = (hadValue: Boolean, oldValue: C) => { if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 9514604752640..49329423dca76 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -101,6 +101,9 @@ private[spark] class ExecutorAllocationManager( private val executorIdleTimeoutS = conf.getTimeAsSeconds( "spark.dynamicAllocation.executorIdleTimeout", "60s") + private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${2 * executorIdleTimeoutS}s") + // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) @@ -150,6 +153,13 @@ private[spark] class ExecutorAllocationManager( // Metric source for ExecutorAllocationManager to expose internal status to MetricsSystem. val executorAllocationManagerSource = new ExecutorAllocationManagerSource + // Whether we are still waiting for the initial set of executors to be allocated. + // While this is true, we will not cancel outstanding executor requests. This is + // set to false when: + // (1) a stage is submitted, or + // (2) an executor idle timeout has elapsed. + @volatile private var initializing: Boolean = true + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -240,6 +250,7 @@ private[spark] class ExecutorAllocationManager( removeTimes.retain { case (executorId, expireTime) => val expired = now >= expireTime if (expired) { + initializing = false removeExecutor(executorId) } !expired @@ -261,15 +272,23 @@ private[spark] class ExecutorAllocationManager( private def updateAndSyncNumExecutorsTarget(now: Long): Int = synchronized { val maxNeeded = maxNumExecutorsNeeded - if (maxNeeded < numExecutorsTarget) { + if (initializing) { + // Do not change our target while we are still initializing, + // Otherwise the first job may have to ramp up unnecessarily + 0 + } else if (maxNeeded < numExecutorsTarget) { // The target number exceeds the number we actually need, so stop adding new // executors and inform the cluster manager to cancel the extra pending requests val oldNumExecutorsTarget = numExecutorsTarget numExecutorsTarget = math.max(maxNeeded, minNumExecutors) - client.requestTotalExecutors(numExecutorsTarget) numExecutorsToAdd = 1 - logInfo(s"Lowering target number of executors to $numExecutorsTarget because " + - s"not all requests are actually needed (previously $oldNumExecutorsTarget)") + + // If the new target has not changed, avoid sending a message to the cluster manager + if (numExecutorsTarget < oldNumExecutorsTarget) { + client.requestTotalExecutors(numExecutorsTarget) + logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") + } numExecutorsTarget - oldNumExecutorsTarget } else if (addTime != NOT_SET && now >= addTime) { val delta = addExecutors(maxNeeded) @@ -443,9 +462,23 @@ private[spark] class ExecutorAllocationManager( private def onExecutorIdle(executorId: String): Unit = synchronized { if (executorIds.contains(executorId)) { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { + // Note that it is not necessary to query the executors since all the cached + // blocks we are concerned with are reported to the driver. Note that this + // does not include broadcast blocks. + val hasCachedBlocks = SparkEnv.get.blockManager.master.hasCachedBlocks(executorId) + val now = clock.getTimeMillis() + val timeout = { + if (hasCachedBlocks) { + // Use a different timeout if the executor has cached blocks. + now + cachedExecutorIdleTimeoutS * 1000 + } else { + now + executorIdleTimeoutS * 1000 + } + } + val realTimeout = if (timeout <= 0) Long.MaxValue else timeout // overflow + removeTimes(executorId) = realTimeout logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $executorIdleTimeoutS seconds)") - removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeoutS * 1000 + s"scheduled to run on the executor (to expire in ${(realTimeout - now)/1000} seconds)") } } else { logWarning(s"Attempted to mark unknown executor $executorId idle") @@ -477,6 +510,7 @@ private[spark] class ExecutorAllocationManager( private var numRunningTasks: Int = _ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + initializing = false val stageId = stageSubmitted.stageInfo.stageId val numTasks = stageSubmitted.stageInfo.numTasks allocationManager.synchronized { diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 91f9ef8ce7185..48792a958130c 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -150,7 +150,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def isCompleted: Boolean = jobWaiter.jobFinished - + override def isCancelled: Boolean = _cancelled override def value: Option[Try[T]] = { diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index f2b024ff6cb67..6909015ff66e6 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -29,7 +29,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal - * components to convey liveness or execution information for in-progress tasks. It will also + * components to convey liveness or execution information for in-progress tasks. It will also * expire the hosts that have not heartbeated for more than spark.network.timeout. */ private[spark] case class Heartbeat( @@ -43,8 +43,8 @@ private[spark] case class Heartbeat( */ private[spark] case object TaskSchedulerIsSet -private[spark] case object ExpireDeadHosts - +private[spark] case object ExpireDeadHosts + private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** @@ -62,18 +62,18 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses // "milliseconds" - private val slaveTimeoutMs = + private val slaveTimeoutMs = sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") - private val executorTimeoutMs = + private val executorTimeoutMs = sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 - + // "spark.network.timeoutInterval" uses "seconds", while // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" - private val timeoutIntervalMs = + private val timeoutIntervalMs = sc.conf.getTimeAsMs("spark.storage.blockManagerTimeoutIntervalMs", "60s") - private val checkTimeoutIntervalMs = + private val checkTimeoutIntervalMs = sc.conf.getTimeAsSeconds("spark.network.timeoutInterval", s"${timeoutIntervalMs}ms") * 1000 - + private var timeoutCheckingTask: ScheduledFuture[_] = null // "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not @@ -140,7 +140,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) } } } - + override def onStop(): Unit = { if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel(true) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 7e706bcc42f04..7cf7bc0dc6810 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -50,8 +50,8 @@ private[spark] class HttpFileServer( def stop() { httpServer.stop() - - // If we only stop sc, but the driver process still run as a services then we need to delete + + // If we only stop sc, but the driver process still run as a services then we need to delete // the tmp dir, if not, it will create too many tmp dirs try { Utils.deleteRecursively(baseDir) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index b8d244408bc5b..82889bcd30988 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -103,7 +103,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { */ class RangePartitioner[K : Ordering : ClassTag, V]( @transient partitions: Int, - @transient rdd: RDD[_ <: Product2[K,V]], + @transient rdd: RDD[_ <: Product2[K, V]], private var ascending: Boolean = true) extends Partitioner { @@ -185,7 +185,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } override def equals(other: Any): Boolean = other match { - case r: RangePartitioner[_,_] => + case r: RangePartitioner[_, _] => r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending case _ => false @@ -249,7 +249,7 @@ private[spark] object RangePartitioner { * @param sampleSizePerPartition max sample size per partition * @return (total number of items, an array of (partitionId, number of items, sample)) */ - def sketch[K:ClassTag]( + def sketch[K : ClassTag]( rdd: RDD[K], sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { val shift = rdd.id @@ -272,7 +272,7 @@ private[spark] object RangePartitioner { * @param partitions number of partitions * @return selected bounds */ - def determineBounds[K:Ordering:ClassTag]( + def determineBounds[K : Ordering : ClassTag]( candidates: ArrayBuffer[(K, Float)], partitions: Int): Array[K] = { val ordering = implicitly[Ordering[K]] diff --git a/core/src/main/scala/org/apache/spark/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/SizeEstimator.scala deleted file mode 100644 index 54fc3a856adfa..0000000000000 --- a/core/src/main/scala/org/apache/spark/SizeEstimator.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.apache.spark.annotation.DeveloperApi - -/** - * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in - * memory-aware caches. - * - * Based on the following JavaWorld article: - * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html - */ -@DeveloperApi -object SizeEstimator { - /** - * :: DeveloperApi :: - * Estimate the number of bytes that the given object takes up on the JVM heap. The estimate - * includes space taken up by objects referenced by the given object, their references, and so on - * and so forth. - * - * This is useful for determining the amount of heap space a broadcast variable will occupy on - * each executor or the amount of space each object will take when caching objects in - * deserialized form. This is not the same as the serialized size of the object, which will - * typically be much smaller. - */ - @DeveloperApi - def estimate(obj: AnyRef): Long = org.apache.spark.util.SizeEstimator.estimate(obj) -} diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index b5e5d6f1465f3..46d72841dccce 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -227,7 +227,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsBytes(key: String, defaultValue: String): Long = { Utils.byteStringAsBytes(get(key, defaultValue)) } - + /** * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. @@ -244,7 +244,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsKb(key: String, defaultValue: String): Long = { Utils.byteStringAsKb(get(key, defaultValue)) } - + /** * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. @@ -261,7 +261,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsMb(key: String, defaultValue: String): Long = { Utils.byteStringAsMb(get(key, defaultValue)) } - + /** * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. @@ -278,7 +278,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsGb(key: String, defaultValue: String): Long = { Utils.byteStringAsGb(get(key, defaultValue)) } - + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) @@ -480,8 +480,8 @@ private[spark] object SparkConf extends Logging { "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + "are no longer accepted. To specify the equivalent now, one may use '64k'.") ) - - Map(configs.map { cfg => (cfg.key -> cfg) }:_*) + + Map(configs.map { cfg => (cfg.key -> cfg) } : _*) } /** @@ -508,7 +508,7 @@ private[spark] object SparkConf extends Logging { "spark.reducer.maxSizeInFlight" -> Seq( AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), "spark.kryoserializer.buffer" -> - Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", translation = s => s"${(s.toDouble * 1000).toInt}k")), "spark.kryoserializer.buffer.max" -> Seq( AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ea6c0dea08e47..a453c9bf4864a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -389,7 +389,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) - _jars =_conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten + _jars = _conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten _files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)) .toSeq.flatten @@ -438,7 +438,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _ui = if (conf.getBoolean("spark.ui.enabled", true)) { Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, - _env.securityManager,appName, startTime = startTime)) + _env.securityManager, appName, startTime = startTime)) } else { // For tests, do not enable the UI None @@ -917,7 +917,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli classOf[FixedLengthBinaryInputFormat], classOf[LongWritable], classOf[BytesWritable], - conf=conf) + conf = conf) val data = br.map { case (k, v) => val bytes = v.getBytes assert(bytes.length == recordLength, "Byte array does not have correct length") @@ -1267,7 +1267,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { - val param = new GrowableAccumulableParam[R,T] + val param = new GrowableAccumulableParam[R, T] val acc = new Accumulable(initialValue, param) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1316,7 +1316,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val uri = new URI(path) val schemeCorrectedPath = uri.getScheme match { case null | "local" => new File(path).getCanonicalFile.toURI.toString - case _ => path + case _ => path } val hadoopPath = new Path(schemeCorrectedPath) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 327114542880d..a185954089528 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -298,7 +298,7 @@ object SparkEnv extends Logging { } } - val mapOutputTracker = if (isDriver) { + val mapOutputTracker = if (isDriver) { new MapOutputTrackerMaster(conf) } else { new MapOutputTrackerWorker(conf) @@ -348,7 +348,7 @@ object SparkEnv extends Logging { val fileServerPort = conf.getInt("spark.fileserver.port", 0) val server = new HttpFileServer(conf, securityManager, fileServerPort) server.initialize() - conf.set("spark.fileserver.uri", server.serverUri) + conf.set("spark.fileserver.uri", server.serverUri) server } else { null diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 2ec42d3aea169..59ac82ccec53b 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -50,8 +50,8 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private var jID: SerializableWritable[JobID] = null private var taID: SerializableWritable[TaskAttemptID] = null - @transient private var writer: RecordWriter[AnyRef,AnyRef] = null - @transient private var format: OutputFormat[AnyRef,AnyRef] = null + @transient private var writer: RecordWriter[AnyRef, AnyRef] = null + @transient private var format: OutputFormat[AnyRef, AnyRef] = null @transient private var committer: OutputCommitter = null @transient private var jobContext: JobContext = null @transient private var taskContext: TaskAttemptContext = null @@ -114,10 +114,10 @@ class SparkHadoopWriter(@transient jobConf: JobConf) // ********* Private Functions ********* - private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = { + private def getOutputFormat(): OutputFormat[AnyRef, AnyRef] = { if (format == null) { format = conf.value.getOutputFormat() - .asInstanceOf[OutputFormat[AnyRef,AnyRef]] + .asInstanceOf[OutputFormat[AnyRef, AnyRef]] } format } @@ -138,7 +138,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private def getTaskContext(): TaskAttemptContext = { if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) + taskContext = newTaskAttemptContext(conf.value, taID.value) } taskContext } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index fe6320b504e15..a1ebbecf93b7b 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -51,7 +51,7 @@ private[spark] object TestUtils { classpathUrls: Seq[URL] = Seq()): URL = { val tempDir = Utils.createTempDir() val files1 = for (name <- classNames) yield { - createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) + createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) } val files2 = for ((childName, baseName) <- classNamesWithBase) yield { createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 61af867b11b9c..a650df605b92e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -137,7 +137,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) */ def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index db4e996feb31c..ed312770ee131 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -101,7 +101,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD. - * + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] @@ -109,10 +109,10 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. - * + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 74db7643224f5..c95615a5a9307 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -60,10 +60,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { @deprecated("Use partitions() instead.", "1.1.0") def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) - + /** Set of partitions in this RDD. */ def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + /** The partitioner of this RDD. */ + def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) + /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ def context: SparkContext = rdd.context @@ -96,7 +99,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsWithIndex[R]( f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = - new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), + new JavaRDD(rdd.mapPartitionsWithIndex(((a, b) => f(a, asJavaIterator(b))), preservesPartitioning)(fakeClassTag))(fakeClassTag) /** @@ -492,9 +495,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { new java.util.ArrayList(arr) } - def takeSample(withReplacement: Boolean, num: Int): JList[T] = + def takeSample(withReplacement: Boolean, num: Int): JList[T] = takeSample(withReplacement, num, Utils.random.nextLong) - + def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 2d92f6a42b308..55a37f8c944b2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -723,7 +723,7 @@ private[spark] object PythonRDD extends Logging { val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] - converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec) + converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec = codec) } /** @@ -797,10 +797,10 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) - /** + /** * We try to reuse a single Socket to transfer accumulator updates, as they are all added * by the DAGScheduler's single-threaded actor anyway. - */ + */ @transient var socket: Socket = _ def openSocket(): Socket = synchronized { @@ -843,6 +843,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: * An Wrapper for Python Broadcast, which is written into disk by Python. It also will * write the data into disk after deserialization, then Python can read it from disks. */ +// scalastyle:off no.finalize private[spark] class PythonBroadcast(@transient var path: String) extends Serializable { /** @@ -884,3 +885,4 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial } } } +// scalastyle:on no.finalize diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 0a91977928cee..d24c650d37bb0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -44,11 +44,11 @@ private[spark] class RBackend { bossGroup = new NioEventLoopGroup(2) val workerGroup = bossGroup val handler = new RBackendHandler(this) - + bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) .channel(classOf[NioServerSocketChannel]) - + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { def initChannel(ch: SocketChannel): Unit = { ch.pipeline() diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 0075d963711f1..2e86984c66b3a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -77,7 +77,7 @@ private[r] class RBackendHandler(server: RBackend) val reply = bos.toByteArray ctx.write(reply) } - + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { ctx.flush() } @@ -124,7 +124,7 @@ private[r] class RBackendHandler(server: RBackend) } throw new Exception(s"No matched method found for $cls.$methodName") } - val ret = methods.head.invoke(obj, args:_*) + val ret = methods.head.invoke(obj, args : _*) // Write status bit writeInt(dos, 0) @@ -135,7 +135,7 @@ private[r] class RBackendHandler(server: RBackend) matchMethod(numArgs, args, x.getParameterTypes) }.head - val obj = ctor.newInstance(args:_*) + val obj = ctor.newInstance(args : _*) writeInt(dos, 0) writeObject(dos, obj.asInstanceOf[AnyRef]) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 06247f7e8b78c..4dfa7325934ff 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -309,7 +309,7 @@ private class StringRRDD[T: ClassTag]( } private object SpecialLengths { - val TIMING_DATA = -1 + val TIMING_DATA = -1 } private[r] class BufferedStreamThread( @@ -355,7 +355,6 @@ private[r] object RRDD { val sparkConf = new SparkConf().setAppName(appName) .setSparkHome(sparkHome) - .setJars(jars) // Override `master` if we have a user-specified value if (master != "") { @@ -373,7 +372,11 @@ private[r] object RRDD { sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) } - new JavaSparkContext(sparkConf) + val jsc = new JavaSparkContext(sparkConf) + jars.foreach { jar => + jsc.addJar(jar) + } + jsc } /** diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 371dfe454d1a2..f8e3f1a79082e 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -157,9 +157,11 @@ private[spark] object SerDe { val keysLen = readInt(in) val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) - val valuesType = readObjectType(in) val valuesLen = readInt(in) - val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType)) + val values = (0 until valuesLen).map(_ => { + val valueType = readObjectType(in) + readTypedObject(in, valueType) + }) mapAsJavaMap(keys.zip(values).toMap) } else { new java.util.HashMap[Object, Object]() diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 4457c75e8b0fc..b69af639f7862 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -125,7 +125,7 @@ private[broadcast] object HttpBroadcast extends Logging { securityManager = securityMgr if (isDriver) { createServer(conf) - conf.set("spark.httpBroadcast.uri", serverUri) + conf.set("spark.httpBroadcast.uri", serverUri) } serverUri = conf.get("spark.httpBroadcast.uri") cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf) @@ -187,7 +187,7 @@ private[broadcast] object HttpBroadcast extends Logging { } private def read[T: ClassTag](id: Long): T = { - logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) + logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name var uc: URLConnection = null diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index c048b78910f38..b4edb6109e839 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -65,7 +65,7 @@ private object FaultToleranceTest extends App with Logging { private val workers = ListBuffer[TestWorkerInfo]() private var sc: SparkContext = _ - private val zk = SparkCuratorUtil.newClient(conf) + private val zk = SparkCuratorUtil.newClient(conf) private var numPassed = 0 private var numFailed = 0 diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 198371b70f14f..a0eae774268ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -82,13 +82,13 @@ object SparkSubmit { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 // Exposed for testing - private[spark] var exitFn: () => Unit = () => System.exit(1) + private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) private[spark] var printStream: PrintStream = System.err private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) private[spark] def printErrorAndExit(str: String): Unit = { printStream.println("Error: " + str) printStream.println("Run with --help for usage help or --verbose for debug output") - exitFn() + exitFn(1) } private[spark] def printVersionAndExit(): Unit = { printStream.println("""Welcome to @@ -99,7 +99,7 @@ object SparkSubmit { /_/ """.format(SPARK_VERSION)) printStream.println("Type --help for more information.") - exitFn() + exitFn(0) } def main(args: Array[String]): Unit = { @@ -160,7 +160,7 @@ object SparkSubmit { // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") - exitFn() + exitFn(1) } else { throw e } @@ -361,7 +361,7 @@ object SparkSubmit { pyArchives = pythonPath.mkString(",") } - pyArchives = pyArchives.split(",").map { localPath=> + pyArchives = pyArchives.split(",").map { localPath => val localURI = Utils.resolveURI(localPath) if (localURI.getScheme != "local") { args.files = mergeFileLists(args.files, localURI.toString) @@ -425,9 +425,10 @@ object SparkSubmit { // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), - OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), + OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), + OptionAssigner(args.keytab, YARN, CLIENT, sysProp = "spark.yarn.keytab"), // Yarn cluster only OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), @@ -440,13 +441,11 @@ object SparkSubmit { OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"), OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), - - // Yarn client or cluster - OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, clOption = "--principal"), - OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, clOption = "--keytab"), + OptionAssigner(args.principal, YARN, CLUSTER, clOption = "--principal"), + OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"), // Other options - OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, + OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), @@ -700,7 +699,7 @@ object SparkSubmit { /** * Return whether the given main class represents a sql shell. */ - private def isSqlShell(mainClass: String): Boolean = { + private[deploy] def isSqlShell(mainClass: String): Boolean = { mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" } @@ -869,7 +868,7 @@ private[spark] object SparkSubmitUtils { md.addDependency(dd) } } - + /** Add exclusion rules for dependencies already included in the spark-assembly */ def addExclusionRules( ivySettings: IvySettings, diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c0e4c771908b3..b7429a901e162 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,12 +17,15 @@ package org.apache.spark.deploy +import java.io.{ByteArrayOutputStream, PrintStream} +import java.lang.reflect.InvocationTargetException import java.net.URI import java.util.{List => JList} import java.util.jar.JarFile import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.io.Source import org.apache.spark.deploy.SparkSubmitAction._ import org.apache.spark.launcher.SparkSubmitArgumentsParser @@ -169,6 +172,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) + keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull + principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && !isR && primaryResource != null) { @@ -410,6 +415,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case VERSION => SparkSubmit.printVersionAndExit() + case USAGE_ERROR => + printUsageAndExit(1) + case _ => throw new IllegalArgumentException(s"Unexpected argument '$opt'.") } @@ -447,11 +455,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) } - outStream.println( + val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] - |Usage: spark-submit --status [submission ID] --master [spark://...] - | + |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin) + outStream.println(command) + + outStream.println( + """ |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or @@ -523,6 +534,65 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | delegation tokens periodically. """.stripMargin ) - SparkSubmit.exitFn() + + if (SparkSubmit.isSqlShell(mainClass)) { + outStream.println("CLI options:") + outStream.println(getSqlShellOptions()) + } + + SparkSubmit.exitFn(exitCode) } + + /** + * Run the Spark SQL CLI main class with the "--help" option and catch its output. Then filter + * the results to remove unwanted lines. + * + * Since the CLI will call `System.exit()`, we install a security manager to prevent that call + * from working, and restore the original one afterwards. + */ + private def getSqlShellOptions(): String = { + val currentOut = System.out + val currentErr = System.err + val currentSm = System.getSecurityManager() + try { + val out = new ByteArrayOutputStream() + val stream = new PrintStream(out) + System.setOut(stream) + System.setErr(stream) + + val sm = new SecurityManager() { + override def checkExit(status: Int): Unit = { + throw new SecurityException() + } + + override def checkPermission(perm: java.security.Permission): Unit = {} + } + System.setSecurityManager(sm) + + try { + Class.forName(mainClass).getMethod("main", classOf[Array[String]]) + .invoke(null, Array(HELP)) + } catch { + case e: InvocationTargetException => + // Ignore SecurityException, since we throw it above. + if (!e.getCause().isInstanceOf[SecurityException]) { + throw e + } + } + + stream.flush() + + // Get the output and discard any unnecessary lines from it. + Source.fromString(new String(out.toByteArray())).getLines + .filter { line => + !line.startsWith("log4j") && !line.startsWith("usage") + } + .mkString("\n") + } finally { + System.setSecurityManager(currentSm) + System.setOut(currentOut) + System.setErr(currentErr) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 298a8201960d1..5f5e0fe1c34d7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -17,6 +17,9 @@ package org.apache.spark.deploy.history +import java.util.zip.ZipOutputStream + +import org.apache.spark.SparkException import org.apache.spark.ui.SparkUI private[spark] case class ApplicationAttemptInfo( @@ -62,4 +65,12 @@ private[history] abstract class ApplicationHistoryProvider { */ def getConfig(): Map[String, String] = Map() + /** + * Writes out the event logs to the output stream provided. The logs will be compressed into a + * single zip file and written out. + * @throws SparkException if the logs for the app id cannot be found. + */ + @throws(classOf[SparkException]) + def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 45c2be34c8680..5427a88f32ffd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -17,16 +17,18 @@ package org.apache.spark.deploy.history -import java.io.{BufferedInputStream, FileNotFoundException, IOException, InputStream} +import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} import java.util.concurrent.{ExecutorService, Executors, TimeUnit} +import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable +import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.permission.AccessControlException -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.scheduler._ @@ -59,7 +61,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .map { d => Utils.resolveURI(d).toString } .getOrElse(DEFAULT_LOG_DIR) - private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf)) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + private val fs = Utils.getHadoopFileSystem(logDir, hadoopConf) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs @@ -219,6 +222,58 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + override def writeEventLogs( + appId: String, + attemptId: Option[String], + zipStream: ZipOutputStream): Unit = { + + /** + * This method compresses the files passed in, and writes the compressed data out into the + * [[OutputStream]] passed in. Each file is written as a new [[ZipEntry]] with its name being + * the name of the file being compressed. + */ + def zipFileToStream(file: Path, entryName: String, outputStream: ZipOutputStream): Unit = { + val fs = FileSystem.get(hadoopConf) + val inputStream = fs.open(file, 1 * 1024 * 1024) // 1MB Buffer + try { + outputStream.putNextEntry(new ZipEntry(entryName)) + ByteStreams.copy(inputStream, outputStream) + outputStream.closeEntry() + } finally { + inputStream.close() + } + } + + applications.get(appId) match { + case Some(appInfo) => + try { + // If no attempt is specified, or there is no attemptId for attempts, return all attempts + appInfo.attempts.filter { attempt => + attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get + }.foreach { attempt => + val logPath = new Path(logDir, attempt.logPath) + // If this is a legacy directory, then add the directory to the zipStream and add + // each file to that directory. + if (isLegacyLogDirectory(fs.getFileStatus(logPath))) { + val files = fs.listStatus(logPath) + zipStream.putNextEntry(new ZipEntry(attempt.logPath + "/")) + zipStream.closeEntry() + files.foreach { file => + val path = file.getPath + zipFileToStream(path, attempt.logPath + Path.SEPARATOR + path.getName, zipStream) + } + } else { + zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) + } + } + } finally { + zipStream.close() + } + case None => throw new SparkException(s"Logs for $appId not found.") + } + } + + /** * Replay the log files in the list and merge the list of old applications with new ones */ diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 5a0eb585a9049..10638afb74900 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.history import java.util.NoSuchElementException +import java.util.zip.ZipOutputStream import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import com.google.common.cache._ @@ -173,6 +174,13 @@ class HistoryServer( getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } + override def writeEventLogs( + appId: String, + attemptId: Option[String], + zipStream: ZipOutputStream): Unit = { + provider.writeEventLogs(appId, attemptId, zipStream) + } + /** * Returns the provider configuration to show in the listing page. * diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index a2a97a7877ce7..4692d22651c93 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.Utils /** * Command-line parser for the master. */ -private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) +private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { private var propertiesFile: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 80db6d474b5c1..328d95a7a0c68 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -32,7 +32,7 @@ import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) extends PersistenceEngine with Logging { - + private val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" private val zk: CuratorFramework = SparkCuratorUtil.newClient(conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 756927682cd24..6a7c74020bace 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -75,6 +75,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory") val workers = state.workers.sortBy(_.id) + val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", @@ -108,12 +109,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { }.getOrElse { Seq.empty } } -

  • Workers: {state.workers.size}
  • -
  • Cores: {state.workers.map(_.cores).sum} Total, - {state.workers.map(_.coresUsed).sum} Used
  • -
  • Memory: - {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total, - {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
  • +
  • Alive Workers: {aliveWorkers.size}
  • +
  • Cores in use: {aliveWorkers.map(_.cores).sum} Total, + {aliveWorkers.map(_.coresUsed).sum} Used
  • +
  • Memory in use: + {Utils.megabytesToString(aliveWorkers.map(_.memory).sum)} Total, + {Utils.megabytesToString(aliveWorkers.map(_.memoryUsed).sum)} Used
  • Applications: {state.activeApps.size} Running, {state.completedApps.size} Completed
  • diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index be8560d10fc62..e8ef60bd5428a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -68,7 +68,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) val content =

    Driver state information for driver id {driverId}

    - Back to Drivers + Back to Drivers

    Driver state: {driverState.state}

    diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 6078f50518ba4..1fe956320a1b8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -57,7 +57,11 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { private val supportedMasterPrefixes = Seq("spark://", "mesos://") - private val masters: Array[String] = Utils.parseStandaloneMasterUrls(master) + private val masters: Array[String] = if (master.startsWith("spark://")) { + Utils.parseStandaloneMasterUrls(master) + } else { + Array(master) + } // Set of masters that lost contact with us, used to keep track of // whether there are masters still alive for us to communicate with diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c8df024dda355..ebc6cd76c6afd 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -554,7 +554,7 @@ private[deploy] object Worker extends Logging { conf = conf, securityManager = securityMgr) val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) + masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 88170d4df3053..5a1d06eb87db9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy.worker.ui +import java.io.File +import java.net.URI import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -29,6 +31,7 @@ import org.apache.spark.util.logging.RollingFileAppender private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker private val workDir = parent.workDir + private val supportedLogTypes = Set("stderr", "stdout") def renderLog(request: HttpServletRequest): String = { val defaultBytes = 100 * 1024 @@ -129,6 +132,18 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with offsetOption: Option[Long], byteLength: Int ): (String, Long, Long, Long) = { + + if (!supportedLogTypes.contains(logType)) { + return ("Error: Log type must be one of " + supportedLogTypes.mkString(", "), 0, 0, 0) + } + + // Verify that the normalized path of the log directory is in the working directory + val normalizedUri = new URI(logDirectory).normalize() + val normalizedLogDir = new File(normalizedUri.getPath) + if (!Utils.isInDirectory(workDir, normalizedLogDir)) { + return ("Error: invalid log directory " + logDirectory, 0, 0, 0) + } + try { val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType) logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}") @@ -144,7 +159,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with offset } } - val endIndex = math.min(startIndex + totalLength, totalLength) + val endIndex = math.min(startIndex + byteLength, totalLength) logDebug(s"Getting log from $startIndex to $endIndex") val logText = Utils.offsetBytes(files, startIndex, endIndex) logDebug(s"Got log of length ${logText.length} bytes") diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 06152f16ae618..38b61d7242fce 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -43,22 +43,22 @@ class TaskMetrics extends Serializable { private var _hostname: String = _ def hostname: String = _hostname private[spark] def setHostname(value: String) = _hostname = value - + /** * Time taken on the executor to deserialize this task */ private var _executorDeserializeTime: Long = _ def executorDeserializeTime: Long = _executorDeserializeTime private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value - - + + /** * Time the executor spends actually running the task (including fetching shuffle data) */ private var _executorRunTime: Long = _ def executorRunTime: Long = _executorRunTime private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value - + /** * The number of bytes this task transmitted back to the driver as the TaskResult */ @@ -261,7 +261,7 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { */ private var _recordsRead: Long = _ def recordsRead: Long = _recordsRead - def incRecordsRead(records: Long): Unit = _recordsRead += records + def incRecordsRead(records: Long): Unit = _recordsRead += records /** * Invoke the bytesReadCallback and mutate bytesRead. @@ -315,7 +315,7 @@ class ShuffleReadMetrics extends Serializable { def remoteBlocksFetched: Int = _remoteBlocksFetched private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value - + /** * Number of local blocks fetched in this shuffle by this task */ @@ -333,7 +333,7 @@ class ShuffleReadMetrics extends Serializable { def fetchWaitTime: Long = _fetchWaitTime private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value - + /** * Total number of remote bytes read from the shuffle by this task */ @@ -381,7 +381,7 @@ class ShuffleWriteMetrics extends Serializable { def shuffleBytesWritten: Long = _shuffleBytesWritten private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value - + /** * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ @@ -389,7 +389,7 @@ class ShuffleWriteMetrics extends Serializable { def shuffleWriteTime: Long = _shuffleWriteTime private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value - + /** * Total number of records written to the shuffle by this task */ diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index cfd20392d12f1..390d148bc97f9 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -60,7 +60,7 @@ trait SparkHadoopMapReduceUtil { val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType") .asInstanceOf[Class[Enum[_]]] val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( - taskTypeClass, if(isMap) "MAP" else "REDUCE") + taskTypeClass, if (isMap) "MAP" else "REDUCE") val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, classOf[Int], classOf[Int]) ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index 8edf493780687..d7495551ad233 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -23,10 +23,10 @@ import java.util.Properties import scala.collection.mutable import scala.util.matching.Regex -import org.apache.spark.Logging import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} -private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging { +private[spark] class MetricsConfig(conf: SparkConf) extends Logging { private val DEFAULT_PREFIX = "*" private val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r @@ -46,23 +46,14 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi // Add default properties in case there's no properties file setDefaultProperties(properties) - // If spark.metrics.conf is not set, try to get file in class path - val isOpt: Option[InputStream] = configFile.map(new FileInputStream(_)).orElse { - try { - Option(Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME)) - } catch { - case e: Exception => - logError("Error loading default configuration file", e) - None - } - } + loadPropertiesFromFile(conf.getOption("spark.metrics.conf")) - isOpt.foreach { is => - try { - properties.load(is) - } finally { - is.close() - } + // Also look for the properties in provided Spark configuration + val prefix = "spark.metrics.conf." + conf.getAll.foreach { + case (k, v) if k.startsWith(prefix) => + properties.setProperty(k.substring(prefix.length()), v) + case _ => } propertyCategories = subProperties(properties, INSTANCE_REGEX) @@ -97,5 +88,31 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi case None => propertyCategories.getOrElse(DEFAULT_PREFIX, new Properties) } } -} + /** + * Loads configuration from a config file. If no config file is provided, try to get file + * in class path. + */ + private[this] def loadPropertiesFromFile(path: Option[String]): Unit = { + var is: InputStream = null + try { + is = path match { + case Some(f) => new FileInputStream(f) + case None => Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME) + } + + if (is != null) { + properties.load(is) + } + } catch { + case e: Exception => + val file = path.getOrElse(DEFAULT_METRICS_CONF_FILENAME) + logError(s"Error loading configuration file $file", e) + } finally { + if (is != null) { + is.close() + } + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 9150ad35712a1..ed5131c79fdc5 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -70,8 +70,7 @@ private[spark] class MetricsSystem private ( securityMgr: SecurityManager) extends Logging { - private[this] val confFile = conf.get("spark.metrics.conf", null) - private[this] val metricsConfig = new MetricsConfig(Option(confFile)) + private[this] val metricsConfig = new MetricsConfig(conf) private val sinks = new mutable.ArrayBuffer[Sink] private val sources = new mutable.ArrayBuffer[Source] diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index e8b3074e8f1a6..11dfcfe2f04e1 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -26,9 +26,9 @@ import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem private[spark] class Slf4jSink( - val property: Properties, + val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) + securityMgr: SecurityManager) extends Sink { val SLF4J_DEFAULT_PERIOD = 10 val SLF4J_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala index 90e3aa70b99ef..670e683663324 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala @@ -20,4 +20,4 @@ package org.apache.spark.metrics /** * Sinks used in Spark's metrics system. */ -package object sink +package object sink diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index b573f1a8a5fcb..67a376102994c 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -110,7 +110,7 @@ private[nio] class BlockMessage() { def getType: Int = typ def getId: BlockId = id def getData: ByteBuffer = data - def getLevel: StorageLevel = level + def getLevel: StorageLevel = level def toBufferMessage: BufferMessage = { val buffers = new ArrayBuffer[ByteBuffer]() @@ -155,7 +155,7 @@ private[nio] class BlockMessage() { override def toString: String = { "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" + ", data = " + (if (data != null) data.remaining.toString else "null") + "]" } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index 1ba25aa74aa02..7d0806f0c2580 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -114,8 +114,8 @@ private[nio] object BlockMessageArray { val blockMessages = (0 until 10).map { i => if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear + val buffer = ByteBuffer.allocate(100) + buffer.clear() BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, StorageLevel.MEMORY_ONLY_SER)) } else { diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 6b898bd4bfc1b..1499da07bb83b 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -326,15 +326,14 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // MUST be called within the selector loop def connect() { - try{ + try { channel.register(selector, SelectionKey.OP_CONNECT) channel.connect(address) logInfo("Initiating connection to [" + address + "]") } catch { - case e: Exception => { + case e: Exception => logError("Error connecting to " + address, e) callOnExceptionCallbacks(e) - } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 497871ed6d5e5..c0bca2c4bc994 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -635,12 +635,11 @@ private[nio] class ConnectionManager( val message = securityMsgResp.toBufferMessage if (message == null) throw new IOException("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) - } catch { - case e: Exception => { + } catch { + case e: Exception => logError("Error handling sasl client authentication", e) waitingConn.close() throw new IOException("Error evaluating sasl response: ", e) - } } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala index 747a2088a7258..232c552f9865d 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala @@ -75,7 +75,7 @@ private[nio] class SecurityMessage extends Logging { for (i <- 1 to idLength) { idBuilder += buffer.getChar() } - connectionId = idBuilder.toString() + connectionId = idBuilder.toString() val tokenLength = buffer.getInt() token = new Array[Byte](tokenLength) diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 2ab41ba488ff6..8ae76c5f72f2e 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.4.0-SNAPSHOT" + val SPARK_VERSION = "1.5.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 3ef3cc219dec6..91b07ce3af1b6 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -32,12 +32,12 @@ import org.apache.spark.util.collection.OpenHashMap * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. */ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[OpenHashMap[T,Long], Map[T, BoundedDouble]] { + extends ApproximateEvaluator[OpenHashMap[T, Long], Map[T, BoundedDouble]] { var outputsMerged = 0 - var sums = new OpenHashMap[T,Long]() // Sum of counts for each key + var sums = new OpenHashMap[T, Long]() // Sum of counts for each key - override def merge(outputId: Int, taskResult: OpenHashMap[T,Long]) { + override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]) { outputsMerged += 1 taskResult.foreach { case (key, value) => sums.changeValue(key, value, _ + value) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index bbf1b83af0795..ca1eb1f4e4a9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -85,9 +85,9 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi numPartsToTry = partsScanned * 4 } else { // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max(1, + numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 0d130dd4c7a60..a4715e3437d94 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -49,7 +49,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) if (fs.exists(cpath)) { val dirContents = fs.listStatus(cpath).map(_.getPath) val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.length + val numPart = partitionFiles.length if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { throw new SparkException("Invalid checkpoint directory: " + checkpointPath) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 0c1b02c07d09f..663eebb8e4191 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -310,11 +310,11 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: def throwBalls() { if (noLocality) { // no preferredLocations in parent RDD, no randomization needed if (maxPartitions > groupArr.size) { // just return prev.partitions - for ((p,i) <- prev.partitions.zipWithIndex) { + for ((p, i) <- prev.partitions.zipWithIndex) { groupArr(i).arr += p } } else { // no locality available, then simply split partitions based on positions in array - for(i <- 0 until maxPartitions) { + for (i <- 0 until maxPartitions) { val rangeStart = ((i.toLong * prev.partitions.length) / maxPartitions).toInt val rangeEnd = (((i.toLong + 1) * prev.partitions.length) / maxPartitions).toInt (rangeStart until rangeEnd).foreach{ j => groupArr(i).arr += prev.partitions(j) } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 2ab967f4bb313..84456d6d868dc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -196,7 +196,7 @@ class NewHadoopRDD[K, V]( override def getPreferredLocations(hsplit: Partition): Seq[String] = { val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => + case Some(c) => try { val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] Some(HadoopRDD.convertSplitLocationInfo(infos)) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 8653cdee1adee..cfd3e26faf2b9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -328,7 +328,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) reduceByKeyLocally(func) } - /** + /** * Count the number of elements for each key, collecting the results to a local Map. * * Note that this method should only be used if the resulting map is expected to be small, as @@ -467,7 +467,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 val bufs = combineByKey[CompactBuffer[V]]( - createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine=false) + createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) bufs.asInstanceOf[RDD[(K, Iterable[V])]] } @@ -1011,7 +1011,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) jobFormat.checkOutputSpecs(job) } - val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { + val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => { val config = wrappedConf.value /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, @@ -1027,7 +1027,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) - val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]] require(writer != null, "Unable to obtain RecordWriter") var recordsWritten = 0L Utils.tryWithSafeFinally { diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 7598ff617b399..9e3880714a79f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -86,7 +86,7 @@ class PartitionerAwareUnionRDD[T: ClassTag]( } val location = if (locations.isEmpty) { None - } else { + } else { // Find the location that maximum number of parent partitions prefer Some(locations.groupBy(x => x).maxBy(_._2.length)._1) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index d772f03f76651..10610f4b6f1ff 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -434,11 +434,11 @@ abstract class RDD[T: ClassTag]( * @return A random sub-sample of the RDD without replacement. */ private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = { - this.mapPartitionsWithIndex { case (index, partition) => + this.mapPartitionsWithIndex( { (index, partition) => val sampler = new BernoulliCellSampler[T](lb, ub) sampler.setSeed(seed + index) sampler.sample(partition) - } + }, preservesPartitioning = true) } /** @@ -454,7 +454,7 @@ abstract class RDD[T: ClassTag]( withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = { - val numStDev = 10.0 + val numStDev = 10.0 if (num < 0) { throw new IllegalArgumentException("Negative number of elements requested") @@ -1138,8 +1138,8 @@ abstract class RDD[T: ClassTag]( if (elementClassTag.runtimeClass.isArray) { throw new SparkException("countByValueApprox() does not support arrays") } - val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T,Long] = { (ctx, iter) => - val map = new OpenHashMap[T,Long] + val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T, Long] = { (ctx, iter) => + val map = new OpenHashMap[T, Long] iter.foreach { t => map.changeValue(t, 1L, _ + 1L) } @@ -1585,15 +1585,15 @@ abstract class RDD[T: ClassTag]( case 0 => Seq.empty case 1 => val d = rdd.dependencies.head - debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]], true) + debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_, _, _]], true) case _ => val frontDeps = rdd.dependencies.take(len - 1) val frontDepStrings = frontDeps.flatMap( - d => debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]])) + d => debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_, _, _]])) val lastDep = rdd.dependencies.last val lastDepStrings = - debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_,_,_]], true) + debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true) (frontDepStrings ++ lastDepStrings) } diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 3dfcf67f0eb66..4b5f15dd06b85 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -104,13 +104,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag if (!convertKey && !convertValue) { self.saveAsHadoopFile(path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (!convertKey && convertValue) { - self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( + self.map(x => (x._1, anyToWritable(x._2))).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && !convertValue) { - self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( + self.map(x => (anyToWritable(x._1), x._2)).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && convertValue) { - self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( + self.map(x => (anyToWritable(x._1), anyToWritable(x._2))).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 633aeba3bbae6..f7cb1791d4ac6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -125,7 +125,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys integrate(1, t => map.remove(t._1)) - map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten + map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index a96b6c3d23454..81f40ad33aa5d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -123,7 +123,7 @@ private[spark] class ZippedPartitionsRDD3 } private[spark] class ZippedPartitionsRDD4 - [A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag]( + [A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]( sc: SparkContext, var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], var rdd1: RDD[A], diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5d812918a13d1..75a567fb31520 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -193,9 +193,15 @@ class DAGScheduler( def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized { // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times if (!cacheLocs.contains(rdd.id)) { - val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] - val locs: Seq[Seq[TaskLocation]] = blockManagerMaster.getLocations(blockIds).map { bms => - bms.map(bm => TaskLocation(bm.host, bm.executorId)) + // Note: if the storage level is NONE, we don't need to get locations from block manager. + val locs: Seq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { + Seq.fill(rdd.partitions.size)(Nil) + } else { + val blockIds = + rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] + blockManagerMaster.getLocations(blockIds).map { bms => + bms.map(bm => TaskLocation(bm.host, bm.executorId)) + } } cacheLocs(rdd.id) = locs } @@ -208,19 +214,17 @@ class DAGScheduler( /** * Get or create a shuffle map stage for the given shuffle dependency's map side. - * The jobId value passed in will be used if the stage doesn't already exist with - * a lower jobId (jobId always increases across jobs.) */ private def getShuffleMapStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int): ShuffleMapStage = { + firstJobId: Int): ShuffleMapStage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies - registerShuffleDependencies(shuffleDep, jobId) + registerShuffleDependencies(shuffleDep, firstJobId) // Then register current shuffleDep - val stage = newOrUsedShuffleStage(shuffleDep, jobId) + val stage = newOrUsedShuffleStage(shuffleDep, firstJobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage @@ -230,15 +234,15 @@ class DAGScheduler( /** * Helper function to eliminate some code re-use when creating new stages. */ - private def getParentStagesAndId(rdd: RDD[_], jobId: Int): (List[Stage], Int) = { - val parentStages = getParentStages(rdd, jobId) + private def getParentStagesAndId(rdd: RDD[_], firstJobId: Int): (List[Stage], Int) = { + val parentStages = getParentStages(rdd, firstJobId) val id = nextStageId.getAndIncrement() (parentStages, id) } /** * Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in - * newOrUsedShuffleStage. The stage will be associated with the provided jobId. + * newOrUsedShuffleStage. The stage will be associated with the provided firstJobId. * Production of shuffle map stages should always use newOrUsedShuffleStage, not * newShuffleMapStage directly. */ @@ -246,21 +250,19 @@ class DAGScheduler( rdd: RDD[_], numTasks: Int, shuffleDep: ShuffleDependency[_, _, _], - jobId: Int, + firstJobId: Int, callSite: CallSite): ShuffleMapStage = { - val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, firstJobId) val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages, - jobId, callSite, shuffleDep) + firstJobId, callSite, shuffleDep) stageIdToStage(id) = stage - updateJobIdStageIdMaps(jobId, stage) + updateJobIdStageIdMaps(firstJobId, stage) stage } /** - * Create a ResultStage -- either directly for use as a result stage, or as part of the - * (re)-creation of a shuffle map stage in newOrUsedShuffleStage. The stage will be associated - * with the provided jobId. + * Create a ResultStage associated with the provided jobId. */ private def newResultStage( rdd: RDD[_], @@ -277,16 +279,16 @@ class DAGScheduler( /** * Create a shuffle map Stage for the given RDD. The stage will also be associated with the - * provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is + * provided firstJobId. If a stage for the shuffleId existed previously so that the shuffleId is * present in the MapOutputTracker, then the number and location of available outputs are * recovered from the MapOutputTracker */ private def newOrUsedShuffleStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int): ShuffleMapStage = { + firstJobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd val numTasks = rdd.partitions.size - val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, jobId, rdd.creationSite) + val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) @@ -304,10 +306,10 @@ class DAGScheduler( } /** - * Get or create the list of parent stages for a given RDD. The stages will be assigned the - * provided jobId if they haven't already been created with a lower jobId. + * Get or create the list of parent stages for a given RDD. The new Stages will be created with + * the provided firstJobId. */ - private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { + private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError @@ -321,7 +323,7 @@ class DAGScheduler( for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - parents += getShuffleMapStage(shufDep, jobId) + parents += getShuffleMapStage(shufDep, firstJobId) case _ => waitingForVisit.push(dep.rdd) } @@ -336,11 +338,11 @@ class DAGScheduler( } /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) { + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int) { val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) while (parentsWithNoMapStage.nonEmpty) { val currentShufDep = parentsWithNoMapStage.pop() - val stage = newOrUsedShuffleStage(currentShufDep, jobId) + val stage = newOrUsedShuffleStage(currentShufDep, firstJobId) shuffleToMapStage(currentShufDep.shuffleId) = stage } } @@ -386,11 +388,12 @@ class DAGScheduler( def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd - if (getCacheLocs(rdd).contains(Nil)) { + val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil) + if (rddHasUncachedPartitions) { for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) + val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { missing += mapStage } @@ -577,7 +580,7 @@ class DAGScheduler( private[scheduler] def doCancelAllJobs() { // Cancel all running jobs. - runningStages.map(_.jobId).foreach(handleJobCancellation(_, + runningStages.map(_.firstJobId).foreach(handleJobCancellation(_, reason = "as part of cancellation of all jobs")) activeJobs.clear() // These should already be empty by this point, jobIdToActiveJob.clear() // but just in case we lost track of some jobs... @@ -603,7 +606,7 @@ class DAGScheduler( clearCacheLocs() val failedStagesCopy = failedStages.toArray failedStages.clear() - for (stage <- failedStagesCopy.sortBy(_.jobId)) { + for (stage <- failedStagesCopy.sortBy(_.firstJobId)) { submitStage(stage) } } @@ -623,7 +626,7 @@ class DAGScheduler( logTrace("failed: " + failedStages) val waitingStagesCopy = waitingStages.toArray waitingStages.clear() - for (stage <- waitingStagesCopy.sortBy(_.jobId)) { + for (stage <- waitingStagesCopy.sortBy(_.firstJobId)) { submitStage(stage) } } @@ -843,7 +846,7 @@ class DAGScheduler( } } - val properties = jobIdToActiveJob.get(stage.jobId).map(_.properties).orNull + val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are @@ -909,7 +912,7 @@ class DAGScheduler( stage.pendingTasks ++= tasks logDebug("New pending tasks: " + stage.pendingTasks) taskScheduler.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) + new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark @@ -1323,7 +1326,7 @@ class DAGScheduler( for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) + val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { waitingForVisit.push(mapStage.rdd) } // Otherwise there's no need to follow the dependency back @@ -1364,10 +1367,10 @@ class DAGScheduler( private def getPreferredLocsInternal( rdd: RDD[_], partition: Int, - visited: HashSet[(RDD[_],Int)]): Seq[TaskLocation] = { + visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = { // If the partition has already been visited, no need to re-visit. // This avoids exponential path exploration. SPARK-695 - if (!visited.add((rdd,partition))) { + if (!visited.add((rdd, partition))) { // Nil has already been returned for previously visited partitions. return Nil } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 12668b6c0988e..02c67073af6a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -17,9 +17,8 @@ package org.apache.spark.scheduler -import com.codahale.metrics.{Gauge,MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry} -import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 86f357abb8723..c6d957b65f3fb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -41,7 +41,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { * * @param logData Stream containing event log data. * @param sourceName Filename (or other source identifier) from whence @logData is being read - * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations + * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations * encountered, log file might not finished writing) or not */ def replay( @@ -62,7 +62,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { if (!maybeTruncated || lines.hasNext) { throw jpe } else { - logWarning(s"Got JsonParseException from log file $sourceName" + + logWarning(s"Got JsonParseException from log file $sourceName" + s" at line $lineNumber, the file might not have finished writing cleanly.") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index c0f3d5a13d623..bf81b9aca4810 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -28,9 +28,9 @@ private[spark] class ResultStage( rdd: RDD[_], numTasks: Int, parents: List[Stage], - jobId: Int, + firstJobId: Int, callSite: CallSite) - extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { // The active job for this result stage. Will be empty if the job has already finished // (e.g., because the job was cancelled). diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala index 5e62c8468f007..864941d468af9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala @@ -56,7 +56,7 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble - var compare:Int = 0 + var compare: Int = 0 if (s1Needy && !s2Needy) { return true diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index d02210743484c..66c75f325fcde 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -30,10 +30,10 @@ private[spark] class ShuffleMapStage( rdd: RDD[_], numTasks: Int, parents: List[Stage], - jobId: Int, + firstJobId: Int, callSite: CallSite, val shuffleDep: ShuffleDependency[_, _, _]) - extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { override def toString: String = "ShuffleMapStage " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 863d0befbc19e..9620915f495ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -270,7 +270,7 @@ class StatsReportListener extends SparkListener with Logging { private[spark] object StatsReportListener extends Logging { // For profiling, the extremes are more interesting - val percentiles = Array[Int](0,5,10,25,50,75,90,95,100) + val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100) val probabilities = percentiles.map(_ / 100.0) val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" @@ -304,7 +304,7 @@ private[spark] object StatsReportListener extends Logging { dOpt.foreach { d => showDistribution(heading, d, formatNumber)} } - def showDistribution(heading: String, dOpt: Option[Distribution], format:String) { + def showDistribution(heading: String, dOpt: Option[Distribution], format: String) { def f(d: Double): String = format.format(d) showDistribution(heading, dOpt, f _) } @@ -318,7 +318,7 @@ private[spark] object StatsReportListener extends Logging { } def showBytesDistribution( - heading:String, + heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long], taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 5d0ddb8377c33..c59d6e4f5bc04 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.CallSite * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes * that each output partition is on. * - * Each Stage also has a jobId, identifying the job that first submitted the stage. When FIFO + * Each Stage also has a firstJobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. * @@ -51,7 +51,7 @@ private[spark] abstract class Stage( val rdd: RDD[_], val numTasks: Int, val parents: List[Stage], - val jobId: Int, + val firstJobId: Int, val callSite: CallSite) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 586d1e06204c1..15101c64f0503 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -125,7 +125,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - } + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index b4b8a630694bb..ed3dde0fc3055 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer import java.util.{TimerTask, Timer} +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong -import scala.concurrent.duration._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet @@ -32,7 +32,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -64,6 +64,9 @@ private[spark] class TaskSchedulerImpl( // How often to check for speculative tasks val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", "100ms") + private val speculationScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("task-scheduler-speculation") + // Threshold above which we warn user initial TaskSet may be starved val STARVATION_TIMEOUT_MS = conf.getTimeAsMs("spark.starvation.timeout", "15s") @@ -142,10 +145,11 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL_MS milliseconds, - SPECULATION_INTERVAL_MS milliseconds) { - Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } - }(sc.env.actorSystem.dispatcher) + speculationScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryOrStopSparkContext(sc) { + checkSpeculatableTasks() + } + }, SPECULATION_INTERVAL_MS, SPECULATION_INTERVAL_MS, TimeUnit.MILLISECONDS) } } @@ -412,6 +416,7 @@ private[spark] class TaskSchedulerImpl( } override def stop() { + speculationScheduler.shutdown() if (backend != null) { backend.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index c4487d5b37247..82455b0426a5d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -46,7 +46,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} * * @param sched the TaskSchedulerImpl associated with the TaskSetManager * @param taskSet the TaskSet to manage scheduling for - * @param maxTaskFailures if any particular task fails more than this number of times, the entire + * @param maxTaskFailures if any particular task fails this number of times, the entire * task set will be aborted */ private[spark] class TaskSetManager( @@ -781,10 +781,10 @@ private[spark] class TaskSetManager( // that it's okay if we add a task to the same queue twice (if it had multiple preferred // locations), because dequeueTaskFromList will skip already-running tasks. for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding=true) + addPendingTask(index, readding = true) } for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding=true) + addPendingTask(index, readding = true) } // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage, @@ -861,9 +861,9 @@ private[spark] class TaskSetManager( case TaskLocality.RACK_LOCAL => "spark.locality.wait.rack" case _ => null } - + if (localityWaitKey != null) { - conf.getTimeAsMs(localityWaitKey, defaultWait) + conf.getTimeAsMs(localityWaitKey, defaultWait) } else { 0L } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 70364cea62a80..4be1eda2e9291 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -75,7 +75,8 @@ private[spark] object CoarseGrainedClusterMessages { case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage // Exchanged between the driver and the AM in Yarn client mode - case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) + case class AddWebUIFilter( + filterName: String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage // Messages exchanged between the driver and the cluster manager for executor allocation diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index c5bc6294a5577..7c7f70d8a193b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -84,7 +84,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveIntervalMs = conf.getTimeAsMs("spark.scheduler.revive.interval", "1s") - + reviveThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { Option(self).foreach(_.send(ReviveOffers)) @@ -103,7 +103,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case None => // Ignoring the update since we don't know about the executor. logWarning(s"Ignored task status update ($taskId state $state) " + - "from unknown executor $sender with ID $executorId") + s"from unknown executor with ID $executorId") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 2a3a5d925d06f..190ff61d689d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -149,7 +149,7 @@ private[spark] abstract class YarnSchedulerBackend( } } - override def onStop(): Unit ={ + override def onStop(): Unit = { askAmThreadPool.shutdownNow() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index dc59545b43314..6b8edca5aa485 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -25,9 +25,10 @@ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** @@ -51,7 +52,7 @@ private[spark] class CoarseMesosSchedulerBackend( val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) - val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt // Cores we have acquired with each Mesos task ID val coresByTaskId = new HashMap[Int, Int] @@ -115,11 +116,9 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(sc.env.actorSystem), + val driverUrl = sc.env.rpcEnv.uriOf( SparkEnv.driverActorSystemName, - conf.get("spark.driver.host"), - conf.get("spark.driver.port"), + RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.getOption("spark.executor.uri") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index db0a080b3b0c0..49de85ef48ada 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -146,7 +146,7 @@ private[spark] class MesosSchedulerBackend( private def createExecArg(): Array[Byte] = { if (execArgs == null) { val props = new HashMap[String, String] - for ((key,value) <- sc.conf.getAll) { + for ((key, value) <- sc.conf.getAll) { props(key) = value } // Serialize the map as an array of (String, String) pairs diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index 928c5cfed417a..e79c543a9de27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -37,14 +37,14 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .newBuilder() .setMode(Volume.Mode.RW) spec match { - case Array(container_path) => + case Array(container_path) => Some(vol.setContainerPath(container_path)) case Array(container_path, "rw") => Some(vol.setContainerPath(container_path)) case Array(container_path, "ro") => Some(vol.setContainerPath(container_path) .setMode(Volume.Mode.RO)) - case Array(host_path, container_path) => + case Array(host_path, container_path) => Some(vol.setContainerPath(container_path) .setHostPath(host_path)) case Array(host_path, container_path, "rw") => @@ -108,7 +108,7 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { image: String, volumes: Option[List[Volume]] = None, network: Option[ContainerInfo.DockerInfo.Network] = None, - portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None):Unit = { + portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None): Unit = { val docker = ContainerInfo.DockerInfo.newBuilder().setImage(image) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 217957963437d..cd8a82347a1e9 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -17,8 +17,9 @@ package org.apache.spark.serializer -import java.io.{EOFException, InputStream, OutputStream} +import java.io.{EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer +import javax.annotation.Nullable import scala.reflect.ClassTag @@ -51,7 +52,7 @@ class KryoSerializer(conf: SparkConf) with Serializable { private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") - + if (bufferSizeKb >= ByteUnit.GiB.toKiB(2)) { throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + s"2048 mb, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} mb.") @@ -136,21 +137,45 @@ class KryoSerializer(conf: SparkConf) } private[spark] -class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { - val output = new KryoOutput(outStream) +class KryoSerializationStream( + serInstance: KryoSerializerInstance, + outStream: OutputStream) extends SerializationStream { + + private[this] var output: KryoOutput = new KryoOutput(outStream) + private[this] var kryo: Kryo = serInstance.borrowKryo() override def writeObject[T: ClassTag](t: T): SerializationStream = { kryo.writeClassAndObject(output, t) this } - override def flush() { output.flush() } - override def close() { output.close() } + override def flush() { + if (output == null) { + throw new IOException("Stream is closed") + } + output.flush() + } + + override def close() { + if (output != null) { + try { + output.close() + } finally { + serInstance.releaseKryo(kryo) + kryo = null + output = null + } + } + } } private[spark] -class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { - private val input = new KryoInput(inStream) +class KryoDeserializationStream( + serInstance: KryoSerializerInstance, + inStream: InputStream) extends DeserializationStream { + + private[this] var input: KryoInput = new KryoInput(inStream) + private[this] var kryo: Kryo = serInstance.borrowKryo() override def readObject[T: ClassTag](): T = { try { @@ -163,52 +188,105 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser } override def close() { - // Kryo's Input automatically closes the input stream it is using. - input.close() + if (input != null) { + try { + // Kryo's Input automatically closes the input stream it is using. + input.close() + } finally { + serInstance.releaseKryo(kryo) + kryo = null + input = null + } + } } } private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - private val kryo = ks.newKryo() - // Make these lazy vals to avoid creating a buffer unless we use them + /** + * A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do + * their work, then release the instance by calling `releaseKryo()`. Logically, this is a caching + * pool of size one. SerializerInstances are not thread-safe, hence accesses to this field are + * not synchronized. + */ + @Nullable private[this] var cachedKryo: Kryo = borrowKryo() + + /** + * Borrows a [[Kryo]] instance. If possible, this tries to re-use a cached Kryo instance; + * otherwise, it allocates a new instance. + */ + private[serializer] def borrowKryo(): Kryo = { + if (cachedKryo != null) { + val kryo = cachedKryo + // As a defensive measure, call reset() to clear any Kryo state that might have been modified + // by the last operation to borrow this instance (see SPARK-7766 for discussion of this issue) + kryo.reset() + cachedKryo = null + kryo + } else { + ks.newKryo() + } + } + + /** + * Release a borrowed [[Kryo]] instance. If this serializer instance already has a cached Kryo + * instance, then the given Kryo instance is discarded; otherwise, the Kryo is stored for later + * re-use. + */ + private[serializer] def releaseKryo(kryo: Kryo): Unit = { + if (cachedKryo == null) { + cachedKryo = kryo + } + } + + // Make these lazy vals to avoid creating a buffer unless we use them. private lazy val output = ks.newKryoOutput() private lazy val input = new KryoInput() override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() - kryo.reset() // We must reset in case this serializer instance was reused (see SPARK-7766) + val kryo = borrowKryo() try { kryo.writeClassAndObject(output, t) } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + "increase spark.kryoserializer.buffer.max value.") + } finally { + releaseKryo(kryo) } ByteBuffer.wrap(output.toBytes) } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - input.setBuffer(bytes.array) - kryo.readClassAndObject(input).asInstanceOf[T] + val kryo = borrowKryo() + try { + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } finally { + releaseKryo(kryo) + } } override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + val kryo = borrowKryo() val oldClassLoader = kryo.getClassLoader - kryo.setClassLoader(loader) - input.setBuffer(bytes.array) - val obj = kryo.readClassAndObject(input).asInstanceOf[T] - kryo.setClassLoader(oldClassLoader) - obj + try { + kryo.setClassLoader(loader) + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } finally { + kryo.setClassLoader(oldClassLoader) + releaseKryo(kryo) + } } override def serializeStream(s: OutputStream): SerializationStream = { - kryo.reset() // We must reset in case this serializer instance was reused (see SPARK-7766) - new KryoSerializationStream(kryo, s) + new KryoSerializationStream(this, s) } override def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(kryo, s) + new KryoDeserializationStream(this, s) } /** @@ -218,7 +296,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ def getAutoReset(): Boolean = { val field = classOf[Kryo].getDeclaredField("autoReset") field.setAccessible(true) - field.get(kryo).asInstanceOf[Boolean] + val kryo = borrowKryo() + try { + field.get(kryo).asInstanceOf[Boolean] + } finally { + releaseKryo(kryo) + } } } diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 6078c9d433ebf..bd2704dc81871 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io._ import java.nio.ByteBuffer +import javax.annotation.concurrent.NotThreadSafe import scala.reflect.ClassTag @@ -114,8 +115,12 @@ object Serializer { /** * :: DeveloperApi :: * An instance of a serializer, for use by one thread at a time. + * + * It is legal to create multiple serialization / deserialization streams from the same + * SerializerInstance as long as those streams are all used within the same thread. */ @DeveloperApi +@NotThreadSafe abstract class SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer @@ -177,6 +182,7 @@ abstract class DeserializationStream { } catch { case eof: EOFException => finished = true + null } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 80374adc44296..597d46a3d2223 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -80,7 +80,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blocksByAddress, serializer, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index c9dd6bfc4c219..5865e7640c1cf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -17,9 +17,10 @@ package org.apache.spark.shuffle.sort -import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext} +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus +import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -35,7 +36,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val blockManager = SparkEnv.get.blockManager - private var sorter: ExternalSorter[K, V, _] = null + private var sorter: SortShuffleFileWriter[K, V] = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -49,18 +50,27 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { - if (dep.mapSideCombine) { + sorter = if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") - sorter = new ExternalSorter[K, V, C]( + new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - sorter.insertAll(records) + } else if (SortShuffleWriter.shouldBypassMergeSort( + SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need local aggregation and sorting, write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner, + writeMetrics, Serializer.getSerializer(dep.serializer)) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer) - sorter.insertAll(records) + new ExternalSorter[K, V, V]( + aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) } + sorter.insertAll(records) // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately @@ -100,3 +110,13 @@ private[spark] class SortShuffleWriter[K, V, C]( } } +private[spark] object SortShuffleWriter { + def shouldBypassMergeSort( + conf: SparkConf, + numPartitions: Int, + aggregator: Option[Aggregator[_, _, _]], + keyOrdering: Option[Ordering[_]]): Boolean = { + val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 50608588f09ae..390c136df79b3 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -169,7 +169,7 @@ private[v1] object AllStagesResource { val outputMetrics: Option[OutputMetricDistributions] = new MetricHelper[InternalOutputMetrics, OutputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw:InternalTaskMetrics): Option[InternalOutputMetrics] = { + def getSubmetrics(raw: InternalTaskMetrics): Option[InternalOutputMetrics] = { raw.outputMetrics } def build: OutputMetricDistributions = new OutputMetricDistributions( @@ -284,7 +284,7 @@ private[v1] object AllStagesResource { * the options (returning None if the metrics are all empty), and extract the quantiles for each * metric. After creating an instance, call metricOption to get the result type. */ -private[v1] abstract class MetricHelper[I,O]( +private[v1] abstract class MetricHelper[I, O]( rawMetrics: Seq[InternalTaskMetrics], quantiles: Array[Double]) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index bf2cc2e72f1fe..50b6ba67e9931 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.status.api.v1 +import java.util.zip.ZipOutputStream import javax.servlet.ServletContext import javax.ws.rs._ import javax.ws.rs.core.{Context, Response} @@ -101,7 +102,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { @Path("applications/{appId}/stages") - def getStages(@PathParam("appId") appId: String): AllStagesResource= { + def getStages(@PathParam("appId") appId: String): AllStagesResource = { uiRoot.withSparkUI(appId, None) { ui => new AllStagesResource(ui) } @@ -110,14 +111,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { @Path("applications/{appId}/{attemptId}/stages") def getStages( @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): AllStagesResource= { + @PathParam("attemptId") attemptId: String): AllStagesResource = { uiRoot.withSparkUI(appId, Some(attemptId)) { ui => new AllStagesResource(ui) } } @Path("applications/{appId}/stages/{stageId: \\d+}") - def getStage(@PathParam("appId") appId: String): OneStageResource= { + def getStage(@PathParam("appId") appId: String): OneStageResource = { uiRoot.withSparkUI(appId, None) { ui => new OneStageResource(ui) } @@ -164,6 +165,18 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { } } + @Path("applications/{appId}/logs") + def getEventLogs( + @PathParam("appId") appId: String): EventLogDownloadResource = { + new EventLogDownloadResource(uiRoot, appId, None) + } + + @Path("applications/{appId}/{attemptId}/logs") + def getEventLogs( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { + new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + } } private[spark] object ApiRootResource { @@ -171,7 +184,7 @@ private[spark] object ApiRootResource { def getServletHandler(uiRoot: UIRoot): ServletContextHandler = { val jerseyContext = new ServletContextHandler(ServletContextHandler.NO_SESSIONS) jerseyContext.setContextPath("/api") - val holder:ServletHolder = new ServletHolder(classOf[ServletContainer]) + val holder: ServletHolder = new ServletHolder(classOf[ServletContainer]) holder.setInitParameter("com.sun.jersey.config.property.resourceConfigClass", "com.sun.jersey.api.core.PackagesResourceConfig") holder.setInitParameter("com.sun.jersey.config.property.packages", @@ -193,6 +206,17 @@ private[spark] trait UIRoot { def getSparkUI(appKey: String): Option[SparkUI] def getApplicationInfoList: Iterator[ApplicationInfo] + /** + * Write the event logs for the given app to the [[ZipOutputStream]] instance. If attemptId is + * [[None]], event logs for all attempts of this application will be written out. + */ + def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit = { + Response.serverError() + .entity("Event logs are only available through the history server.") + .status(Response.Status.SERVICE_UNAVAILABLE) + .build() + } + /** * Get the spark UI with the given appID, and apply a function * to it. If there is no such app, throw an appropriate exception diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala new file mode 100644 index 0000000000000..22e21f0c62a29 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import java.io.OutputStream +import java.util.zip.ZipOutputStream +import javax.ws.rs.{GET, Produces} +import javax.ws.rs.core.{MediaType, Response, StreamingOutput} + +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil + +@Produces(Array(MediaType.APPLICATION_OCTET_STREAM)) +private[v1] class EventLogDownloadResource( + val uIRoot: UIRoot, + val appId: String, + val attemptId: Option[String]) extends Logging { + val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf) + + @GET + def getEventLogs(): Response = { + try { + val fileName = { + attemptId match { + case Some(id) => s"eventLogs-$appId-$id.zip" + case None => s"eventLogs-$appId.zip" + } + } + + val stream = new StreamingOutput { + override def write(output: OutputStream): Unit = { + val zipStream = new ZipOutputStream(output) + try { + uIRoot.writeEventLogs(appId, attemptId, zipStream) + } finally { + zipStream.close() + } + + } + } + + Response.ok(stream) + .header("Content-Disposition", s"attachment; filename=$fileName") + .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) + .build() + } catch { + case NonFatal(e) => + Response.serverError() + .entity(s"Event logs are not available for app: $appId.") + .status(Response.Status.SERVICE_UNAVAILABLE) + .build() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala index 07b224fac4786..dfdc09c6caf3b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala @@ -25,7 +25,7 @@ import org.apache.spark.ui.SparkUI private[v1] class OneRDDResource(ui: SparkUI) { @GET - def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { + def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { AllRDDResource.getRDDStorageInfo(rddId, ui.storageListener, true).getOrElse( throw new NotFoundException(s"no rdd found w/ id $rddId") ) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala index fd24aea63a8a1..f9812f06cf527 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala @@ -83,7 +83,7 @@ private[v1] class OneStageResource(ui: SparkUI) { withStageAttempt(stageId, stageAttemptId) { stage => val tasks = stage.ui.taskData.values.map{AllStagesResource.convertTaskData}.toIndexedSeq .sorted(OneStageResource.ordering(sortBy)) - tasks.slice(offset, offset + length) + tasks.slice(offset, offset + length) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index cee29786c3019..0c71cd2382225 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -16,40 +16,33 @@ */ package org.apache.spark.status.api.v1 -import java.text.SimpleDateFormat +import java.text.{ParseException, SimpleDateFormat} import java.util.TimeZone import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status -import scala.util.Try - private[v1] class SimpleDateParam(val originalValue: String) { - val timestamp: Long = { - SimpleDateParam.formats.collectFirst { - case fmt if Try(fmt.parse(originalValue)).isSuccess => - fmt.parse(originalValue).getTime() - }.getOrElse( - throw new WebApplicationException( - Response - .status(Status.BAD_REQUEST) - .entity("Couldn't parse date: " + originalValue) - .build() - ) - ) - } -} -private[v1] object SimpleDateParam { - - val formats: Seq[SimpleDateFormat] = { - - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") - gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) - - Seq( - new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz"), - gmtDay - ) + val timestamp: Long = { + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + try { + format.parse(originalValue).getTime() + } catch { + case _: ParseException => + val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) + try { + gmtDay.parse(originalValue).getTime() + } catch { + case _: ParseException => + throw new WebApplicationException( + Response + .status(Status.BAD_REQUEST) + .entity("Couldn't parse date: " + originalValue) + .build() + ) + } + } } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index ef3c8570d8186..2bec64f2ef02b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -134,7 +134,7 @@ class StageData private[spark]( val accumulatorUpdates: Seq[AccumulableInfo], val tasks: Option[Map[Long, TaskData]], - val executorSummary:Option[Map[String,ExecutorStageSummary]]) + val executorSummary: Option[Map[String, ExecutorStageSummary]]) class TaskData private[spark]( val taskId: Long, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index abcad9438bf28..7cdae22b0e253 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -202,6 +202,14 @@ class BlockManagerMaster( Await.result(future, timeout) } + /** + * Find out if the executor has cached blocks. This method does not consider broadcast blocks, + * since they are not reported the master. + */ + def hasCachedBlocks(executorId: String): Boolean = { + driverEndpoint.askWithRetry[Boolean](HasCachedBlocks(executorId)) + } + /** Stop the driver endpoint, called only on the Spark driver node */ def stop() { if (driverEndpoint != null && isDriver) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 3afb4c3c02e2d..68ed9096731c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.util.{HashMap => JHashMap} +import scala.collection.immutable.HashSet import scala.collection.mutable import scala.collection.JavaConversions._ import scala.concurrent.{ExecutionContext, Future} @@ -112,6 +113,17 @@ class BlockManagerMasterEndpoint( case BlockManagerHeartbeat(blockManagerId) => context.reply(heartbeatReceived(blockManagerId)) + case HasCachedBlocks(executorId) => + blockManagerIdByExecutor.get(executorId) match { + case Some(bm) => + if (blockManagerInfo.contains(bm)) { + val bmInfo = blockManagerInfo(bm) + context.reply(bmInfo.cachedBlocks.nonEmpty) + } else { + context.reply(false) + } + case None => context.reply(false) + } } private def removeRdd(rddId: Int): Future[Seq[Int]] = { @@ -292,16 +304,16 @@ class BlockManagerMasterEndpoint( blockManagerIdByExecutor.get(id.executorId) match { case Some(oldId) => // A block manager of the same executor already exists, so remove it (assumed dead) - logError("Got two different block manager registrations on same executor - " + logError("Got two different block manager registrations on same executor - " + s" will replace old one $oldId with new one $id") - removeExecutor(id.executorId) + removeExecutor(id.executorId) case None => } logInfo("Registering block manager %s with %s RAM, %s".format( id.hostPort, Utils.bytesToString(maxMemSize), id)) - + blockManagerIdByExecutor(id.executorId) = id - + blockManagerInfo(id) = new BlockManagerInfo( id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) } @@ -418,6 +430,9 @@ private[spark] class BlockManagerInfo( // Mapping from block id to its status. private val _blocks = new JHashMap[BlockId, BlockStatus] + // Cached blocks held by this BlockManager. This does not include broadcast blocks. + private val _cachedBlocks = new mutable.HashSet[BlockId] + def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) def updateLastSeenMs() { @@ -451,27 +466,35 @@ private[spark] class BlockManagerInfo( * and the diskSize here indicates the data size in or dropped to disk. * They can be both larger than 0, when a block is dropped from memory to disk. * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */ + var blockStatus: BlockStatus = null if (storageLevel.useMemory) { - _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0)) + blockStatus = BlockStatus(storageLevel, memSize, 0, 0) + _blocks.put(blockId, blockStatus) _remainingMem -= memSize logInfo("Added %s in memory on %s (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), Utils.bytesToString(_remainingMem))) } if (storageLevel.useDisk) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0)) + blockStatus = BlockStatus(storageLevel, 0, diskSize, 0) + _blocks.put(blockId, blockStatus) logInfo("Added %s on disk on %s (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) } if (storageLevel.useOffHeap) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, externalBlockStoreSize)) + blockStatus = BlockStatus(storageLevel, 0, 0, externalBlockStoreSize) + _blocks.put(blockId, blockStatus) logInfo("Added %s on ExternalBlockStore on %s (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(externalBlockStoreSize))) } + if (!blockId.isBroadcast && blockStatus.isCached) { + _cachedBlocks += blockId + } } else if (_blocks.containsKey(blockId)) { // If isValid is not true, drop the block. val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) + _cachedBlocks -= blockId if (blockStatus.storageLevel.useMemory) { logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), @@ -494,6 +517,7 @@ private[spark] class BlockManagerInfo( _remainingMem += _blocks.get(blockId).memSize _blocks.remove(blockId) } + _cachedBlocks -= blockId } def remainingMem: Long = _remainingMem @@ -502,6 +526,9 @@ private[spark] class BlockManagerInfo( def blocks: JHashMap[BlockId, BlockStatus] = _blocks + // This does not include broadcast blocks. + def cachedBlocks: collection.Set[BlockId] = _cachedBlocks + override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem def clear() { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 1683576067fe8..376e9eb48843d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -42,7 +42,6 @@ private[spark] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave - ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// @@ -108,4 +107,6 @@ private[spark] object BlockManagerMessages { extends ToBlockManagerMaster case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + + case class HasCachedBlocks(executorId: String) extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 543df4e1350dd..7478ab0fc2f7a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -40,7 +40,7 @@ class BlockManagerSlaveEndpoint( private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 8569c6f3cbbc3..c5ba9af3e2658 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -17,9 +17,8 @@ package org.apache.spark.storage -import com.codahale.metrics.{Gauge,MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry} -import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source private[spark] class BlockManagerSource(val blockManager: BlockManager) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index a33f22ef52687..7eeabd1e0489c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -95,6 +95,7 @@ private[spark] class DiskBlockObjectWriter( private var objOut: SerializationStream = null private var initialized = false private var hasBeenClosed = false + private var commitAndCloseHasBeenCalled = false /** * Cursors used to represent positions in the file. @@ -167,20 +168,22 @@ private[spark] class DiskBlockObjectWriter( objOut.flush() bs.flush() close() + finalPosition = file.length() + // In certain compression codecs, more bytes are written after close() is called + writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + } else { + finalPosition = file.length() } - finalPosition = file.length() - // In certain compression codecs, more bytes are written after close() is called - writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + commitAndCloseHasBeenCalled = true } // Discard current writes. We do this by flushing the outstanding writes and then // truncating the file to its initial position. override def revertPartialWritesAndClose() { try { - writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) - writeMetrics.decShuffleRecordsWritten(numRecordsWritten) - if (initialized) { + writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) + writeMetrics.decShuffleRecordsWritten(numRecordsWritten) objOut.flush() bs.flush() close() @@ -228,6 +231,10 @@ private[spark] class DiskBlockObjectWriter( } override def fileSegment(): FileSegment = { + if (!commitAndCloseHasBeenCalled) { + throw new IllegalStateException( + "fileSegment() is only valid after commitAndClose() has been called") + } new FileSegment(file, initialPosition, finalPosition - initialPosition) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 2a4447705fa65..91ef86389a0c3 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -139,8 +139,8 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def addShutdownHook(): AnyRef = { - Utils.addShutdownHook { () => - logDebug("Shutdown hook called") + Utils.addShutdownHook(Utils.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => + logInfo("Shutdown hook called") DiskBlockManager.this.doStop() } } @@ -151,7 +151,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon try { Utils.removeShutdownHook(shutdownHook) } catch { - case e: Exception => + case e: Exception => logError(s"Exception while removing shutdown hook.", e) } doStop() diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 95e2d688d9b17..021a9facfb0b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -24,6 +24,8 @@ import java.io.File * based off an offset and a length. */ private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { + require(offset >= 0, s"File segment offset cannot be negative (got $offset)") + require(length >= 0, s"File segment length cannot be negative (got $length)") override def toString: String = { "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index fb4ba0eac9d9a..b53c86e89a273 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -100,7 +100,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log try { os.write(bytes.array()) } catch { - case NonFatal(e) => + case NonFatal(e) => logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) os.cancel() } finally { @@ -114,7 +114,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log try { blockManager.dataSerializeStream(blockId, os, values) } catch { - case NonFatal(e) => + case NonFatal(e) => logWarning(s"Failed to put values of block $blockId into Tachyon", e) os.cancel() } finally { diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 0b11e914bb251..3788916cf39bb 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -137,7 +137,7 @@ private[spark] object SparkUI { jobProgressListener: JobProgressListener, securityManager: SecurityManager, appName: String, - startTime: Long): SparkUI = { + startTime: Long): SparkUI = { create(Some(sc), conf, listenerBus, securityManager, appName, jobProgressListener = Some(jobProgressListener), startTime = startTime) } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 6194c50ec8c7c..65162f4fdcd62 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -309,7 +309,7 @@ private[spark] object UIUtils extends Logging { started: Int, completed: Int, failed: Int, - skipped:Int, + skipped: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) val startWidth = "width: %s%%".format((started.toDouble/total)*100) diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index 5fbcd6bb8ad94..ba03acdb38cc5 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -54,7 +54,7 @@ private[spark] object UIWorkloadGenerator { val sc = new SparkContext(conf) def setProperties(s: String): Unit = { - if(schedulingMode == SchedulingMode.FAIR) { + if (schedulingMode == SchedulingMode.FAIR) { sc.setLocalProperty("spark.scheduler.pool", s) } sc.setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, s) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 594df15e9cc85..2c84e4485996e 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -62,12 +62,12 @@ private[spark] abstract class WebUI( tab.pages.foreach(attachPage) tabs += tab } - + def detachTab(tab: WebUITab) { tab.pages.foreach(detachPage) tabs -= tab } - + def detachPage(page: WebUIPage) { pageToHandlers.remove(page).foreach(_.foreach(detachHandler)) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index e010ebef3b34a..2ce670ad02e97 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -231,7 +231,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { {lastStageDescription} - {lastStageName} + {lastStageName} {formattedSubmissionTime} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 246e191d64776..0c854f04890b6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -17,8 +17,12 @@ package org.apache.spark.ui.jobs +import java.util.concurrent.TimeoutException + import scala.collection.mutable.{HashMap, HashSet, ListBuffer} +import com.google.common.annotations.VisibleForTesting + import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics @@ -119,7 +123,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { "failedStages" -> failedStages.size ) } - + // These collections may grow arbitrarily, but once Spark becomes idle they should shrink back to // some bound based on the `spark.ui.retainedStages` and `spark.ui.retainedJobs` settings: private[spark] def getSizesOfSoftSizeLimitedCollections: Map[String, Int] = { @@ -278,7 +282,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { ) { jobData.numActiveStages -= 1 if (stage.failureReason.isEmpty) { - jobData.completedStageIndices.add(stage.stageId) + if (!stage.submissionTime.isEmpty) { + jobData.completedStageIndices.add(stage.stageId) + } } else { jobData.numFailedStages += 1 } @@ -311,6 +317,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { jobData <- jobIdToData.get(jobId) ) { jobData.numActiveStages += 1 + + // If a stage retries again, it should be removed from completedStageIndices set + jobData.completedStageIndices.remove(stage.stageId) } } @@ -526,4 +535,30 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onApplicationStart(appStarted: SparkListenerApplicationStart) { startTime = appStarted.time } + + /** + * For testing only. Wait until at least `numExecutors` executors are up, or throw + * `TimeoutException` if the waiting time elapsed before `numExecutors` executors up. + * Exposed for testing. + * + * @param numExecutors the number of executors to wait at least + * @param timeout time to wait in milliseconds + */ + private[spark] def waitUntilExecutorsUp(numExecutors: Int, timeout: Long): Unit = { + val finishTime = System.currentTimeMillis() + timeout + while (System.currentTimeMillis() < finishTime) { + val numBlockManagers = synchronized { + blockManagerIds.size + } + if (numBlockManagers >= numExecutors + 1) { + // Need to count the block manager in driver + return + } + // Sleep rather than using wait/notify, because this is used only for testing and wait/notify + // add overhead in the general case. + Thread.sleep(10) + } + throw new TimeoutException( + s"Can't find $numExecutors executors before $timeout milliseconds elapsed") + } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 31e2e7fba9783..b83a49f79c8a8 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -527,7 +527,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { minLaunchTime = launchTime.min(minLaunchTime) maxFinishTime = finishTime.max(maxFinishTime) - def toProportion(time: Long) = (time.toDouble / totalExecutionTime * 100).toLong + def toProportion(time: Long) = time.toDouble / totalExecutionTime * 100 val metricsOpt = taskUIData.taskMetrics val shuffleReadTime = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 82ba561eefb16..99812db4912a3 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -93,7 +93,7 @@ private[ui] class StageTableBase( } val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" - val nameLink = {s.name} + val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) val details = if (s.details.nonEmpty) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 3d96113aa5fe9..f008d40180611 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -22,6 +22,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} import org.apache.spark.util.collection.OpenHashSet +import scala.collection.mutable import scala.collection.mutable.HashMap private[spark] object UIData { @@ -63,7 +64,7 @@ private[spark] object UIData { /* Stages */ var numActiveStages: Int = 0, // This needs to be a set instead of a simple count to prevent double-counting of rerun stages: - var completedStageIndices: OpenHashSet[Int] = new OpenHashSet[Int](), + var completedStageIndices: mutable.HashSet[Int] = new mutable.HashSet[Int](), var numSkippedStages: Int = 0, var numFailedStages: Int = 0 ) diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index fbce917a0824d..36943978ff594 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -33,7 +33,7 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { val parameterId = request.getParameter("id") require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") val rddId = parameterId.toInt - val rddStorageInfo = AllRDDResource.getRDDStorageInfo(rddId, listener,includeDetails = true) + val rddStorageInfo = AllRDDResource.getRDDStorageInfo(rddId, listener, includeDetails = true) .getOrElse { // Rather than crashing, render an "RDD Not Found" page return UIUtils.headerSparkPage("RDD Not Found", Seq[Node](), parent) diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index de3316d083a22..96aa2fe164703 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -63,7 +63,7 @@ private[spark] object AkkaUtils extends Logging { conf: SparkConf, securityManager: SecurityManager): (ActorSystem, Int) = { - val akkaThreads = conf.getInt("spark.akka.threads", 4) + val akkaThreads = conf.getInt("spark.akka.threads", 4) val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) val akkaTimeoutS = conf.getTimeAsSeconds("spark.akka.timeout", conf.get("spark.network.timeout", "120s")) @@ -235,7 +235,7 @@ private[spark] object AkkaUtils extends Logging { protocol: String, systemName: String, host: String, - port: Any, + port: Int, actorName: String): String = { s"$protocol://$systemName@$host:$port/user/$actorName" } diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index ce7887b76ff96..61b5a4cecddce 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -40,7 +40,7 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri self => private var sparkContext: SparkContext = null - + /* Cap the capacity of the event queue so we get an explicit error (rather than * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ private val EVENT_QUEUE_CAPACITY = 10000 @@ -120,21 +120,22 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri /** * For testing only. Wait until there are no more events in the queue, or until the specified - * time has elapsed. Return true if the queue has emptied and false is the specified time - * elapsed before the queue emptied. + * time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue + * emptied. */ @VisibleForTesting - def waitUntilEmpty(timeoutMillis: Int): Boolean = { + @throws(classOf[TimeoutException]) + def waitUntilEmpty(timeoutMillis: Long): Unit = { val finishTime = System.currentTimeMillis + timeoutMillis while (!queueIsEmpty) { if (System.currentTimeMillis > finishTime) { - return false + throw new TimeoutException( + s"The event queue is not empty after $timeoutMillis milliseconds") } /* Sleep rather than using wait/notify, because this is used only for testing and * wait/notify add overhead in the general case. */ Thread.sleep(10) } - true } /** diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 9044aaeef2d48..31d230d0fec8e 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -42,7 +42,7 @@ abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterat private[spark] object CompletionIterator { def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A, I] = { - new CompletionIterator[A,I](sub) { + new CompletionIterator[A, I](sub) { def completion(): Unit = completionFunction } } diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala index 9aea8efa38c7a..1bab707235b89 100644 --- a/core/src/main/scala/org/apache/spark/util/Distribution.scala +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -35,7 +35,7 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va java.util.Arrays.sort(data, startIdx, endIdx) val length = endIdx - startIdx - val defaultProbabilities = Array(0,0.25,0.5,0.75,1.0) + val defaultProbabilities = Array(0, 0.25, 0.5, 0.75, 1.0) /** * Get the value of the distribution at the given probabilities. Probabilities should be @@ -44,7 +44,7 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va */ def getQuantiles(probabilities: Traversable[Double] = defaultProbabilities) : IndexedSeq[Double] = { - probabilities.toIndexedSeq.map{p:Double => data(closestIndex(p))} + probabilities.toIndexedSeq.map { p: Double => data(closestIndex(p)) } } private def closestIndex(p: Double) = { diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 2bbfc988a99a8..a8bbad086849e 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -89,7 +89,7 @@ private[spark] object MetadataCleaner { conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType, delay: Int) { - conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) + conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) } /** diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala index dad888548ed10..3d95b7869f494 100644 --- a/core/src/main/scala/org/apache/spark/util/MutablePair.scala +++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala @@ -45,5 +45,5 @@ case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef override def toString: String = "(" + _1 + "," + _2 + ")" - override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]] + override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_, _]] } diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala index 1e0ba5c28754a..169489df6c1ea 100644 --- a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -52,8 +52,8 @@ private[spark] class ChildFirstURLClassLoader(urls: Array[URL], parent: ClassLoa * Used to implement fine-grained class loading locks similar to what is done by Java 7. This * prevents deadlock issues when using non-hierarchical class loaders. * - * Note that due to Java 6 compatibility (and some issues with implementing class loaders in - * Scala), Java 7's `ClassLoader.registerAsParallelCapable` method is not called. + * Note that due to some issues with implementing class loaders in + * Scala, Java 7's `ClassLoader.registerAsParallelCapable` method is not called. */ private val locks = new ConcurrentHashMap[String, Object]() diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 968a72d5adae9..0180399c9dad5 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -21,31 +21,47 @@ import java.lang.management.ManagementFactory import java.lang.reflect.{Field, Modifier} import java.util.{IdentityHashMap, Random} import java.util.concurrent.ConcurrentHashMap + import scala.collection.mutable.ArrayBuffer import scala.runtime.ScalaRunTime import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.OpenHashSet /** + * :: DeveloperApi :: * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in * memory-aware caches. * * Based on the following JavaWorld article: * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html */ -private[spark] object SizeEstimator extends Logging { +@DeveloperApi +object SizeEstimator extends Logging { + + /** + * Estimate the number of bytes that the given object takes up on the JVM heap. The estimate + * includes space taken up by objects referenced by the given object, their references, and so on + * and so forth. + * + * This is useful for determining the amount of heap space a broadcast variable will occupy on + * each executor or the amount of space each object will take when caching objects in + * deserialized form. This is not the same as the serialized size of the object, which will + * typically be much smaller. + */ + def estimate(obj: AnyRef): Long = estimate(obj, new IdentityHashMap[AnyRef, AnyRef]) // Sizes of primitive types - private val BYTE_SIZE = 1 + private val BYTE_SIZE = 1 private val BOOLEAN_SIZE = 1 - private val CHAR_SIZE = 2 - private val SHORT_SIZE = 2 - private val INT_SIZE = 4 - private val LONG_SIZE = 8 - private val FLOAT_SIZE = 4 - private val DOUBLE_SIZE = 8 + private val CHAR_SIZE = 2 + private val SHORT_SIZE = 2 + private val INT_SIZE = 4 + private val LONG_SIZE = 8 + private val FLOAT_SIZE = 4 + private val DOUBLE_SIZE = 8 // Fields can be primitive types, sizes are: 1, 2, 4, 8. Or fields can be pointers. The size of // a pointer is 4 or 8 depending on the JVM (32-bit or 64-bit) and UseCompressedOops flag. @@ -80,7 +96,7 @@ private[spark] object SizeEstimator extends Logging { isCompressedOops = getIsCompressedOops objectSize = if (!is64bit) 8 else { - if(!isCompressedOops) { + if (!isCompressedOops) { 16 } else { 12 @@ -161,8 +177,6 @@ private[spark] object SizeEstimator extends Logging { val shellSize: Long, val pointerFields: List[Field]) {} - def estimate(obj: AnyRef): Long = estimate(obj, new IdentityHashMap[AnyRef, AnyRef]) - private def estimate(obj: AnyRef, visited: IdentityHashMap[AnyRef, AnyRef]): Long = { val state = new SearchState(visited) state.enqueue(obj) @@ -222,7 +236,7 @@ private[spark] object SizeEstimator extends Logging { val s1 = sampleArray(array, state, rand, drawn, length) val s2 = sampleArray(array, state, rand, drawn, length) val size = math.min(s1, s2) - state.size += math.max(s1, s2) + + state.size += math.max(s1, s2) + (size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong } } @@ -230,7 +244,7 @@ private[spark] object SizeEstimator extends Logging { private def sampleArray( array: AnyRef, - state: SearchState, + state: SearchState, rand: Random, drawn: OpenHashSet[Int], length: Int): Long = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b7a2473dfe920..153ece6224a6d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -73,6 +73,13 @@ private[spark] object Utils extends Logging { */ val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 + /** + * The shutdown priority of temp directory must be lower than the SparkContext shutdown + * priority. Otherwise cleaning the temp directories while Spark jobs are running can + * throw undesirable errors at the time of shutdown. + */ + val TEMP_DIR_SHUTDOWN_PRIORITY = 25 + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null @@ -189,10 +196,11 @@ private[spark] object Utils extends Logging { private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() // Add a shutdown hook to delete the temp dirs when the JVM exits - addShutdownHook { () => - logDebug("Shutdown hook called") + addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => + logInfo("Shutdown hook called") shutdownDeletePaths.foreach { dirPath => try { + logInfo("Deleting directory " + dirPath) Utils.deleteRecursively(new File(dirPath)) } catch { case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) @@ -882,7 +890,7 @@ private[spark] object Utils extends Logging { // If not, we should change it to LRUCache or something. private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() - def parseHostPort(hostPort: String): (String, Int) = { + def parseHostPort(hostPort: String): (String, Int) = { // Check cache first. val cached = hostPortParseResults.get(hostPort) if (cached != null) { @@ -1287,8 +1295,7 @@ private[spark] object Utils extends Logging { } catch { case t: Throwable => if (originalThrowable != null) { - // We could do originalThrowable.addSuppressed(t), but it's - // not available in JDK 1.6. + originalThrowable.addSuppressed(t) logWarning(s"Suppressing exception in finally: " + t.getMessage, t) throw originalThrowable } else { @@ -2219,6 +2226,22 @@ private[spark] object Utils extends Logging { } } + /** + * Return whether the specified file is a parent directory of the child file. + */ + def isInDirectory(parent: File, child: File): Boolean = { + if (child == null || parent == null) { + return false + } + if (!child.exists() || !parent.exists() || !parent.isDirectory()) { + return false + } + if (parent.equals(child)) { + return true + } + isInDirectory(parent, child.getParentFile) + } + } private [util] class SparkShutdownHookManager { diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 41cb8cfe2afa3..9c15b1188d91c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -161,7 +161,7 @@ class BitSet(numBits: Int) extends Serializable { override def hasNext: Boolean = ind >= 0 override def next(): Int = { val tmp = ind - ind = nextSetBit(ind + 1) + ind = nextSetBit(ind + 1) tmp } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala index a60bffe611f14..516aaa44d03fc 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala @@ -28,11 +28,13 @@ import scala.collection.mutable.ArrayBuffer * occupy a contiguous segment of memory. */ private[spark] class ChainedBuffer(chunkSize: Int) { - private val chunkSizeLog2 = (math.log(chunkSize) / math.log(2)).toInt - assert(math.pow(2, chunkSizeLog2).toInt == chunkSize, + + private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros( + java.lang.Long.highestOneBit(chunkSize)) + assert((1 << chunkSizeLog2) == chunkSize, s"ChainedBuffer chunk size $chunkSize must be a power of two") private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]() - private var _size: Int = _ + private var _size: Long = 0 /** * Feed bytes from this buffer into a BlockObjectWriter. @@ -41,16 +43,16 @@ private[spark] class ChainedBuffer(chunkSize: Int) { * @param os OutputStream to read into. * @param len Number of bytes to read. */ - def read(pos: Int, os: OutputStream, len: Int): Unit = { + def read(pos: Long, os: OutputStream, len: Int): Unit = { if (pos + len > _size) { throw new IndexOutOfBoundsException( s"Read of $len bytes at position $pos would go past size ${_size} of buffer") } - var chunkIndex = pos >> chunkSizeLog2 - var posInChunk = pos - (chunkIndex << chunkSizeLog2) - var written = 0 + var chunkIndex: Int = (pos >> chunkSizeLog2).toInt + var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt + var written: Int = 0 while (written < len) { - val toRead = math.min(len - written, chunkSize - posInChunk) + val toRead: Int = math.min(len - written, chunkSize - posInChunk) os.write(chunks(chunkIndex), posInChunk, toRead) written += toRead chunkIndex += 1 @@ -66,16 +68,16 @@ private[spark] class ChainedBuffer(chunkSize: Int) { * @param offs Offset in the byte array to read to. * @param len Number of bytes to read. */ - def read(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = { + def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { if (pos + len > _size) { throw new IndexOutOfBoundsException( s"Read of $len bytes at position $pos would go past size of buffer") } - var chunkIndex = pos >> chunkSizeLog2 - var posInChunk = pos - (chunkIndex << chunkSizeLog2) - var written = 0 + var chunkIndex: Int = (pos >> chunkSizeLog2).toInt + var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt + var written: Int = 0 while (written < len) { - val toRead = math.min(len - written, chunkSize - posInChunk) + val toRead: Int = math.min(len - written, chunkSize - posInChunk) System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead) written += toRead chunkIndex += 1 @@ -91,22 +93,22 @@ private[spark] class ChainedBuffer(chunkSize: Int) { * @param offs Offset in the byte array to write from. * @param len Number of bytes to write. */ - def write(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = { + def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { if (pos > _size) { throw new IndexOutOfBoundsException( s"Write at position $pos starts after end of buffer ${_size}") } // Grow if needed - val endChunkIndex = (pos + len - 1) >> chunkSizeLog2 + val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt while (endChunkIndex >= chunks.length) { chunks += new Array[Byte](chunkSize) } - var chunkIndex = pos >> chunkSizeLog2 - var posInChunk = pos - (chunkIndex << chunkSizeLog2) - var written = 0 + var chunkIndex: Int = (pos >> chunkSizeLog2).toInt + var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt + var written: Int = 0 while (written < len) { - val toWrite = math.min(len - written, chunkSize - posInChunk) + val toWrite: Int = math.min(len - written, chunkSize - posInChunk) System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite) written += toWrite chunkIndex += 1 @@ -119,19 +121,19 @@ private[spark] class ChainedBuffer(chunkSize: Int) { /** * Total size of buffer that can be written to without allocating additional memory. */ - def capacity: Int = chunks.size * chunkSize + def capacity: Long = chunks.size.toLong * chunkSize /** * Size of the logical buffer. */ - def size: Int = _size + def size: Long = _size } /** * Output stream that writes to a ChainedBuffer. */ private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream { - private var pos = 0 + private var pos: Long = 0 override def write(b: Int): Unit = { throw new UnsupportedOperationException() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index df2d6ad3b41a4..1e4531ef395ae 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -89,9 +89,9 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L - + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided - private val fileBufferSize = + private val fileBufferSize = sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 // Write metrics for current spill diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 3b9d14f9372b6..757dec66c203b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -23,12 +23,14 @@ import java.util.Comparator import scala.collection.mutable.ArrayBuffer import scala.collection.mutable +import com.google.common.annotations.VisibleForTesting import com.google.common.io.ByteStreams import org.apache.spark._ import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.storage.{BlockObjectWriter, BlockId} +import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter} +import org.apache.spark.storage.{BlockId, BlockObjectWriter} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -84,35 +86,40 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId} * each other for equality to merge values. * * - Users are expected to call stop() at the end to delete all the intermediate files. - * - * As a special case, if no Ordering and no Aggregator is given, and the number of partitions is - * less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort and just write to - * separate files for each partition each time we spill, similar to the HashShuffleWriter. We can - * then concatenate these files to produce a single sorted file, without having to serialize and - * de-serialize each item twice (as is needed during the merge). This speeds up the map side of - * groupBy, sort, etc operations since they do no partial aggregation. */ private[spark] class ExternalSorter[K, V, C]( aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, serializer: Option[Serializer] = None) - extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] { + extends Logging + with Spillable[WritablePartitionedPairCollection[K, C]] + with SortShuffleFileWriter[K, V] { + + private val conf = SparkEnv.get.conf private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) private val shouldPartition = numPartitions > 1 + private def getPartition(key: K): Int = { + if (shouldPartition) partitioner.get.getPartition(key) else 0 + } + + // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class. + // As a sanity check, make sure that we're not handling a shuffle which should use that path. + if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) { + throw new IllegalArgumentException("ExternalSorter should not be used to handle " + + " a sort that the BypassMergeSortShuffleWriter should handle") + } private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() - private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 - private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) // Size of object batches when reading/writing from serializers. // @@ -123,43 +130,28 @@ private[spark] class ExternalSorter[K, V, C]( // grow internal data structures by growing + copying every time the number of objects doubles. private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000) - private def getPartition(key: K): Int = { - if (shouldPartition) partitioner.get.getPartition(key) else 0 - } - - private val metaInitialRecords = 256 - private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB private val useSerializedPairBuffer = - !ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) && - ser.supportsRelocationOfSerializedObjects - + ordering.isEmpty && + conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) && + ser.supportsRelocationOfSerializedObjects + private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB + private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = { + if (useSerializedPairBuffer) { + new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance) + } else { + new PartitionedPairBuffer[K, C] + } + } // Data structures to store in-memory objects before we spill. Depending on whether we have an // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we // store them in an array buffer. private var map = new PartitionedAppendOnlyMap[K, C] - private var buffer = if (useSerializedPairBuffer) { - new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance) - } else { - new PartitionedPairBuffer[K, C] - } + private var buffer = newBuffer() // Total spilling statistics private var _diskBytesSpilled = 0L + def diskBytesSpilled: Long = _diskBytesSpilled - // Write metrics for current spill - private var curWriteMetrics: ShuffleWriteMetrics = _ - - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need - // local aggregation and sorting, write numPartitions files directly and just concatenate them - // at the end. This avoids doing serialization and deserialization twice to merge together the - // spilled files, which would happen with the normal code path. The downside is having multiple - // files open at a time and thus more memory allocated to buffers. - private val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - private val bypassMergeSort = - (numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty) - - // Array of file writers for each partition, used if bypassMergeSort is true and we've spilled - private var partitionWriters: Array[BlockObjectWriter] = null // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the @@ -174,6 +166,14 @@ private[spark] class ExternalSorter[K, V, C]( } }) + private def comparator: Option[Comparator[K]] = { + if (ordering.isDefined || aggregator.isDefined) { + Some(keyComparator) + } else { + None + } + } + // Information about a spilled file. Includes sizes in bytes of "batches" written by the // serializer as we periodically reset its stream, as well as number of elements in each // partition, used to efficiently keep track of partitions when merging. @@ -182,9 +182,10 @@ private[spark] class ExternalSorter[K, V, C]( blockId: BlockId, serializerBatchSizes: Array[Long], elementsPerPartition: Array[Long]) + private val spills = new ArrayBuffer[SpilledFile] - def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = { + override def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined @@ -202,15 +203,6 @@ private[spark] class ExternalSorter[K, V, C]( map.changeValue((getPartition(kv._1), kv._1), update) maybeSpillCollection(usingMap = true) } - } else if (bypassMergeSort) { - // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies - if (records.hasNext) { - spillToPartitionFiles( - WritablePartitionedIterator.fromIterator(records.map { kv => - ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) - }) - ) - } } else { // Stick values into our buffer while (records.hasNext) { @@ -238,46 +230,33 @@ private[spark] class ExternalSorter[K, V, C]( } } else { if (maybeSpill(buffer, buffer.estimateSize())) { - buffer = if (useSerializedPairBuffer) { - new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance) - } else { - new PartitionedPairBuffer[K, C] - } + buffer = newBuffer() } } } /** - * Spill the current in-memory collection to disk, adding a new file to spills, and clear it. - */ - override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = { - if (bypassMergeSort) { - spillToPartitionFiles(collection) - } else { - spillToMergeableFile(collection) - } - } - - /** - * Spill our in-memory collection to a sorted file that we can merge later (normal code path). - * We add this file into spilledFiles to find it later. - * - * This should not be invoked if bypassMergeSort is true. In that case, spillToPartitionedFiles() - * is used to write files for each partition. + * Spill our in-memory collection to a sorted file that we can merge later. + * We add this file into `spilledFiles` to find it later. * * @param collection whichever collection we're using (map or buffer) */ - private def spillToMergeableFile(collection: WritablePartitionedPairCollection[K, C]): Unit = { - assert(!bypassMergeSort) - + override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = { // Because these files may be read during shuffle, their compression must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use // createTempShuffleBlock here; see SPARK-3426 for more context. val (blockId, file) = diskBlockManager.createTempShuffleBlock() - curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter( - blockId, file, serInstance, fileBufferSize, curWriteMetrics) - var objectsWritten = 0 // Objects written since the last flush + + // These variables are reset after each flush + var objectsWritten: Long = 0 + var spillMetrics: ShuffleWriteMetrics = null + var writer: BlockObjectWriter = null + def openWriter(): Unit = { + assert (writer == null && spillMetrics == null) + spillMetrics = new ShuffleWriteMetrics + writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics) + } + openWriter() // List of batch sizes (bytes) in the order they are written to disk val batchSizes = new ArrayBuffer[Long] @@ -291,8 +270,9 @@ private[spark] class ExternalSorter[K, V, C]( val w = writer writer = null w.commitAndClose() - _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten - batchSizes.append(curWriteMetrics.shuffleBytesWritten) + _diskBytesSpilled += spillMetrics.shuffleBytesWritten + batchSizes.append(spillMetrics.shuffleBytesWritten) + spillMetrics = null objectsWritten = 0 } @@ -307,9 +287,7 @@ private[spark] class ExternalSorter[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter( - blockId, file, serInstance, fileBufferSize, curWriteMetrics) + openWriter() } } if (objectsWritten > 0) { @@ -336,46 +314,6 @@ private[spark] class ExternalSorter[K, V, C]( spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) } - /** - * Spill our in-memory collection to separate files, one for each partition. This is used when - * there's no aggregator and ordering and the number of partitions is small, because it allows - * writePartitionedFile to just concatenate files without deserializing data. - * - * @param collection whichever collection we're using (map or buffer) - */ - private def spillToPartitionFiles(collection: WritablePartitionedPairCollection[K, C]): Unit = { - spillToPartitionFiles(collection.writablePartitionedIterator()) - } - - private def spillToPartitionFiles(iterator: WritablePartitionedIterator): Unit = { - assert(bypassMergeSort) - - // Create our file writers if we haven't done so yet - if (partitionWriters == null) { - curWriteMetrics = new ShuffleWriteMetrics() - val openStartTime = System.nanoTime - partitionWriters = Array.fill(numPartitions) { - // Because these files may be read during shuffle, their compression must be controlled by - // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use - // createTempShuffleBlock here; see SPARK-3426 for more context. - val (blockId, file) = diskBlockManager.createTempShuffleBlock() - val writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, - curWriteMetrics) - writer.open() - } - // Creating the file to write to and creating a disk writer both involve interacting with - // the disk, and can take a long time in aggregate when we open many files, so should be - // included in the shuffle write time. - curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) - } - - // No need to sort stuff, just write each element out - while (iterator.hasNext) { - val partitionId = iterator.nextPartition() - iterator.writeNext(partitionWriters(partitionId)) - } - } - /** * Merge a sequence of sorted files, giving an iterator over partitions and then over elements * inside each partition. This can be used to either write out a new file or return data to @@ -665,8 +603,6 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * Exposed for testing purposes. - * * Return an iterator over all the data written to this object, grouped by partition and * aggregated by the requested aggregator. For each partition we then have an iterator over its * contents, and these are expected to be accessed in order (you can't "skip ahead" to one @@ -676,10 +612,11 @@ private[spark] class ExternalSorter[K, V, C]( * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. */ - def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { + @VisibleForTesting + def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer - if (spills.isEmpty && partitionWriters == null) { + if (spills.isEmpty) { // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps // we don't even need to sort by anything other than partition ID if (!ordering.isDefined) { @@ -689,13 +626,6 @@ private[spark] class ExternalSorter[K, V, C]( // We do need to sort by both partition ID and key groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator))) } - } else if (bypassMergeSort) { - // Read data from each partition file and merge it together with the data in memory; - // note that there's no ordering or aggregator in this case -- we just partition objects - val collIter = groupByPartition(collection.partitionedDestructiveSortedIterator(None)) - collIter.map { case (partitionId, values) => - (partitionId, values ++ readPartitionFile(partitionWriters(partitionId))) - } } else { // Merge spilled and in-memory data merge(spills, collection.partitionedDestructiveSortedIterator(comparator)) @@ -709,14 +639,13 @@ private[spark] class ExternalSorter[K, V, C]( /** * Write all the data added into this ExternalSorter into a file in the disk store. This is - * called by the SortShuffleWriter and can go through an efficient path of just concatenating - * binary files if we decided to avoid merge-sorting. + * called by the SortShuffleWriter. * * @param blockId block ID to write to. The index file will be blockId.name + ".index". * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ - def writePartitionedFile( + override def writePartitionedFile( blockId: BlockId, context: TaskContext, outputFile: File): Array[Long] = { @@ -724,28 +653,7 @@ private[spark] class ExternalSorter[K, V, C]( // Track location of each range in the output file val lengths = new Array[Long](numPartitions) - if (bypassMergeSort && partitionWriters != null) { - // We decided to write separate files for each partition, so just concatenate them. To keep - // this simple we spill out the current in-memory collection so that everything is in files. - spillToPartitionFiles(if (aggregator.isDefined) map else buffer) - partitionWriters.foreach(_.commitAndClose()) - val out = new FileOutputStream(outputFile, true) - val writeStartTime = System.nanoTime - util.Utils.tryWithSafeFinally { - for (i <- 0 until numPartitions) { - val in = new FileInputStream(partitionWriters(i).fileSegment().file) - util.Utils.tryWithSafeFinally { - lengths(i) = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) - } { - in.close() - } - } - } { - out.close() - context.taskMetrics.shuffleWriteMetrics.foreach( - _.incShuffleWriteTime(System.nanoTime - writeStartTime)) - } - } else if (spills.isEmpty && partitionWriters == null) { + if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer val it = collection.destructiveSortedWritablePartitionedIterator(comparator) @@ -761,7 +669,7 @@ private[spark] class ExternalSorter[K, V, C]( lengths(partitionId) = segment.length } } else { - // Not bypassing merge-sort; get an iterator by partition and just write everything directly. + // We must perform merge-sort; get an iterator by partition and write everything directly. for ((id, elements) <- this.partitionedIterator) { if (elements.hasNext) { val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, @@ -778,41 +686,15 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m => - if (curWriteMetrics != null) { - m.incShuffleBytesWritten(curWriteMetrics.shuffleBytesWritten) - m.incShuffleWriteTime(curWriteMetrics.shuffleWriteTime) - m.incShuffleRecordsWritten(curWriteMetrics.shuffleRecordsWritten) - } - } lengths } - /** - * Read a partition file back as an iterator (used in our iterator method) - */ - private def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { - if (writer.isOpen) { - writer.commitAndClose() - } - new PairIterator[K, C](blockManager.diskStore.getValues(writer.blockId, ser).get) - } - def stop(): Unit = { spills.foreach(s => s.file.delete()) spills.clear() - if (partitionWriters != null) { - partitionWriters.foreach { w => - w.revertPartialWritesAndClose() - diskBlockManager.getFile(w.blockId).delete() - } - partitionWriters = null - } } - def diskBytesSpilled: Long = _diskBytesSpilled - /** * Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*, * group together the pairs for each partition into a sub-iterator. @@ -826,14 +708,6 @@ private[spark] class ExternalSorter[K, V, C]( (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered))) } - private def comparator: Option[Comparator[K]] = { - if (ordering.isDefined || aggregator.isDefined) { - Some(keyComparator) - } else { - None - } - } - /** * An iterator that reads only the elements for a given partition ID from an underlying buffered * stream, assuming this partition is the next one to be read. Used to make it easier to return diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 1501111a06655..64e7102e3654c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -20,6 +20,8 @@ package org.apache.spark.util.collection import scala.reflect._ import com.google.common.hash.Hashing +import org.apache.spark.annotation.Private + /** * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never * removed. @@ -37,7 +39,7 @@ import com.google.common.hash.Hashing * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). */ -private[spark] +@Private class OpenHashSet[@specialized(Long, Int) T: ClassTag]( initialCapacity: Int, loadFactor: Double) @@ -110,6 +112,14 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( rehashIfNeeded(k, grow, move) } + def union(other: OpenHashSet[T]): OpenHashSet[T] = { + val iterator = other.iterator + while (iterator.hasNext) { + add(iterator.next()) + } + this + } + /** * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. * The caller is responsible for calling rehashIfNeeded. diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala index e2e2f1faae9d1..d0d25b43d0477 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala @@ -34,10 +34,6 @@ private[spark] class PartitionedAppendOnlyMap[K, V] destructiveSortedIterator(comparator) } - def writablePartitionedIterator(): WritablePartitionedIterator = { - WritablePartitionedIterator.fromIterator(super.iterator) - } - def insert(partition: Int, key: K, value: V): Unit = { update((partition, key), value) } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index e8332e1a87eac..5a6e9a9580e9b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -71,10 +71,6 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) iterator } - override def writablePartitionedIterator(): WritablePartitionedIterator = { - WritablePartitionedIterator.fromIterator(iterator) - } - private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] { var pos = 0 diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala index ac9ea6393628f..862408b7a4d21 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala @@ -41,6 +41,13 @@ import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ * * Currently, only sorting by partition is supported. * + * Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across + * two integers: + * + * +-------------+------------+------------+-------------+ + * | keyStart | keyValLen | partitionId | + * +-------------+------------+------------+-------------+ + * * @param metaInitialRecords The initial number of entries in the metadata buffer. * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records. * @param serializerInstance the serializer used for serializing inserted records. @@ -68,19 +75,15 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( } val keyStart = kvBuffer.size - if (keyStart < 0) { - throw new Exception(s"Can't grow buffer beyond ${1 << 31} bytes") - } kvSerializationStream.writeKey[Any](key) - kvSerializationStream.flush() - val valueStart = kvBuffer.size kvSerializationStream.writeValue[Any](value) kvSerializationStream.flush() - val valueEnd = kvBuffer.size + val keyValLen = (kvBuffer.size - keyStart).toInt - metaBuffer.put(keyStart) - metaBuffer.put(valueStart) - metaBuffer.put(valueEnd) + // keyStart, a long, gets split across two ints + metaBuffer.put(keyStart.toInt) + metaBuffer.put((keyStart >> 32).toInt) + metaBuffer.put(keyValLen) metaBuffer.put(partition) } @@ -114,24 +117,20 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( } } - override def estimateSize: Long = metaBuffer.capacity * 4 + kvBuffer.capacity + override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) : WritablePartitionedIterator = { sort(keyComparator) - writablePartitionedIterator - } - - override def writablePartitionedIterator(): WritablePartitionedIterator = { new WritablePartitionedIterator { // current position in the meta buffer in ints var pos = 0 def writeNext(writer: BlockObjectWriter): Unit = { - val keyStart = metaBuffer.get(pos + KEY_START) - val valueEnd = metaBuffer.get(pos + VAL_END) + val keyStart = getKeyStartPos(metaBuffer, pos) + val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN) pos += RECORD_SIZE - kvBuffer.read(keyStart, writer, valueEnd - keyStart) + kvBuffer.read(keyStart, writer, keyValLen) writer.recordWritten() } def nextPartition(): Int = metaBuffer.get(pos + PARTITION) @@ -163,9 +162,11 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer) extends InputStream { + import PartitionedSerializedPairBuffer._ + private var metaBufferPos = 0 private var kvBufferPos = - if (metaBuffer.position > 0) metaBuffer.get(metaBufferPos + KEY_START) else 0 + if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0 override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length) @@ -173,13 +174,14 @@ private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: Chained if (metaBufferPos >= metaBuffer.position) { return -1 } - val bytesRemainingInRecord = metaBuffer.get(metaBufferPos + VAL_END) - kvBufferPos + val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) - + (kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt val toRead = math.min(bytesRemainingInRecord, len) kvBuffer.read(kvBufferPos, bytes, offs, toRead) if (toRead == bytesRemainingInRecord) { metaBufferPos += RECORD_SIZE if (metaBufferPos < metaBuffer.position) { - kvBufferPos = metaBuffer.get(metaBufferPos + KEY_START) + kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos) } } else { kvBufferPos += toRead @@ -246,9 +248,14 @@ private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuf } private[spark] object PartitionedSerializedPairBuffer { - val KEY_START = 0 - val VAL_START = 1 - val VAL_END = 2 + val KEY_START = 0 // keyStart, a long, gets split across two ints + val KEY_VAL_LEN = 2 val PARTITION = 3 - val RECORD_SIZE = Seq(KEY_START, VAL_START, VAL_END, PARTITION).size // num ints of metadata + val RECORD_SIZE = PARTITION + 1 // num ints of metadata + + def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = { + val lower32 = metaBuffer.get(metaBufferPos + KEY_START) + val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1) + (upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL) + } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala index 4f0bf8384afc9..9a7a5a4e74868 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala @@ -90,9 +90,9 @@ class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K, override def swap(data: Array[T], pos0: Int, pos1: Int) { val tmpKey = data(2 * pos0) val tmpVal = data(2 * pos0 + 1) - data(2 * pos0) = data(2 * pos1) + data(2 * pos0) = data(2 * pos1) data(2 * pos0 + 1) = data(2 * pos1 + 1) - data(2 * pos1) = tmpKey + data(2 * pos1) = tmpKey data(2 * pos1 + 1) = tmpVal } diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index f26d1618c9200..7bc59898658e4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -47,13 +47,20 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { */ def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) : WritablePartitionedIterator = { - WritablePartitionedIterator.fromIterator(partitionedDestructiveSortedIterator(keyComparator)) - } + val it = partitionedDestructiveSortedIterator(keyComparator) + new WritablePartitionedIterator { + private[this] var cur = if (it.hasNext) it.next() else null - /** - * Iterate through the data and write out the elements instead of returning them. - */ - def writablePartitionedIterator(): WritablePartitionedIterator + def writeNext(writer: BlockObjectWriter): Unit = { + writer.write(cur._1._2, cur._2) + cur = if (it.hasNext) it.next() else null + } + + def hasNext(): Boolean = cur != null + + def nextPartition(): Int = cur._1._1 + } + } } private[spark] object WritablePartitionedPairCollection { @@ -94,20 +101,3 @@ private[spark] trait WritablePartitionedIterator { def nextPartition(): Int } - -private[spark] object WritablePartitionedIterator { - def fromIterator(it: Iterator[((Int, _), _)]): WritablePartitionedIterator = { - new WritablePartitionedIterator { - var cur = if (it.hasNext) it.next() else null - - def writeNext(writer: BlockObjectWriter): Unit = { - writer.write(cur._1._2, cur._2) - cur = if (it.hasNext) it.next() else null - } - - def hasNext(): Boolean = cur != null - - def nextPartition(): Int = cur._1._1 - } - } -} diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index 9e29bf9d61f17..effe6fa2adcfa 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -196,7 +196,7 @@ private[spark] object StratifiedSamplingUtils extends Logging { * * The sampling function has a unique seed per partition. */ - def getBernoulliSamplingFunction[K, V](rdd: RDD[(K, V)], + def getBernoulliSamplingFunction[K, V](rdd: RDD[(K, V)], fractions: Map[K, Double], exact: Boolean, seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index c2089b0e56a1f..dfd86d3e51e7d 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -212,6 +212,8 @@ public int getPartition(Object key) { JavaPairRDD repartitioned = rdd.repartitionAndSortWithinPartitions(partitioner); + Assert.assertTrue(repartitioned.partitioner().isPresent()); + Assert.assertEquals(repartitioned.partitioner().get(), partitioner); List>> partitions = repartitioned.glom().collect(); Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2(0, 5), new Tuple2(0, 8), new Tuple2(2, 6))); diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index ce4fe80b66aa5..d575bf2f284b9 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -7,6 +7,22 @@ "sparkUser" : "irashid", "completed" : true } ] +}, { + "id" : "local-1430917381535", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:00.950GMT", + "sparkUser" : "irashid", + "completed" : true + }, { + "attemptId" : "1", + "startTime" : "2015-05-06T13:03:00.880GMT", + "endTime" : "2015-05-06T13:03:00.890GMT", + "sparkUser" : "irashid", + "completed" : true + } ] }, { "id" : "local-1426533911241", "name" : "Spark shell", diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index ce4fe80b66aa5..d575bf2f284b9 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -7,6 +7,22 @@ "sparkUser" : "irashid", "completed" : true } ] +}, { + "id" : "local-1430917381535", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:00.950GMT", + "sparkUser" : "irashid", + "completed" : true + }, { + "attemptId" : "1", + "startTime" : "2015-05-06T13:03:00.880GMT", + "endTime" : "2015-05-06T13:03:00.890GMT", + "sparkUser" : "irashid", + "completed" : true + } ] }, { "id" : "local-1426533911241", "name" : "Spark shell", diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index dca86fe5f7e6a..15c2de8ef99ea 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -7,6 +7,22 @@ "sparkUser" : "irashid", "completed" : true } ] +}, { + "id" : "local-1430917381535", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:00.950GMT", + "sparkUser" : "irashid", + "completed" : true + }, { + "attemptId" : "1", + "startTime" : "2015-05-06T13:03:00.880GMT", + "endTime" : "2015-05-06T13:03:00.890GMT", + "sparkUser" : "irashid", + "completed" : true + } ] }, { "id" : "local-1426533911241", "name" : "Spark shell", @@ -24,12 +40,14 @@ "completed" : true } ] }, { - "id" : "local-1425081759269", - "name" : "Spark shell", - "attempts" : [ { - "startTime" : "2015-02-28T00:02:38.277GMT", - "endTime" : "2015-02-28T00:02:46.912GMT", - "sparkUser" : "irashid", - "completed" : true - } ] + "id": "local-1425081759269", + "name": "Spark shell", + "attempts": [ + { + "startTime": "2015-02-28T00:02:38.277GMT", + "endTime": "2015-02-28T00:02:46.912GMT", + "sparkUser": "irashid", + "completed": true + } + ] } ] \ No newline at end of file diff --git a/core/src/test/resources/spark-events/local-1430917381535_1 b/core/src/test/resources/spark-events/local-1430917381535_1 new file mode 100644 index 0000000000000..d5a1303344825 --- /dev/null +++ b/core/src/test/resources/spark-events/local-1430917381535_1 @@ -0,0 +1,5 @@ +{"Event":"SparkListenerLogStart","Spark Version":"1.4.0-SNAPSHOT"} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"localhost","Port":61103},"Maximum Memory":278019440,"Timestamp":1430917380880} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre","Java Version":"1.8.0_25 (Oracle Corporation)","Scala Version":"version 2.10.4"},"Spark Properties":{"spark.driver.host":"192.168.1.102","spark.eventLog.enabled":"true","spark.driver.port":"61101","spark.repl.class.uri":"http://192.168.1.102:61100","spark.jars":"","spark.app.name":"Spark shell","spark.scheduler.mode":"FIFO","spark.executor.id":"driver","spark.master":"local[*]","spark.eventLog.dir":"/Users/irashid/github/kraps/core/src/test/resources/spark-events","spark.fileserver.uri":"http://192.168.1.102:61102","spark.tachyonStore.folderName":"spark-aaaf41b3-d1dd-447f-8951-acf51490758b","spark.app.id":"local-1430917381534"},"System Properties":{"java.io.tmpdir":"/var/folders/36/m29jw1z95qv4ywb1c4n0rz000000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"/Users/irashid","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib","user.dir":"/Users/irashid/github/spark","java.library.path":"/Users/irashid/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.25-b02","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_25-b17","java.vm.info":"mixed mode","java.ext.dirs":"/Users/irashid/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"America/Chicago","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.9.5","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"US","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"en","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"http://java.oracle.com/","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"http://bugreport.sun.com/bugreport/","user.name":"irashid","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --conf spark.eventLog.enabled=true --conf spark.eventLog.dir=/Users/irashid/github/kraps/core/src/test/resources/spark-events --class org.apache.spark.repl.Main spark-shell","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre","java.version":"1.8.0_25","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/etc/hadoop":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-rdbms-3.2.9.jar":"System Classpath","/Users/irashid/github/spark/conf/":"System Classpath","/Users/irashid/github/spark/assembly/target/scala-2.10/spark-assembly-1.4.0-SNAPSHOT-hadoop2.5.0.jar":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-core-3.2.10.jar":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-api-jdo-3.2.6.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"local-1430917381535","Timestamp":1430917380880,"User":"irashid","App Attempt ID":"1"} +{"Event":"SparkListenerApplicationEnd","Timestamp":1430917380890} \ No newline at end of file diff --git a/core/src/test/resources/spark-events/local-1430917381535_2 b/core/src/test/resources/spark-events/local-1430917381535_2 new file mode 100644 index 0000000000000..abb637a22e1e3 --- /dev/null +++ b/core/src/test/resources/spark-events/local-1430917381535_2 @@ -0,0 +1,5 @@ +{"Event":"SparkListenerLogStart","Spark Version":"1.4.0-SNAPSHOT"} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"localhost","Port":61103},"Maximum Memory":278019440,"Timestamp":1430917380893} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre","Java Version":"1.8.0_25 (Oracle Corporation)","Scala Version":"version 2.10.4"},"Spark Properties":{"spark.driver.host":"192.168.1.102","spark.eventLog.enabled":"true","spark.driver.port":"61101","spark.repl.class.uri":"http://192.168.1.102:61100","spark.jars":"","spark.app.name":"Spark shell","spark.scheduler.mode":"FIFO","spark.executor.id":"driver","spark.master":"local[*]","spark.eventLog.dir":"/Users/irashid/github/kraps/core/src/test/resources/spark-events","spark.fileserver.uri":"http://192.168.1.102:61102","spark.tachyonStore.folderName":"spark-aaaf41b3-d1dd-447f-8951-acf51490758b","spark.app.id":"local-1430917381534"},"System Properties":{"java.io.tmpdir":"/var/folders/36/m29jw1z95qv4ywb1c4n0rz000000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"/Users/irashid","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib","user.dir":"/Users/irashid/github/spark","java.library.path":"/Users/irashid/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.25-b02","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_25-b17","java.vm.info":"mixed mode","java.ext.dirs":"/Users/irashid/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"America/Chicago","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.9.5","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"US","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"en","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"http://java.oracle.com/","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"http://bugreport.sun.com/bugreport/","user.name":"irashid","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --conf spark.eventLog.enabled=true --conf spark.eventLog.dir=/Users/irashid/github/kraps/core/src/test/resources/spark-events --class org.apache.spark.repl.Main spark-shell","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre","java.version":"1.8.0_25","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/etc/hadoop":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-rdbms-3.2.9.jar":"System Classpath","/Users/irashid/github/spark/conf/":"System Classpath","/Users/irashid/github/spark/assembly/target/scala-2.10/spark-assembly-1.4.0-SNAPSHOT-hadoop2.5.0.jar":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-core-3.2.10.jar":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-api-jdo-3.2.6.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"local-1430917381535","Timestamp":1430917380893,"User":"irashid","App Attempt ID":"2"} +{"Event":"SparkListenerApplicationEnd","Timestamp":1430917380950} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 75399461f2a5f..e942d6579b2fd 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark import scala.collection.mutable import scala.ref.WeakReference -import org.scalatest.FunSuite import org.scalatest.Matchers -class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { +class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = @@ -103,7 +102,7 @@ class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { sc = new SparkContext("local[" + nThreads + "]", "test") val setAcc = sc.accumulableCollection(mutable.HashSet[Int]()) val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]()) - val mapAcc = sc.accumulableCollection(mutable.HashMap[Int,String]()) + val mapAcc = sc.accumulableCollection(mutable.HashMap[Int, String]()) val d = sc.parallelize((1 to maxI) ++ (1 to maxI)) d.foreach { x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)} diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 668ddf9f5f0a9..af81e46a657d3 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.mock.MockitoSugar import org.apache.spark.executor.DataReadMethod @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage._ // TODO: Test the CacheManager's thread-safety aspects -class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter +class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter with MockitoSugar { var blockManager: BlockManager = _ diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index e1faddeabec79..d1761a48babbc 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,13 +21,11 @@ import java.io.File import scala.reflect.ClassTag -import org.scalatest.FunSuite - import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils -class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { +class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging { var checkpointDir: File = _ val partitioner = new HashPartitioner(2) @@ -218,10 +216,10 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { val pairRDD = generateFatPairRDD() pairRDD.checkpoint() val unionRDD = new PartitionerAwareUnionRDD(sc, Array(pairRDD)) - val partitionBeforeCheckpoint = serializeDeserialize( + val partitionBeforeCheckpoint = serializeDeserialize( unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) pairRDD.count() - val partitionAfterCheckpoint = serializeDeserialize( + val partitionAfterCheckpoint = serializeDeserialize( unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) assert( partitionBeforeCheckpoint.parents.head.getClass != diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 0922a2c3599cc..501fe186bfd7c 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.{HashSet, SynchronizedSet} import scala.language.existentials import scala.util.Random -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.{PatienceConfiguration, Eventually} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -44,7 +44,7 @@ import org.apache.spark.storage.ShuffleIndexBlockId * config options, in particular, a different shuffle manager class */ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[HashShuffleManager]) - extends FunSuite with BeforeAndAfter with LocalSparkContext + extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { implicit val defaultTimeout = timeout(10000 millis) val conf = new SparkConf() @@ -158,7 +158,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { rdd.count() // Test that GC does not cause RDD cleanup due to a strong reference - val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) runGC() intercept[Exception] { preGCTester.assertCleanup()(timeout(1000 millis)) @@ -195,7 +195,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { var broadcast = newBroadcast() // Test that GC does not cause broadcast cleanup due to a strong reference - val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) runGC() intercept[Exception] { preGCTester.assertCleanup()(timeout(1000 millis)) @@ -267,7 +267,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { val shuffleIds = 0 until sc.newShuffleId val broadcastIds = broadcastBuffer.map(_.id) - val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() intercept[Exception] { preGCTester.assertCleanup()(timeout(1000 millis)) diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 96a9c207ad022..9c191ed52206d 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark -import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} @@ -28,7 +27,7 @@ class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -class DistributedSuite extends FunSuite with Matchers with LocalSparkContext { +class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index c42dfbc82ada4..b2262033ca238 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark import java.io.File -import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatest.time.SpanSugar._ import org.apache.spark.util.Utils -class DriverSuite extends FunSuite with Timeouts { +class DriverSuite extends SparkFunSuite with Timeouts { ignore("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 84f787ee3715d..803e1831bb269 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import scala.collection.mutable -import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -28,7 +28,11 @@ import org.apache.spark.util.ManualClock /** * Test add and remove behavior of ExecutorAllocationManager. */ -class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter { +class ExecutorAllocationManagerSuite + extends SparkFunSuite + with LocalSparkContext + with BeforeAndAfter { + import ExecutorAllocationManager._ import ExecutorAllocationManagerSuite._ @@ -86,7 +90,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("add executors") { - sc = createSparkContext(1, 10) + sc = createSparkContext(1, 10, 1) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) @@ -131,7 +135,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("add executors capped by num pending tasks") { - sc = createSparkContext(0, 10) + sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 5))) @@ -182,7 +186,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("cancel pending executors when no longer needed") { - sc = createSparkContext(0, 10) + sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 5))) @@ -209,7 +213,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("remove executors") { - sc = createSparkContext(5, 10) + sc = createSparkContext(5, 10, 5) val manager = sc.executorAllocationManager.get (1 to 10).map(_.toString).foreach { id => onExecutorAdded(manager, id) } @@ -259,7 +263,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test ("interleaving add and remove") { - sc = createSparkContext(5, 10) + sc = createSparkContext(5, 10, 5) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) @@ -327,7 +331,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("starting/canceling add timer") { - sc = createSparkContext(2, 10) + sc = createSparkContext(2, 10, 2) val clock = new ManualClock(8888L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -359,7 +363,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("starting/canceling remove timers") { - sc = createSparkContext(2, 10) + sc = createSparkContext(2, 10, 2) val clock = new ManualClock(14444L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -406,7 +410,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("mock polling loop with no events") { - sc = createSparkContext(0, 20) + sc = createSparkContext(0, 20, 0) val manager = sc.executorAllocationManager.get val clock = new ManualClock(2020L) manager.setClock(clock) @@ -432,7 +436,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("mock polling loop add behavior") { - sc = createSparkContext(0, 20) + sc = createSparkContext(0, 20, 0) val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -482,7 +486,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("mock polling loop remove behavior") { - sc = createSparkContext(1, 20) + sc = createSparkContext(1, 20, 1) val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -543,7 +547,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("listeners trigger add executors correctly") { - sc = createSparkContext(2, 10) + sc = createSparkContext(2, 10, 2) val manager = sc.executorAllocationManager.get assert(addTime(manager) === NOT_SET) @@ -573,7 +577,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("listeners trigger remove executors correctly") { - sc = createSparkContext(2, 10) + sc = createSparkContext(2, 10, 2) val manager = sc.executorAllocationManager.get assert(removeTimes(manager).isEmpty) @@ -604,7 +608,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("listeners trigger add and remove executor callbacks correctly") { - sc = createSparkContext(2, 10) + sc = createSparkContext(2, 10, 2) val manager = sc.executorAllocationManager.get assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) @@ -637,7 +641,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("SPARK-4951: call onTaskStart before onBlockManagerAdded") { - sc = createSparkContext(2, 10) + sc = createSparkContext(2, 10, 2) val manager = sc.executorAllocationManager.get assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) @@ -673,7 +677,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("avoid ramp up when target < running executors") { - sc = createSparkContext(0, 100000) + sc = createSparkContext(0, 100000, 0) val manager = sc.executorAllocationManager.get val stage1 = createStageInfo(0, 1000) sc.listenerBus.postToAll(SparkListenerStageSubmitted(stage1)) @@ -697,13 +701,67 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit assert(numExecutorsTarget(manager) === 16) } - private def createSparkContext(minExecutors: Int = 1, maxExecutors: Int = 5): SparkContext = { + test("avoid ramp down initial executors until first job is submitted") { + sc = createSparkContext(2, 5, 3) + val manager = sc.executorAllocationManager.get + val clock = new ManualClock(10000L) + manager.setClock(clock) + + // Verify the initial number of executors + assert(numExecutorsTarget(manager) === 3) + schedule(manager) + // Verify whether the initial number of executors is kept with no pending tasks + assert(numExecutorsTarget(manager) === 3) + + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 2))) + clock.advance(100L) + + assert(maxNumExecutorsNeeded(manager) === 2) + schedule(manager) + + // Verify that current number of executors should be ramp down when first job is submitted + assert(numExecutorsTarget(manager) === 2) + } + + test("avoid ramp down initial executors until idle executor is timeout") { + sc = createSparkContext(2, 5, 3) + val manager = sc.executorAllocationManager.get + val clock = new ManualClock(10000L) + manager.setClock(clock) + + // Verify the initial number of executors + assert(numExecutorsTarget(manager) === 3) + schedule(manager) + // Verify the initial number of executors is kept when no pending tasks + assert(numExecutorsTarget(manager) === 3) + (0 until 3).foreach { i => + onExecutorAdded(manager, s"executor-$i") + } + + clock.advance(executorIdleTimeout * 1000) + + assert(maxNumExecutorsNeeded(manager) === 0) + schedule(manager) + // Verify executor is timeout but numExecutorsTarget is not recalculated + assert(numExecutorsTarget(manager) === 3) + + // Schedule again to recalculate the numExecutorsTarget after executor is timeout + schedule(manager) + // Verify that current number of executors should be ramp down when executor is timeout + assert(numExecutorsTarget(manager) === 2) + } + + private def createSparkContext( + minExecutors: Int = 1, + maxExecutors: Int = 5, + initialExecutors: Int = 1): SparkContext = { val conf = new SparkConf() .setMaster("local") .setAppName("test-executor-allocation-manager") .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) + .set("spark.dynamicAllocation.initialExecutors", initialExecutors.toString) .set("spark.dynamicAllocation.schedulerBacklogTimeout", s"${schedulerBacklogTimeout.toString}s") .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", @@ -787,6 +845,10 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { manager invokePrivate _schedule() } + private def maxNumExecutorsNeeded(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _maxNumExecutorsNeeded() + } + private def addExecutors(manager: ExecutorAllocationManager): Int = { val maxNumExecutorsNeeded = manager invokePrivate _maxNumExecutorsNeeded() manager invokePrivate _addExecutors(maxNumExecutorsNeeded) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index bac6fdbcdc976..140012226fdbb 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -55,6 +55,14 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { sc.env.blockManager.externalShuffleServiceEnabled should equal(true) sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient]) + // In a slow machine, one slave may register hundreds of milliseconds ahead of the other one. + // If we don't wait for all slaves, it's possible that only one executor runs all jobs. Then + // all shuffle blocks will be in this executor, ShuffleBlockFetcherIterator will directly fetch + // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. + // In this case, we won't receive FetchFailed. And it will make this test fail. + // Therefore, we should wait until all slaves are up + sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) rdd.count() diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 1212d0b43207d..a8c8c6f73fb5a 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark -import org.scalatest.FunSuite - import org.apache.spark.util.NonSerializable import java.io.NotSerializableException @@ -38,7 +36,7 @@ object FailureSuiteState { } } -class FailureSuite extends FunSuite with LocalSparkContext { +class FailureSuite extends SparkFunSuite with LocalSparkContext { // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. @@ -57,7 +55,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { FailureSuiteState.synchronized { assert(FailureSuiteState.tasksRun === 4) } - assert(results.toList === List(1,4,9)) + assert(results.toList === List(1, 4, 9)) FailureSuiteState.clear() } @@ -119,7 +117,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { sc.parallelize(1 to 10, 2).map(x => a).count() } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("NotSerializableException") || + assert(thrown.getMessage.contains("NotSerializableException") || thrown.getCause.getClass === classOf[NotSerializableException]) // Non-serializable closure in an earlier stage @@ -127,7 +125,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() } assert(thrown1.getClass === classOf[SparkException]) - assert(thrown1.getMessage.contains("NotSerializableException") || + assert(thrown1.getMessage.contains("NotSerializableException") || thrown1.getCause.getClass === classOf[NotSerializableException]) // Non-serializable closure in foreach function @@ -135,7 +133,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { sc.parallelize(1 to 10, 2).foreach(x => println(a)) } assert(thrown2.getClass === classOf[SparkException]) - assert(thrown2.getMessage.contains("NotSerializableException") || + assert(thrown2.getMessage.contains("NotSerializableException") || thrown2.getCause.getClass === classOf[NotSerializableException]) FailureSuiteState.clear() diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index c0439f934813e..6e65b0a8f6c76 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -24,13 +24,12 @@ import javax.net.ssl.SSLException import com.google.common.io.{ByteStreams, Files} import org.apache.commons.lang3.RandomUtils -import org.scalatest.FunSuite import org.apache.spark.util.Utils import SSLSampleConfigs._ -class FileServerSuite extends FunSuite with LocalSparkContext { +class FileServerSuite extends SparkFunSuite with LocalSparkContext { @transient var tmpDir: File = _ @transient var tmpFile: File = _ @@ -81,7 +80,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { test("Distributing files locally") { sc = new SparkContext("local[4]", "test", newConf) sc.addFile(tmpFile.toString) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) val result = sc.parallelize(testData).reduceByKey { val path = SparkFiles.get("FileServerSuite.txt") val in = new BufferedReader(new FileReader(path)) @@ -89,7 +88,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { in.close() _ * fileVal + _ * fileVal }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) + assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) } test("Distributing files locally security On") { @@ -100,7 +99,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { sc.addFile(tmpFile.toString) assert(sc.env.securityManager.isAuthenticationEnabled() === true) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) val result = sc.parallelize(testData).reduceByKey { val path = SparkFiles.get("FileServerSuite.txt") val in = new BufferedReader(new FileReader(path)) @@ -108,14 +107,14 @@ class FileServerSuite extends FunSuite with LocalSparkContext { in.close() _ * fileVal + _ * fileVal }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) + assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) } test("Distributing files locally using URL as input") { // addFile("file:///....") sc = new SparkContext("local[4]", "test", newConf) sc.addFile(new File(tmpFile.toString).toURI.toString) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) val result = sc.parallelize(testData).reduceByKey { val path = SparkFiles.get("FileServerSuite.txt") val in = new BufferedReader(new FileReader(path)) @@ -123,7 +122,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { in.close() _ * fileVal + _ * fileVal }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) + assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) } test ("Dynamically adding JARS locally") { @@ -140,7 +139,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { test("Distributing files on a standalone cluster") { sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) sc.addFile(tmpFile.toString) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) val result = sc.parallelize(testData).reduceByKey { val path = SparkFiles.get("FileServerSuite.txt") val in = new BufferedReader(new FileReader(path)) @@ -148,13 +147,13 @@ class FileServerSuite extends FunSuite with LocalSparkContext { in.close() _ * fileVal + _ * fileVal }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) + assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) } test ("Dynamically adding JARS on a standalone cluster") { sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) sc.addJar(tmpJarUrl) - val testData = Array((1,1)) + val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { throw new SparkException("jar not added") @@ -165,7 +164,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { test ("Dynamically adding JARS on a standalone cluster using local: URL") { sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) sc.addJar(tmpJarUrl.replace("file", "local")) - val testData = Array((1,1)) + val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { throw new SparkException("jar not added") diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index c8f08eed47c76..1d8fade90f398 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -30,12 +30,11 @@ import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.scalatest.FunSuite import org.apache.spark.rdd.{NewHadoopRDD, HadoopRDD} import org.apache.spark.util.Utils -class FileSuite extends FunSuite with LocalSparkContext { +class FileSuite extends SparkFunSuite with LocalSparkContext { var tempDir: File = _ override def beforeEach() { @@ -334,7 +333,7 @@ class FileSuite extends FunSuite with LocalSparkContext { } val copyRdd = mappedRdd.flatMap { curData: (String, PortableDataStream) => - for(i <- 1 to numOfCopies) yield (i, curData._2) + for (i <- 1 to numOfCopies) yield (i, curData._2) } val copyArr: Array[(Int, PortableDataStream)] = copyRdd.collect() diff --git a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala index f5cdb01ec9504..1102aea96b548 100644 --- a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala +++ b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala @@ -20,10 +20,14 @@ package org.apache.spark import scala.concurrent.Await import scala.concurrent.duration.Duration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} -class FutureActionSuite extends FunSuite with BeforeAndAfter with Matchers with LocalSparkContext { +class FutureActionSuite + extends SparkFunSuite + with BeforeAndAfter + with Matchers + with LocalSparkContext { before { sc = new SparkContext("local", "FutureActionSuite") diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index b789912e9ebef..911b3bddd1836 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -22,7 +22,6 @@ import scala.language.postfixOps import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId -import org.scalatest.FunSuite import org.mockito.Mockito.{mock, spy, verify, when} import org.mockito.Matchers import org.mockito.Matchers._ @@ -31,7 +30,7 @@ import org.apache.spark.scheduler.TaskScheduler import org.apache.spark.util.RpcUtils import org.scalatest.concurrent.Eventually._ -class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { +class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext { test("HeartbeatReceiver") { sc = spy(new SparkContext("local[2]", "test")) diff --git a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala index 51348c039b5c9..4399f25626472 100644 --- a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark -import org.scalatest.FunSuite - import org.apache.spark.rdd.RDD -class ImplicitOrderingSuite extends FunSuite with LocalSparkContext { +class ImplicitOrderingSuite extends SparkFunSuite with LocalSparkContext { // Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should. test("basic inference of Orderings"){ sc = new SparkContext("local", "test") @@ -29,11 +27,11 @@ class ImplicitOrderingSuite extends FunSuite with LocalSparkContext { // These RDD methods are in the companion object so that the unserializable ScalaTest Engine // won't be reachable from the closure object - + // Infer orderings after basic maps to particular types val basicMapExpectations = ImplicitOrderingSuite.basicMapExpectations(rdd) basicMapExpectations.map({case (met, explain) => assert(met, explain)}) - + // Infer orderings for other RDD methods val otherRDDMethodExpectations = ImplicitOrderingSuite.otherRDDMethodExpectations(rdd) otherRDDMethodExpectations.map({case (met, explain) => assert(met, explain)}) @@ -44,36 +42,36 @@ private object ImplicitOrderingSuite { class NonOrderedClass {} class ComparableClass extends Comparable[ComparableClass] { - override def compareTo(o: ComparableClass): Int = ??? + override def compareTo(o: ComparableClass): Int = throw new UnsupportedOperationException } class OrderedClass extends Ordered[OrderedClass] { - override def compare(o: OrderedClass): Int = ??? + override def compare(o: OrderedClass): Int = throw new UnsupportedOperationException } - + def basicMapExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { - List((rdd.map(x => (x, x)).keyOrdering.isDefined, + List((rdd.map(x => (x, x)).keyOrdering.isDefined, "rdd.map(x => (x, x)).keyOrdering.isDefined"), - (rdd.map(x => (1, x)).keyOrdering.isDefined, + (rdd.map(x => (1, x)).keyOrdering.isDefined, "rdd.map(x => (1, x)).keyOrdering.isDefined"), - (rdd.map(x => (x.toString, x)).keyOrdering.isDefined, + (rdd.map(x => (x.toString, x)).keyOrdering.isDefined, "rdd.map(x => (x.toString, x)).keyOrdering.isDefined"), - (rdd.map(x => (null, x)).keyOrdering.isDefined, + (rdd.map(x => (null, x)).keyOrdering.isDefined, "rdd.map(x => (null, x)).keyOrdering.isDefined"), - (rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty, + (rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty, "rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty"), - (rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined, + (rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined, "rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined"), - (rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined, + (rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined, "rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined")) } - + def otherRDDMethodExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { - List((rdd.groupBy(x => x).keyOrdering.isDefined, + List((rdd.groupBy(x => x).keyOrdering.isDefined, "rdd.groupBy(x => x).keyOrdering.isDefined"), - (rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty, + (rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty, "rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty"), - (rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined, + (rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined, "rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined"), (rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined, "rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined"), diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index ae17fc60e4a43..340a9e327107e 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -24,7 +24,7 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.future -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} @@ -34,7 +34,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} * (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers * in both FIFO and fair scheduling modes. */ -class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter +class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAfter with LocalSparkContext { override def afterEach() { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 6ed057a7cab97..1fab69678d040 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark import org.mockito.Mockito._ import org.mockito.Matchers.{any, isA} -import org.scalatest.FunSuite import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId -class MapOutputTrackerSuite extends FunSuite { +class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 47e3bf6e1ac41..3316f561a4949 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer import scala.math.abs -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.PrivateMethodTester import org.apache.spark.rdd.RDD import org.apache.spark.util.StatCounter -class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMethodTester { +class PartitioningSuite extends SparkFunSuite with SharedSparkContext with PrivateMethodTester { test("HashPartitioner equality") { val p2 = new HashPartitioner(2) diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 93f46ef11c0e2..376481ba541fa 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -21,9 +21,9 @@ import java.io.File import com.google.common.io.Files import org.apache.spark.util.Utils -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { +class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { test("test resolving property file as spark conf ") { val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 308b9ea17708d..1a099da2c6c8e 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -34,7 +34,7 @@ object SSLSampleConfigs { conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") conf.set("spark.ssl.enabledAlgorithms", - "TLS_RSA_WITH_AES_128_CBC_SHA, SSL_RSA_WITH_DES_CBC_SHA") + "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") conf.set("spark.ssl.protocol", "TLSv1") conf } @@ -48,7 +48,7 @@ object SSLSampleConfigs { conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") conf.set("spark.ssl.enabledAlgorithms", - "TLS_RSA_WITH_AES_128_CBC_SHA, SSL_RSA_WITH_DES_CBC_SHA") + "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") conf.set("spark.ssl.protocol", "TLSv1") conf } diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 62cb7649c0284..e9b64aa82a17a 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -19,11 +19,9 @@ package org.apache.spark import java.io.File -import org.scalatest.FunSuite - import org.apache.spark.util.Utils -class SecurityManagerSuite extends FunSuite { +class SecurityManagerSuite extends SparkFunSuite { test("set security with conf") { val conf = new SparkConf @@ -147,7 +145,7 @@ class SecurityManagerSuite extends FunSuite { assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1")) assert(securityManager.fileServerSSLOptions.enabledAlgorithms === - Set("TLS_RSA_WITH_AES_128_CBC_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) assert(securityManager.akkaSSLOptions.trustStore.isDefined === true) assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore") @@ -158,7 +156,7 @@ class SecurityManagerSuite extends FunSuite { assert(securityManager.akkaSSLOptions.keyPassword === Some("password")) assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1")) assert(securityManager.akkaSSLOptions.enabledAlgorithms === - Set("TLS_RSA_WITH_AES_128_CBC_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) } test("ssl off setup") { diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index d7180516029d5..c3c2b1ffc1efa 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark -import org.scalatest.FunSuite import org.scalatest.Matchers import org.apache.spark.ShuffleSuite.NonJavaSerializableClass import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId} import org.apache.spark.util.MutablePair -abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { +abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { val conf = new SparkConf(loadDefaults = false) @@ -282,6 +282,39 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex // This count should retry the execution of the previous stage and rerun shuffle. rdd.count() } + + test("metrics for shuffle without aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val numRecords = 10000 + + val metrics = ShuffleSuite.runAndReturnMetrics(sc) { + sc.parallelize(1 to numRecords, 4) + .map(key => (key, 1)) + .groupByKey() + .collect() + } + + assert(metrics.recordsRead === numRecords) + assert(metrics.recordsWritten === numRecords) + assert(metrics.bytesWritten === metrics.byresRead) + assert(metrics.bytesWritten > 0) + } + + test("metrics for shuffle with aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val numRecords = 10000 + + val metrics = ShuffleSuite.runAndReturnMetrics(sc) { + sc.parallelize(1 to numRecords, 4) + .flatMap(key => Array.fill(100)((key, 1))) + .countByKey() + } + + assert(metrics.recordsRead === numRecords) + assert(metrics.recordsWritten === numRecords) + assert(metrics.bytesWritten === metrics.byresRead) + assert(metrics.bytesWritten > 0) + } } object ShuffleSuite { @@ -295,4 +328,35 @@ object ShuffleSuite { value - o.value } } + + case class AggregatedShuffleMetrics( + recordsWritten: Long, + recordsRead: Long, + bytesWritten: Long, + byresRead: Long) + + def runAndReturnMetrics(sc: SparkContext)(job: => Unit): AggregatedShuffleMetrics = { + @volatile var recordsWritten: Long = 0 + @volatile var recordsRead: Long = 0 + @volatile var bytesWritten: Long = 0 + @volatile var bytesRead: Long = 0 + val listener = new SparkListener { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + taskEnd.taskMetrics.shuffleWriteMetrics.foreach { m => + recordsWritten += m.shuffleRecordsWritten + bytesWritten += m.shuffleBytesWritten + } + taskEnd.taskMetrics.shuffleReadMetrics.foreach { m => + recordsRead += m.recordsRead + bytesRead += m.totalBytesRead + } + } + } + sc.addSparkListener(listener) + + job + + sc.listenerBus.waitUntilEmpty(500) + AggregatedShuffleMetrics(recordsWritten, recordsRead, bytesWritten, bytesRead) + } } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index fafa4ed606b08..9fbaeb33f97cd 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -23,29 +23,28 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.{Try, Random} -import org.scalatest.FunSuite import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} import org.apache.spark.util.{RpcUtils, ResetSystemProperties} import com.esotericsoftware.kryo.Kryo -class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { +class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties { test("Test byteString conversion") { val conf = new SparkConf() // Simply exercise the API, we don't need a complete conversion test since that's handled in // UtilsSuite.scala - assert(conf.getSizeAsBytes("fake","1k") === ByteUnit.KiB.toBytes(1)) - assert(conf.getSizeAsKb("fake","1k") === ByteUnit.KiB.toKiB(1)) - assert(conf.getSizeAsMb("fake","1k") === ByteUnit.KiB.toMiB(1)) - assert(conf.getSizeAsGb("fake","1k") === ByteUnit.KiB.toGiB(1)) + assert(conf.getSizeAsBytes("fake", "1k") === ByteUnit.KiB.toBytes(1)) + assert(conf.getSizeAsKb("fake", "1k") === ByteUnit.KiB.toKiB(1)) + assert(conf.getSizeAsMb("fake", "1k") === ByteUnit.KiB.toMiB(1)) + assert(conf.getSizeAsGb("fake", "1k") === ByteUnit.KiB.toGiB(1)) } test("Test timeString conversion") { val conf = new SparkConf() // Simply exercise the API, we don't need a complete conversion test since that's handled in // UtilsSuite.scala - assert(conf.getTimeAsMs("fake","1ms") === TimeUnit.MILLISECONDS.toMillis(1)) - assert(conf.getTimeAsSeconds("fake","1000ms") === TimeUnit.MILLISECONDS.toSeconds(1000)) + assert(conf.getTimeAsMs("fake", "1ms") === TimeUnit.MILLISECONDS.toMillis(1)) + assert(conf.getTimeAsSeconds("fake", "1000ms") === TimeUnit.MILLISECONDS.toSeconds(1000)) } test("loading from system properties") { diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index e6ab538d77bcc..2bdbd70c638a5 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark -import org.scalatest.{Assertions, FunSuite} +import org.scalatest.Assertions import org.apache.spark.storage.StorageLevel -class SparkContextInfoSuite extends FunSuite with LocalSparkContext { +class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { test("getPersistentRDDs only returns RDDs that are marked as cached") { sc = new SparkContext("local", "test") assert(sc.getPersistentRDDs.isEmpty === true) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 9343f4fff89da..f89e3d0a49920 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.PrivateMethodTester import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} @@ -25,7 +25,7 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me import org.apache.spark.scheduler.local.LocalBackend class SparkContextSchedulerCreationSuite - extends FunSuite with LocalSparkContext with PrivateMethodTester with Logging { + extends SparkFunSuite with LocalSparkContext with PrivateMethodTester with Logging { def createTaskScheduler(master: String): TaskSchedulerImpl = createTaskScheduler(master, new SparkConf()) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 9049db7755358..6838b35ab4cc8 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -23,8 +23,6 @@ import java.util.concurrent.TimeUnit import com.google.common.base.Charsets._ import com.google.common.io.Files -import org.scalatest.FunSuite - import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} @@ -33,7 +31,7 @@ import org.apache.spark.util.Utils import scala.concurrent.Await import scala.concurrent.duration.Duration -class SparkContextSuite extends FunSuite with LocalSparkContext { +class SparkContextSuite extends SparkFunSuite with LocalSparkContext { test("Only one SparkContext may be active at a time") { // Regression test for SPARK-4180 @@ -73,22 +71,22 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { var sc2: SparkContext = null SparkContext.clearActiveContext() val conf = new SparkConf().setAppName("test").setMaster("local") - + sc = SparkContext.getOrCreate(conf) - + assert(sc.getConf.get("spark.app.name").equals("test")) sc2 = SparkContext.getOrCreate(new SparkConf().setAppName("test2").setMaster("local")) assert(sc2.getConf.get("spark.app.name").equals("test")) assert(sc === sc2) assert(sc eq sc2) - + // Try creating second context to confirm that it's still possible, if desired sc2 = new SparkContext(new SparkConf().setAppName("test3").setMaster("local") .set("spark.driver.allowMultipleContexts", "true")) - + sc2.stop() } - + test("BytesWritable implicit conversion is correct") { // Regression test for SPARK-3121 val bytesWritable = new BytesWritable() @@ -222,8 +220,8 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { val dir1 = Utils.createTempDir() val dir2 = Utils.createTempDir() - val dirpath1=dir1.getAbsolutePath - val dirpath2=dir2.getAbsolutePath + val dirpath1 = dir1.getAbsolutePath + val dirpath2 = dir2.getAbsolutePath // file1 and file2 are placed inside dir1, they are also used for // textFile, hadoopFile, and newAPIHadoopFile @@ -235,11 +233,11 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { val file4 = new File(dir2, "part-00001") val file5 = new File(dir2, "part-00002") - val filepath1=file1.getAbsolutePath - val filepath2=file2.getAbsolutePath - val filepath3=file3.getAbsolutePath - val filepath4=file4.getAbsolutePath - val filepath5=file5.getAbsolutePath + val filepath1 = file1.getAbsolutePath + val filepath2 = file2.getAbsolutePath + val filepath3 = file3.getAbsolutePath + val filepath4 = file4.getAbsolutePath + val filepath5 = file5.getAbsolutePath try { diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala new file mode 100644 index 0000000000000..9be9db01c7de9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +// scalastyle:off +import org.scalatest.{FunSuite, Outcome} + +/** + * Base abstract class for all unit tests in Spark for handling common functionality. + */ +private[spark] abstract class SparkFunSuite extends FunSuite with Logging { +// scalastyle:on + + /** + * Log the suite name and the test name before and after each test. + * + * Subclasses should never override this method. If they wish to run + * custom code before and after each test, they should mix in the + * {{org.scalatest.BeforeAndAfter}} trait instead. + */ + final protected override def withFixture(test: NoArgTest): Outcome = { + val testName = test.text + val suiteName = this.getClass.getName + val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s") + try { + logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n") + test() + } finally { + logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 084eb237d70d1..46516e8d25298 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -21,12 +21,12 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.apache.spark.JobExecutionStatus._ -class StatusTrackerSuite extends FunSuite with Matchers with LocalSparkContext { +class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkContext { test("basic status API usage") { sc = new SparkContext("local", "test", new SparkConf(false)) diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 10917c866cc7d..6580139df6c60 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -22,7 +22,6 @@ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.scheduler._ -import org.scalatest.FunSuite /** * Holds state shared across task threads in some ThreadingSuite tests. @@ -37,7 +36,7 @@ object ThreadingSuiteState { } } -class ThreadingSuite extends FunSuite with LocalSparkContext { +class ThreadingSuite extends SparkFunSuite with LocalSparkContext { test("accessing SparkContext form a different thread") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala index 42ff059e018a3..f7a13ab3996d8 100644 --- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark -import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ import org.scalatest.time.{Millis, Span} -class UnpersistSuite extends FunSuite with LocalSparkContext { +class UnpersistSuite extends SparkFunSuite with LocalSparkContext { test("unpersist RDD") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala index 8959a843dbd7d..135c56bf5bc9d 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala @@ -21,15 +21,15 @@ import scala.io.Source import java.io.{PrintWriter, File} -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers -import org.apache.spark.{SharedSparkContext, SparkConf} +import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils // This test suite uses SharedSparkContext because we need a SparkEnv in order to deserialize // a PythonBroadcast: -class PythonBroadcastSuite extends FunSuite with Matchers with SharedSparkContext { +class PythonBroadcastSuite extends SparkFunSuite with Matchers with SharedSparkContext { test("PythonBroadcast can be serialized with Kryo (SPARK-4882)") { val tempDir = Utils.createTempDir() val broadcastedString = "Hello, world!" diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index c63d834f9048b..41f2a5c972b6b 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.api.python import java.io.{ByteArrayOutputStream, DataOutputStream} -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class PythonRDDSuite extends FunSuite { +class PythonRDDSuite extends SparkFunSuite { test("Writing large strings to the worker") { val input: List[String] = List("a"*100000) diff --git a/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala index f8c39326145e1..267a79fa63782 100644 --- a/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.api.python -import org.scalatest.FunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} -import org.apache.spark.SharedSparkContext - -class SerDeUtilSuite extends FunSuite with SharedSparkContext { +class SerDeUtilSuite extends SparkFunSuite with SharedSparkContext { test("Converting an empty pair RDD to python does not throw an exception (SPARK-5441)") { val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 06e5f1cf6b96f..c054c718075f8 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.broadcast -import scala.concurrent.duration._ import scala.util.Random -import org.scalatest.{Assertions, FunSuite} -import org.scalatest.concurrent.Eventually._ +import org.scalatest.Assertions -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkEnv} +import org.apache.spark._ import org.apache.spark.io.SnappyCompressionCodec import org.apache.spark.rdd.RDD import org.apache.spark.serializer.JavaSerializer @@ -45,7 +43,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { } } -class BroadcastSuite extends FunSuite with LocalSparkContext { +class BroadcastSuite extends SparkFunSuite with LocalSparkContext { private val httpConf = broadcastConf("HttpBroadcastFactory") private val torrentConf = broadcastConf("TorrentBroadcastFactory") @@ -286,7 +284,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(statuses.size === expectedNumBlocks) } - testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -312,13 +310,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { val _sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up - eventually(timeout(10.seconds), interval(10.milliseconds)) { - _sc.jobProgressListener.synchronized { - val numBlockManagers = _sc.jobProgressListener.blockManagerIds.size - assert(numBlockManagers == numSlaves + 1, - s"Expect ${numSlaves + 1} block managers, but was ${numBlockManagers}") - } - } + _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) _sc } else { new SparkContext("local", "test", broadcastConf) diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index 745f9eeee7536..6a99dbca64f4b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.deploy -import org.scalatest.FunSuite import org.scalatest.Matchers -class ClientSuite extends FunSuite with Matchers { +import org.apache.spark.SparkFunSuite + +class ClientSuite extends SparkFunSuite with Matchers { test("correctly validates driver jar URL's") { ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true) ClientArguments.isValidJarUrl("https://someHost:8080/foo.jar") should be (true) diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index e04a79284175c..08529e0ef2806 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -23,14 +23,13 @@ import java.util.Date import com.fasterxml.jackson.core.JsonParseException import org.json4s._ import org.json4s.jackson.JsonMethods -import org.scalatest.FunSuite import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf} +import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} -class JsonProtocolSuite extends FunSuite with JsonTestUtils { +class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { test("writeApplicationInfo") { val output = JsonProtocol.writeApplicationInfo(createAppInfo()) diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index c93d16f8a1586..ddc92814c0acf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -23,13 +23,11 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.io.Source -import org.scalatest.FunSuite - import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListener} -import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} -class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { +class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { /** Length of time to wait while draining listener events. */ private val WAIT_TIMEOUT_MILLIS = 10000 @@ -43,7 +41,7 @@ class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.addedExecutorInfos.values.foreach { info => assert(info.logUrlMap.nonEmpty) // Browse to each URL to check that it's valid @@ -73,7 +71,7 @@ class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo] assert(listeners.size === 1) val listener = listeners(0) diff --git a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala index 80f2cc02516fe..473a2d7b2a258 100644 --- a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.deploy -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.Utils -class PythonRunnerSuite extends FunSuite { +class PythonRunnerSuite extends SparkFunSuite { // Test formatting a single path to be added to the PYTHONPATH test("format path") { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index ea9227a7e9af5..46ea28d0f18f6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Charsets.UTF_8 import com.google.common.io.ByteStreams -import org.scalatest.FunSuite import org.scalatest.Matchers import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -35,7 +34,12 @@ import org.apache.spark.util.{ResetSystemProperties, Utils} // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch // of properties that neeed to be cleared after tests. -class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties with Timeouts { +class SparkSubmitSuite + extends SparkFunSuite + with Matchers + with ResetSystemProperties + with Timeouts { + def beforeAll() { System.setProperty("spark.testing", "true") } @@ -58,7 +62,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties SparkSubmit.printStream = printStream @volatile var exitedCleanly = false - SparkSubmit.exitFn = () => exitedCleanly = true + SparkSubmit.exitFn = (_) => exitedCleanly = true val thread = new Thread { override def run() = try { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 088ca3cb93b49..07d261cc428c4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -20,15 +20,19 @@ package org.apache.spark.deploy import java.io.{File, PrintStream, OutputStream} import scala.collection.mutable.ArrayBuffer -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.apache.ivy.core.module.descriptor.MDArtifact import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.resolver.IBiblioResolver +import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate +import org.apache.spark.util.Utils -class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { +class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var tempIvyPath: String = _ private val noOpOutputStream = new OutputStream { def write(b: Int) = {} @@ -46,6 +50,7 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { super.beforeAll() // We don't want to write logs during testing SparkSubmitUtils.printStream = new BufferPrintStream + tempIvyPath = Utils.createTempDir(namePrefix = "ivy").getAbsolutePath() } test("incorrect maven coordinate throws error") { @@ -89,21 +94,20 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { } test("ivy path works correctly") { - val ivyPath = "dummy" + File.separator + "ivy" val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") - var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(ivyPath)) + var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(tempIvyPath)) for (i <- 0 until 3) { - val index = jPaths.indexOf(ivyPath) + val index = jPaths.indexOf(tempIvyPath) assert(index >= 0) - jPaths = jPaths.substring(index + ivyPath.length) + jPaths = jPaths.substring(index + tempIvyPath.length) } val main = MavenCoordinate("my.awesome.lib", "mylib", "0.1") IvyTestUtils.withRepository(main, None, None) { repo => // end to end val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, Option(repo), - Option(ivyPath), true) - assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") + Option(tempIvyPath), true) + assert(jarPath.indexOf(tempIvyPath) >= 0, "should use non-default ivy path") } } @@ -122,13 +126,12 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { assert(jarPath.indexOf("mylib") >= 0, "should find artifact") } // Local ivy repository with modified home - val dummyIvyPath = "dummy" + File.separator + "ivy" - val dummyIvyLocal = new File(dummyIvyPath, "local" + File.separator) + val dummyIvyLocal = new File(tempIvyPath, "local" + File.separator) IvyTestUtils.withRepository(main, None, Some(dummyIvyLocal), true) { repo => val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, - Some(dummyIvyPath), true) + Some(tempIvyPath), true) assert(jarPath.indexOf("mylib") >= 0, "should find artifact") - assert(jarPath.indexOf(dummyIvyPath) >= 0, "should be in new ivy path") + assert(jarPath.indexOf(tempIvyPath) >= 0, "should be in new ivy path") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index a0a0afa48833e..09075eeb539aa 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -17,23 +17,27 @@ package org.apache.spark.deploy.history -import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStreamWriter} +import java.io.{BufferedOutputStream, ByteArrayInputStream, ByteArrayOutputStream, File, + FileOutputStream, OutputStreamWriter} import java.net.URI import java.util.concurrent.TimeUnit +import java.util.zip.{ZipInputStream, ZipOutputStream} import scala.io.Source +import com.google.common.base.Charsets +import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.Matchers -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.io._ import org.apache.spark.scheduler._ import org.apache.spark.util.{JsonProtocol, ManualClock, Utils} -class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { private var testDir: File = null @@ -335,6 +339,40 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers assert(!log2.exists()) } + test("Event log copy") { + val provider = new FsHistoryProvider(createTestConf()) + val logs = (1 to 2).map { i => + val log = newLogFile("downloadApp1", Some(s"attempt$i"), inProgress = false) + writeFile(log, true, None, + SparkListenerApplicationStart( + "downloadApp1", Some("downloadApp1"), 5000 * i, "test", Some(s"attempt$i")), + SparkListenerApplicationEnd(5001 * i) + ) + log + } + provider.checkForLogs() + + (1 to 2).foreach { i => + val underlyingStream = new ByteArrayOutputStream() + val outputStream = new ZipOutputStream(underlyingStream) + provider.writeEventLogs("downloadApp1", Some(s"attempt$i"), outputStream) + outputStream.close() + val inputStream = new ZipInputStream(new ByteArrayInputStream(underlyingStream.toByteArray)) + var totalEntries = 0 + var entry = inputStream.getNextEntry + entry should not be null + while (entry != null) { + val actual = new String(ByteStreams.toByteArray(inputStream), Charsets.UTF_8) + val expected = Files.toString(logs.find(_.getName == entry.getName).get, Charsets.UTF_8) + actual should be (expected) + totalEntries += 1 + entry = inputStream.getNextEntry + } + totalEntries should be (1) + inputStream.close() + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 4adb5122bcf1a..e5b5e1bb65337 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -16,16 +16,19 @@ */ package org.apache.spark.deploy.history -import java.io.{File, FileInputStream, FileWriter, IOException} +import java.io.{File, FileInputStream, FileWriter, InputStream, IOException} import java.net.{HttpURLConnection, URL} +import java.util.zip.ZipInputStream import javax.servlet.http.{HttpServletRequest, HttpServletResponse} +import com.google.common.base.Charsets +import com.google.common.io.{ByteStreams, Files} import org.apache.commons.io.{FileUtils, IOUtils} import org.mockito.Mockito.when -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.mock.MockitoSugar -import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf} +import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.ui.SparkUI /** @@ -39,7 +42,7 @@ import org.apache.spark.ui.SparkUI * expectations. However, in general this should be done with extreme caution, as the metrics * are considered part of Spark's public api. */ -class HistoryServerSuite extends FunSuite with BeforeAndAfter with Matchers with MockitoSugar +class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers with MockitoSugar with JsonTestUtils { private val logDir = new File("src/test/resources/spark-events") @@ -82,7 +85,7 @@ class HistoryServerSuite extends FunSuite with BeforeAndAfter with Matchers with "running app list json" -> "applications?status=running", "minDate app list json" -> "applications?minDate=2015-02-10", "maxDate app list json" -> "applications?maxDate=2015-02-10", - "maxDate2 app list json" -> "applications?maxDate=2015-02-03T10:42:40.000CST", + "maxDate2 app list json" -> "applications?maxDate=2015-02-03T16:42:40.000GMT", "one app json" -> "applications/local-1422981780767", "one app multi-attempt json" -> "applications/local-1426533911241", "job list json" -> "applications/local-1422981780767/jobs", @@ -147,6 +150,70 @@ class HistoryServerSuite extends FunSuite with BeforeAndAfter with Matchers with } } + test("download all logs for app with multiple attempts") { + doDownloadTest("local-1430917381535", None) + } + + test("download one log for app with multiple attempts") { + (1 to 2).foreach { attemptId => doDownloadTest("local-1430917381535", Some(attemptId)) } + } + + test("download legacy logs - all attempts") { + doDownloadTest("local-1426533911241", None, legacy = true) + } + + test("download legacy logs - single attempts") { + (1 to 2). foreach { + attemptId => doDownloadTest("local-1426533911241", Some(attemptId), legacy = true) + } + } + + // Test that the files are downloaded correctly, and validate them. + def doDownloadTest(appId: String, attemptId: Option[Int], legacy: Boolean = false): Unit = { + + val url = attemptId match { + case Some(id) => + new URL(s"${generateURL(s"applications/$appId")}/$id/logs") + case None => + new URL(s"${generateURL(s"applications/$appId")}/logs") + } + + val (code, inputStream, error) = HistoryServerSuite.connectAndGetInputStream(url) + code should be (HttpServletResponse.SC_OK) + inputStream should not be None + error should be (None) + + val zipStream = new ZipInputStream(inputStream.get) + var entry = zipStream.getNextEntry + entry should not be null + val totalFiles = { + if (legacy) { + attemptId.map { x => 3 }.getOrElse(6) + } else { + attemptId.map { x => 1 }.getOrElse(2) + } + } + var filesCompared = 0 + while (entry != null) { + if (!entry.isDirectory) { + val expectedFile = { + if (legacy) { + val splits = entry.getName.split("/") + new File(new File(logDir, splits(0)), splits(1)) + } else { + new File(logDir, entry.getName) + } + } + val expected = Files.toString(expectedFile, Charsets.UTF_8) + val actual = new String(ByteStreams.toByteArray(zipStream), Charsets.UTF_8) + actual should be (expected) + filesCompared += 1 + } + entry = zipStream.getNextEntry + } + filesCompared should be (totalFiles) + } + test("response codes on bad paths") { val badAppId = getContentAndCode("applications/foobar") badAppId._1 should be (HttpServletResponse.SC_NOT_FOUND) @@ -202,7 +269,11 @@ class HistoryServerSuite extends FunSuite with BeforeAndAfter with Matchers with } def getUrl(path: String): String = { - HistoryServerSuite.getUrl(new URL(s"http://localhost:$port/api/v1/$path")) + HistoryServerSuite.getUrl(generateURL(path)) + } + + def generateURL(path: String): URL = { + new URL(s"http://localhost:$port/api/v1/$path") } def generateExpectation(name: String, path: String): Unit = { @@ -233,13 +304,18 @@ object HistoryServerSuite { } def getContentAndCode(url: URL): (Int, Option[String], Option[String]) = { + val (code, in, errString) = connectAndGetInputStream(url) + val inString = in.map(IOUtils.toString) + (code, inString, errString) + } + + def connectAndGetInputStream(url: URL): (Int, Option[InputStream], Option[String]) = { val connection = url.openConnection().asInstanceOf[HttpURLConnection] connection.setRequestMethod("GET") connection.connect() val code = connection.getResponseCode() - val inString = try { - val in = Option(connection.getInputStream()) - in.map(IOUtils.toString) + val inStream = try { + Option(connection.getInputStream()) } catch { case io: IOException => None } @@ -249,7 +325,7 @@ object HistoryServerSuite { } catch { case io: IOException => None } - (code, inString, errString) + (code, inStream, errString) } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index f97e5ff6db31d..014e87bb40254 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -27,14 +27,14 @@ import scala.language.postfixOps import akka.actor.Address import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy._ -class MasterSuite extends FunSuite with Matchers with Eventually { +class MasterSuite extends SparkFunSuite with Matchers with Eventually { test("toAkkaUrl") { val conf = new SparkConf(loadDefaults = false) diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index f4d548d9e7720..197f68e7ec5ed 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -25,7 +25,7 @@ import scala.collection.mutable import akka.actor.{Actor, ActorRef, ActorSystem, Props} import com.google.common.base.Charsets -import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.scalatest.BeforeAndAfterEach import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ @@ -38,7 +38,7 @@ import org.apache.spark.deploy.master.DriverState._ /** * Tests for the REST application submission protocol used in standalone cluster mode. */ -class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { +class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { private var actorSystem: Option[ActorSystem] = None private var server: Option[RestSubmissionServer] = None diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 61071ee17256c..115ac0534a1b4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -21,14 +21,13 @@ import java.lang.Boolean import java.lang.Integer import org.json4s.jackson.JsonMethods._ -import org.scalatest.FunSuite -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} /** * Tests for the REST application submission protocol. */ -class SubmitRestProtocolSuite extends FunSuite { +class SubmitRestProtocolSuite extends SparkFunSuite { test("validate") { val request = new DummyRequest diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala index 1c27d83cf876c..5b3930c0b0132 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.deploy.worker +import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.Command import org.apache.spark.util.Utils -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers -class CommandUtilsSuite extends FunSuite with Matchers { +class CommandUtilsSuite extends SparkFunSuite with Matchers { test("set libraryPath correctly") { val appId = "12345-worker321-9876" diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index 2159fd8c16c6f..6258c18d177fd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -23,13 +23,12 @@ import org.mockito.Mockito._ import org.mockito.Matchers._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.FunSuite -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.{Command, DriverDescription} import org.apache.spark.util.Clock -class DriverRunnerTest extends FunSuite { +class DriverRunnerTest extends SparkFunSuite { private def createDriverRunner() = { val command = new Command("mainClass", Seq(), Map(), Seq(), Seq(), Seq()) val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index a8b9df227c996..3da992788962b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -21,12 +21,10 @@ import java.io.File import scala.collection.JavaConversions._ -import org.scalatest.FunSuite - import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} -class ExecutorRunnerTest extends FunSuite { +class ExecutorRunnerTest extends SparkFunSuite { test("command includes appId") { val appId = "12345-worker321-9876" val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index 7cc2104281464..15f7ca4a6dacc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -18,11 +18,10 @@ package org.apache.spark.deploy.worker -import org.apache.spark.SparkConf -import org.scalatest.FunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} -class WorkerArgumentsTest extends FunSuite { +class WorkerArgumentsTest extends SparkFunSuite { test("Memory can't be set to 0 when cmd line args leave off M or G") { val conf = new SparkConf @@ -66,7 +65,7 @@ class WorkerArgumentsTest extends FunSuite { } } val conf = new MySparkConf() - val workerArgs = new WorkerArguments(args, conf) + val workerArgs = new WorkerArguments(args, conf) assert(workerArgs.memory === 5120) } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 450fba21f4b5c..0f4d3b28d09df 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.deploy.worker -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.Command -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers -class WorkerSuite extends FunSuite with Matchers { +class WorkerSuite extends SparkFunSuite with Matchers { def cmd(javaOpts: String*): Command = { - Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts:_*)) + Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) } def conf(opts: (String, String)*): SparkConf = new SparkConf(loadDefaults = false).setAll(opts) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 6a6f29dd613cd..ac18f04a11475 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -18,12 +18,11 @@ package org.apache.spark.deploy.worker import akka.actor.AddressFromURIString -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SecurityManager import org.apache.spark.rpc.{RpcAddress, RpcEnv} -import org.scalatest.FunSuite -class WorkerWatcherSuite extends FunSuite { +class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala new file mode 100644 index 0000000000000..72eaffb416981 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker.ui + +import java.io.{File, FileWriter} + +import org.mockito.Mockito.{mock, when} +import org.scalatest.PrivateMethodTester + +import org.apache.spark.SparkFunSuite + +class LogPageSuite extends SparkFunSuite with PrivateMethodTester { + + test("get logs simple") { + val webui = mock(classOf[WorkerWebUI]) + val tmpDir = new File(sys.props("java.io.tmpdir")) + val workDir = new File(tmpDir, "work-dir") + workDir.mkdir() + when(webui.workDir).thenReturn(workDir) + val logPage = new LogPage(webui) + + // Prepare some fake log files to read later + val out = "some stdout here" + val err = "some stderr here" + val tmpOut = new File(workDir, "stdout") + val tmpErr = new File(workDir, "stderr") + val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory + val tmpOutBad = new File(tmpDir, "stdout") + val tmpRand = new File(workDir, "random") + write(tmpOut, out) + write(tmpErr, err) + write(tmpOutBad, out) + write(tmpErrBad, err) + write(tmpRand, "1 6 4 5 2 7 8") + + // Get the logs. All log types other than "stderr" or "stdout" will be rejected + val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog) + val (stdout, _, _, _) = + logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100) + val (stderr, _, _, _) = + logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100) + val (error1, _, _, _) = + logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100) + val (error2, _, _, _) = + logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100) + // These files exist, but live outside the working directory + val (error3, _, _, _) = + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) + val (error4, _, _, _) = + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) + assert(stdout === out) + assert(stderr === err) + assert(error1.startsWith("Error: Log type must be one of ")) + assert(error2.startsWith("Error: Log type must be one of ")) + assert(error3.startsWith("Error: invalid log directory")) + assert(error4.startsWith("Error: invalid log directory")) + } + + /** Write the specified string to the file. */ + private def write(f: File, s: String): Unit = { + val writer = new FileWriter(f) + try { + writer.write(s) + } finally { + writer.close() + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 326e203afe136..8275fd87764cd 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.executor -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class TaskMetricsSuite extends FunSuite { +class TaskMetricsSuite extends SparkFunSuite { test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") { val taskMetrics = new TaskMetrics() taskMetrics.updateShuffleReadMetrics() diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 2e58c159a2ed8..63947df3d43a2 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -24,11 +24,10 @@ import java.io.FileOutputStream import scala.collection.immutable.IndexedSeq import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite import org.apache.hadoop.io.Text -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.util.Utils import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} @@ -37,7 +36,7 @@ import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, Gzi * [[org.apache.spark.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary * directory is created as fake input. Temporal storage would be deleted in the end. */ -class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { +class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll { private var sc: SparkContext = _ private var factory: CompressionCodecFactory = _ diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index cf6a143537889..cbdb33c89d0fb 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.io import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import com.google.common.io.ByteStreams -import org.scalatest.FunSuite -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} -class CompressionCodecSuite extends FunSuite { +class CompressionCodecSuite extends SparkFunSuite { val conf = new SparkConf(false) def testCodec(codec: CompressionCodec) { diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index ef3e213f1fcce..9e4d34fb7d382 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -36,14 +36,14 @@ import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombi import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.apache.hadoop.mapreduce.{TaskAttemptContext, InputSplit => NewInputSplit, RecordReader => NewRecordReader} -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.SharedSparkContext +import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.util.Utils -class InputOutputMetricsSuite extends FunSuite with SharedSparkContext +class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext with BeforeAndAfter { @transient var tmpDir: File = _ @@ -193,26 +193,6 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext assert(records == numRecords) } - test("shuffle records read metrics") { - val recordsRead = runAndReturnShuffleRecordsRead { - sc.textFile(tmpFilePath, 4) - .map(key => (key, 1)) - .groupByKey() - .collect() - } - assert(recordsRead == numRecords) - } - - test("shuffle records written metrics") { - val recordsWritten = runAndReturnShuffleRecordsWritten { - sc.textFile(tmpFilePath, 4) - .map(key => (key, 1)) - .groupByKey() - .collect() - } - assert(recordsWritten == numRecords) - } - /** * Tests the metrics from end to end. * 1) reading a hadoop file @@ -263,7 +243,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext val tmpRdd = sc.textFile(tmpFilePath, numPartitions) - val firstSize= runAndReturnBytesRead { + val firstSize = runAndReturnBytesRead { aRdd.count() } val secondSize = runAndReturnBytesRead { @@ -301,14 +281,6 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext runAndReturnMetrics(job, _.taskMetrics.outputMetrics.map(_.recordsWritten)) } - private def runAndReturnShuffleRecordsRead(job: => Unit): Long = { - runAndReturnMetrics(job, _.taskMetrics.shuffleReadMetrics.map(_.recordsRead)) - } - - private def runAndReturnShuffleRecordsWritten(job: => Unit): Long = { - runAndReturnMetrics(job, _.taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten)) - } - private def runAndReturnMetrics(job: => Unit, collector: (SparkListenerTaskEnd) => Option[Long]): Long = { val taskMetrics = new ArrayBuffer[Long]() @@ -433,10 +405,10 @@ class OldCombineTextRecordReaderWrapper( /** * Hadoop 2 has a version of this, but we can't use it for backwards compatibility */ -class NewCombineTextInputFormat extends NewCombineFileInputFormat[LongWritable,Text] { +class NewCombineTextInputFormat extends NewCombineFileInputFormat[LongWritable, Text] { def createRecordReader(split: NewInputSplit, context: TaskAttemptContext) : NewRecordReader[LongWritable, Text] = { - new NewCombineFileRecordReader[LongWritable,Text](split.asInstanceOf[NewCombineFileSplit], + new NewCombineFileRecordReader[LongWritable, Text](split.asInstanceOf[NewCombineFileSplit], context, classOf[NewCombineTextRecordReaderWrapper]) } } diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala index 100ac77dec1f7..41f2ff725a17b 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -17,9 +17,13 @@ package org.apache.spark.metrics -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.SparkConf -class MetricsConfigSuite extends FunSuite with BeforeAndAfter { +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite + +class MetricsConfigSuite extends SparkFunSuite with BeforeAndAfter { var filePath: String = _ before { @@ -27,7 +31,9 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { } test("MetricsConfig with default properties") { - val conf = new MetricsConfig(None) + val sparkConf = new SparkConf(loadDefaults = false) + sparkConf.set("spark.metrics.conf", "dummy-file") + val conf = new MetricsConfig(sparkConf) conf.initialize() assert(conf.properties.size() === 4) @@ -40,8 +46,10 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { assert(property.getProperty("sink.servlet.path") === "/metrics/json") } - test("MetricsConfig with properties set") { - val conf = new MetricsConfig(Option(filePath)) + test("MetricsConfig with properties set from a file") { + val sparkConf = new SparkConf(loadDefaults = false) + sparkConf.set("spark.metrics.conf", filePath) + val conf = new MetricsConfig(sparkConf) conf.initialize() val masterProp = conf.getInstance("master") @@ -65,8 +73,71 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { assert(workerProp.getProperty("sink.servlet.path") === "/metrics/json") } + test("MetricsConfig with properties set from a Spark configuration") { + val sparkConf = new SparkConf(loadDefaults = false) + setMetricsProperty(sparkConf, "*.sink.console.period", "10") + setMetricsProperty(sparkConf, "*.sink.console.unit", "seconds") + setMetricsProperty(sparkConf, "*.source.jvm.class", "org.apache.spark.metrics.source.JvmSource") + setMetricsProperty(sparkConf, "master.sink.console.period", "20") + setMetricsProperty(sparkConf, "master.sink.console.unit", "minutes") + val conf = new MetricsConfig(sparkConf) + conf.initialize() + + val masterProp = conf.getInstance("master") + assert(masterProp.size() === 5) + assert(masterProp.getProperty("sink.console.period") === "20") + assert(masterProp.getProperty("sink.console.unit") === "minutes") + assert(masterProp.getProperty("source.jvm.class") === + "org.apache.spark.metrics.source.JvmSource") + assert(masterProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") + assert(masterProp.getProperty("sink.servlet.path") === "/metrics/master/json") + + val workerProp = conf.getInstance("worker") + assert(workerProp.size() === 5) + assert(workerProp.getProperty("sink.console.period") === "10") + assert(workerProp.getProperty("sink.console.unit") === "seconds") + assert(workerProp.getProperty("source.jvm.class") === + "org.apache.spark.metrics.source.JvmSource") + assert(workerProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") + assert(workerProp.getProperty("sink.servlet.path") === "/metrics/json") + } + + test("MetricsConfig with properties set from a file and a Spark configuration") { + val sparkConf = new SparkConf(loadDefaults = false) + setMetricsProperty(sparkConf, "*.sink.console.period", "10") + setMetricsProperty(sparkConf, "*.sink.console.unit", "seconds") + setMetricsProperty(sparkConf, "*.source.jvm.class", "org.apache.spark.SomeOtherSource") + setMetricsProperty(sparkConf, "master.sink.console.period", "50") + setMetricsProperty(sparkConf, "master.sink.console.unit", "seconds") + sparkConf.set("spark.metrics.conf", filePath) + val conf = new MetricsConfig(sparkConf) + conf.initialize() + + val masterProp = conf.getInstance("master") + assert(masterProp.size() === 5) + assert(masterProp.getProperty("sink.console.period") === "50") + assert(masterProp.getProperty("sink.console.unit") === "seconds") + assert(masterProp.getProperty("source.jvm.class") === "org.apache.spark.SomeOtherSource") + assert(masterProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") + assert(masterProp.getProperty("sink.servlet.path") === "/metrics/master/json") + + val workerProp = conf.getInstance("worker") + assert(workerProp.size() === 5) + assert(workerProp.getProperty("sink.console.period") === "10") + assert(workerProp.getProperty("sink.console.unit") === "seconds") + assert(workerProp.getProperty("source.jvm.class") === "org.apache.spark.SomeOtherSource") + assert(workerProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") + assert(workerProp.getProperty("sink.servlet.path") === "/metrics/json") + } + test("MetricsConfig with subProperties") { - val conf = new MetricsConfig(Option(filePath)) + val sparkConf = new SparkConf(loadDefaults = false) + sparkConf.set("spark.metrics.conf", filePath) + val conf = new MetricsConfig(sparkConf) conf.initialize() val propCategories = conf.propertyCategories @@ -88,4 +159,9 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { val servletProps = sinkProps("servlet") assert(servletProps.size() === 2) } + + private def setMetricsProperty(conf: SparkConf, name: String, value: String): Unit = { + conf.set(s"spark.metrics.conf.$name", value) + } + } diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index bbdc9568a6ddb..9c389c76bf3bd 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.metrics -import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.master.MasterSource import org.apache.spark.metrics.source.Source @@ -27,7 +27,7 @@ import com.codahale.metrics.MetricRegistry import scala.collection.mutable.ArrayBuffer -class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester{ +class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester{ var filePath: String = _ var conf: SparkConf = null var securityMgr: SecurityManager = null diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 46d2e5173acae..3940527fb874e 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -31,12 +31,12 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.storage.{BlockId, ShuffleBlockId} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar -import org.scalatest.{FunSuite, ShouldMatchers} +import org.scalatest.ShouldMatchers -class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { +class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with ShouldMatchers { test("security default off") { val conf = new SparkConf() .set("spark.app.id", "app-id") diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index a41f8b7ce5ce0..6f8e8a7ac6033 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -18,11 +18,15 @@ package org.apache.spark.network.netty import org.apache.spark.network.BlockDataManager -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito.mock import org.scalatest._ -class NettyBlockTransferServiceSuite extends FunSuite with BeforeAndAfterEach with ShouldMatchers { +class NettyBlockTransferServiceSuite + extends SparkFunSuite + with BeforeAndAfterEach + with ShouldMatchers { + private var service0: NettyBlockTransferService = _ private var service1: NettyBlockTransferService = _ diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index 02424c59d6831..5e364cc0edeb2 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -24,15 +24,13 @@ import scala.concurrent.duration._ import scala.concurrent.{Await, TimeoutException} import scala.language.postfixOps -import org.scalatest.FunSuite - -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.util.Utils /** * Test the ConnectionManager with various security settings. */ -class ConnectionManagerSuite extends FunSuite { +class ConnectionManagerSuite extends SparkFunSuite { test("security default off") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index f2b0ea1063a72..ec99f2a1bad66 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -23,13 +23,13 @@ import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkContext, SparkException, LocalSparkContext} +import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} -class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts { +class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts { @transient private var sc: SparkContext = _ diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 01039b9449daf..4e72b89bfcc40 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.rdd -import org.scalatest.FunSuite - import org.apache.spark._ -class DoubleRDDSuite extends FunSuite with SharedSparkContext { +class DoubleRDDSuite extends SparkFunSuite with SharedSparkContext { test("sum") { assert(sc.parallelize(Seq.empty[Double]).sum() === 0.0) assert(sc.parallelize(Seq(1.0)).sum() === 1.0) diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index be8467354b222..08215a2bafc09 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.rdd import java.sql._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.{LocalSparkContext, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} -class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { +class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { before { Class.forName("org.apache.derby.jdbc.EmbeddedDriver") @@ -82,7 +82,7 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { assert(rdd.count === 100) assert(rdd.reduce(_ + _) === 10100) } - + test("large id overflow") { sc = new SparkContext("local", "test") val rdd = new JdbcRDD( diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index ca0d953d306d8..dfa102f432a02 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -28,12 +28,10 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} -import org.apache.spark.{Partitioner, SharedSparkContext} +import org.apache.spark.{Partitioner, SharedSparkContext, SparkFunSuite} import org.apache.spark.util.Utils -import org.scalatest.FunSuite - -class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { +class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { test("aggregateByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 1), (3, 2), (5, 1), (5, 3)), 2) @@ -512,17 +510,17 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("lookup") { - val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7))) + val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7))) assert(pairs.partitioner === None) assert(pairs.lookup(1) === Seq(2)) - assert(pairs.lookup(5) === Seq(6,7)) + assert(pairs.lookup(5) === Seq(6, 7)) assert(pairs.lookup(-1) === Seq()) } test("lookup with partitioner") { - val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7))) + val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7))) val p = new Partitioner { def numPartitions: Int = 2 @@ -533,12 +531,12 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { assert(shuffled.partitioner === Some(p)) assert(shuffled.lookup(1) === Seq(2)) - assert(shuffled.lookup(5) === Seq(6,7)) + assert(shuffled.lookup(5) === Seq(6, 7)) assert(shuffled.lookup(-1) === Seq()) } test("lookup with bad partitioner") { - val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7))) + val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7))) val p = new Partitioner { def numPartitions: Int = 2 diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index 1880364581c1a..e7cc1617cdf1c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -22,10 +22,11 @@ import scala.collection.immutable.NumericRange import org.scalacheck.Arbitrary._ import org.scalacheck.Gen import org.scalacheck.Prop._ -import org.scalatest.FunSuite import org.scalatest.prop.Checkers -class ParallelCollectionSplitSuite extends FunSuite with Checkers { +import org.apache.spark.SparkFunSuite + +class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { test("one element per slice") { val data = Array(1, 2, 3) val slices = ParallelCollectionRDD.slice(data, 3) diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala index 465068c6cbb16..b1544a6106110 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.rdd -import org.scalatest.FunSuite +import org.apache.spark.{Partition, SharedSparkContext, SparkFunSuite, TaskContext} -import org.apache.spark.{Partition, SharedSparkContext, TaskContext} - -class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { +class PartitionPruningRDDSuite extends SparkFunSuite with SharedSparkContext { test("Pruned Partitions inherit locality prefs correctly") { diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index 0d1369c19c69e..132a5fa9a80fb 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.rdd -import org.scalatest.FunSuite - -import org.apache.spark.SharedSparkContext +import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler} /** a sampler that outputs its seed */ @@ -38,7 +36,7 @@ class MockSampler extends RandomSampler[Long, Long] { override def clone: MockSampler = new MockSampler } -class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { +class PartitionwiseSampledRDDSuite extends SparkFunSuite with SharedSparkContext { test("seed distribution") { val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2) diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 85eb2a1d07ba4..32f04d54eff94 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat} -import org.scalatest.FunSuite import scala.collection.Map import scala.language.postfixOps @@ -32,7 +31,7 @@ import scala.util.Try import org.apache.spark._ import org.apache.spark.util.Utils -class PipedRDDSuite extends FunSuite with SharedSparkContext { +class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { test("basic pipe") { if (testCommandAvailable("cat")) { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala index 4434ed858c60c..f65349e3e3585 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.rdd -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.{TaskContext, Partition, SparkContext} +import org.apache.spark.{Partition, SparkContext, SparkFunSuite, TaskContext} /** * Tests whether scopes are passed from the RDD operation to the RDDs correctly. */ -class RDDOperationScopeSuite extends FunSuite with BeforeAndAfter { +class RDDOperationScopeSuite extends SparkFunSuite with BeforeAndAfter { private var sc: SparkContext = null private val scope1 = new RDDOperationScope("scope1") private val scope2 = new RDDOperationScope("scope2", Some(scope1)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index afc11bdc4d6ab..f6da9f98ad253 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -25,14 +25,12 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.scalatest.FunSuite - import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDDSuiteUtils._ import org.apache.spark.util.Utils -class RDDSuite extends FunSuite with SharedSparkContext { +class RDDSuite extends SparkFunSuite with SharedSparkContext { test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) @@ -338,10 +336,10 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("coalesced RDDs with locality") { - val data3 = sc.makeRDD(List((1,List("a","c")), (2,List("a","b","c")), (3,List("b")))) + val data3 = sc.makeRDD(List((1, List("a", "c")), (2, List("a", "b", "c")), (3, List("b")))) val coal3 = data3.coalesce(3) val list3 = coal3.partitions.flatMap(_.asInstanceOf[CoalescedRDDPartition].preferredLocation) - assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped") + assert(list3.sorted === Array("a", "b", "c"), "Locality preferences are dropped") // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5 val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i + 2)).map{ j => "m" + (j%6)}))) @@ -591,8 +589,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sc.emptyRDD.isEmpty()) assert(sc.parallelize(Seq[Int]()).isEmpty()) assert(!sc.parallelize(Seq(1)).isEmpty()) - assert(sc.parallelize(Seq(1,2,3), 3).filter(_ < 0).isEmpty()) - assert(!sc.parallelize(Seq(1,2,3), 3).filter(_ > 1).isEmpty()) + assert(sc.parallelize(Seq(1, 2, 3), 3).filter(_ < 0).isEmpty()) + assert(!sc.parallelize(Seq(1, 2, 3), 3).filter(_ > 1).isEmpty()) } test("sample preserves partitioner") { @@ -609,49 +607,49 @@ class RDDSuite extends FunSuite with SharedSparkContext { val data = sc.parallelize(1 to n, 2) for (num <- List(5, 20, 100)) { - val sample = data.takeSample(withReplacement=false, num=num) + val sample = data.takeSample(withReplacement = false, num = num) assert(sample.size === num) // Got exactly num elements assert(sample.toSet.size === num) // Elements are distinct assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 20, seed) + val sample = data.takeSample(withReplacement = false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements assert(sample.toSet.size === 20) // Elements are distinct assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 100, seed) + val sample = data.takeSample(withReplacement = false, 100, seed) assert(sample.size === 100) // Got only 100 elements assert(sample.toSet.size === 100) // Elements are distinct assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 20, seed) + val sample = data.takeSample(withReplacement = true, 20, seed) assert(sample.size === 20) // Got exactly 20 elements assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { - val sample = data.takeSample(withReplacement=true, num=20) + val sample = data.takeSample(withReplacement = true, num = 20) assert(sample.size === 20) // Got exactly 100 elements assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { - val sample = data.takeSample(withReplacement=true, num=n) + val sample = data.takeSample(withReplacement = true, num = n) assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, n, seed) + val sample = data.takeSample(withReplacement = true, n, seed) assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 2 * n, seed) + val sample = data.takeSample(withReplacement = true, 2 * n, seed) assert(sample.size === 2 * n) // Got exactly 200 elements // Chance of getting all distinct elements is still quite low, so test we got < 100 assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") @@ -691,7 +689,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("sortByKey") { - val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B")) + val data = sc.parallelize(Seq("5|50|A", "4|60|C", "6|40|B")) val col1 = Array("4|60|C", "5|50|A", "6|40|B") val col2 = Array("6|40|B", "5|50|A", "4|60|C") @@ -703,7 +701,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("sortByKey ascending parameter") { - val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B")) + val data = sc.parallelize(Seq("5|50|A", "4|60|C", "6|40|B")) val asc = Array("4|60|C", "5|50|A", "6|40|B") val desc = Array("6|40|B", "5|50|A", "4|60|C") @@ -764,9 +762,9 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("intersection strips duplicates in an input") { - val a = sc.parallelize(Seq(1,2,3,3)) - val b = sc.parallelize(Seq(1,1,2,3)) - val intersection = Array(1,2,3) + val a = sc.parallelize(Seq(1, 2, 3, 3)) + val b = sc.parallelize(Seq(1, 1, 2, 3)) + val intersection = Array(1, 2, 3) assert(a.intersection(b).collect().sorted === intersection) assert(b.intersection(a).collect().sorted === intersection) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala index fe695d85e29dd..194dc45d6e399 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala @@ -21,11 +21,11 @@ object RDDSuiteUtils { case class Person(first: String, last: String, age: Int) object AgeOrdering extends Ordering[Person] { - def compare(a:Person, b:Person): Int = a.age.compare(b.age) + def compare(a: Person, b: Person): Int = a.age.compare(b.age) } object NameOrdering extends Ordering[Person] { - def compare(a:Person, b:Person): Int = - implicitly[Ordering[Tuple2[String,String]]].compare((a.last, a.first), (b.last, b.first)) + def compare(a: Person, b: Person): Int = + implicitly[Ordering[Tuple2[String, String]]].compare((a.last, a.first), (b.last, b.first)) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index 64b1c24c47168..a7de9cabe7cc9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.rdd -import org.scalatest.FunSuite import org.scalatest.Matchers -import org.apache.spark.{Logging, SharedSparkContext} +import org.apache.spark.{Logging, SharedSparkContext, SparkFunSuite} -class SortingSuite extends FunSuite with SharedSparkContext with Matchers with Logging { +class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers with Logging { test("sortByKey") { val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) - assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) + assert(pairs.sortByKey().collect() === Array((0, 0), (1, 0), (2, 0), (3, 0))) } test("large array") { @@ -136,7 +135,7 @@ class SortingSuite extends FunSuite with SharedSparkContext with Matchers with L test("get a range of elements in an array not partitioned by a range partitioner") { val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) - val pairs = sc.parallelize(pairArr,10) + val pairs = sc.parallelize(pairArr, 10) val range = pairs.filterByRange(200, 800).collect() assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) } diff --git a/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala index 72596e86865b2..5d7b973fbd9ac 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.SharedSparkContext -import org.scalatest.FunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} object ZippedPartitionsSuite { def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { @@ -26,7 +25,7 @@ object ZippedPartitionsSuite { } } -class ZippedPartitionsSuite extends FunSuite with SharedSparkContext { +class ZippedPartitionsSuite extends SparkFunSuite with SharedSparkContext { test("print sizes") { val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index ae3339d80f9c6..1f0aa759b08da 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -24,15 +24,15 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkException, SparkConf} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} /** * Common tests for an RpcEnv implementation. */ -abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { +abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { var env: RpcEnv = _ @@ -42,7 +42,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } override def afterAll(): Unit = { - if(env != null) { + if (env != null) { env.shutdown() } } @@ -75,7 +75,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote" ,13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") try { @@ -338,7 +338,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { test("call receive in sequence") { // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it - for(i <- 0 until 100) { + for (i <- 0 until 100) { @volatile var result = 0 val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new ThreadSafeRpcEndpoint { override val rpcEnv = env diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index f77661ccbd1c5..34145691153ce 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.scheduler -import org.apache.spark.{LocalSparkContext, SparkConf, SparkException, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.util.{SerializableBuffer, AkkaUtils} -import org.scalatest.FunSuite - -class CoarseGrainedSchedulerBackendSuite extends FunSuite with LocalSparkContext { +class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { test("serialized task larger than akka frame size") { val conf = new SparkConf - conf.set("spark.akka.frameSize","1") - conf.set("spark.default.parallelism","1") + conf.set("spark.akka.frameSize", "1") + conf.set("spark.default.parallelism", "1") sc = new SparkContext("local-cluster[2 , 1 , 512]", "test", conf) val frameSize = AkkaUtils.maxFrameSizeBytes(sc.conf) val buffer = new SerializableBuffer(java.nio.ByteBuffer.allocate(2 * frameSize)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6a8ae29aae675..47b2868753c0e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal -import org.scalatest.{BeforeAndAfter, FunSuiteLike} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -68,7 +68,7 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception class DAGSchedulerSuite - extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts { + extends SparkFunSuite with BeforeAndAfter with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -254,7 +254,7 @@ class DAGSchedulerSuite test("[SPARK-3353] parent stage should have lower stage id") { sparkListener.stageByOrderOfExecution.clear() sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.stageByOrderOfExecution.length === 2) assert(sparkListener.stageByOrderOfExecution(0) < sparkListener.stageByOrderOfExecution(1)) } @@ -318,7 +318,7 @@ class DAGSchedulerSuite } test("cache location preferences w/ dependency") { - val baseRdd = new MyRDD(sc, 1, Nil) + val baseRdd = new MyRDD(sc, 1, Nil).cache() val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) cacheLocations(baseRdd.id -> 0) = Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) @@ -331,7 +331,7 @@ class DAGSchedulerSuite } test("regression test for getCacheLocs") { - val rdd = new MyRDD(sc, 3, Nil) + val rdd = new MyRDD(sc, 3, Nil).cache() cacheLocations(rdd.id -> 0) = Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) cacheLocations(rdd.id -> 1) = @@ -342,13 +342,40 @@ class DAGSchedulerSuite assert(locs === Seq(Seq("hostA", "hostB"), Seq("hostB", "hostC"), Seq("hostC", "hostD"))) } + /** + * This test ensures that if a particular RDD is cached, RDDs earlier in the dependency chain + * are not computed. It constructs the following chain of dependencies: + * +---+ shuffle +---+ +---+ +---+ + * | A |<--------| B |<---| C |<---| D | + * +---+ +---+ +---+ +---+ + * Here, B is derived from A by performing a shuffle, C has a one-to-one dependency on B, + * and D similarly has a one-to-one dependency on C. If none of the RDDs were cached, this + * set of RDDs would result in a two stage job: one ShuffleMapStage, and a ResultStage that + * reads the shuffled data from RDD A. This test ensures that if C is cached, the scheduler + * doesn't perform a shuffle, and instead computes the result using a single ResultStage + * that reads C's cached data. + */ + test("getMissingParentStages should consider all ancestor RDDs' cache statuses") { + val rddA = new MyRDD(sc, 1, Nil) + val rddB = new MyRDD(sc, 1, List(new ShuffleDependency(rddA, null))) + val rddC = new MyRDD(sc, 1, List(new OneToOneDependency(rddB))).cache() + val rddD = new MyRDD(sc, 1, List(new OneToOneDependency(rddC))) + cacheLocations(rddC.id -> 0) = + Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) + submit(rddD, Array(0)) + assert(scheduler.runningStages.size === 1) + // Make sure that the scheduler is running the final result stage. + // Because C is cached, the shuffle map stage to compute A does not need to be run. + assert(scheduler.runningStages.head.isInstanceOf[ResultStage]) + } + test("avoid exponential blowup when getting preferred locs list") { // Build up a complex dependency graph with repeated zip operations, without preferred locations var rdd: RDD[_] = new MyRDD(sc, 1, Nil) (1 to 30).foreach(_ => rdd = rdd.zip(rdd)) // getPreferredLocs runs quickly, indicating that exponential graph traversal is avoided. failAfter(10 seconds) { - val preferredLocs = scheduler.getPreferredLocs(rdd,0) + val preferredLocs = scheduler.getPreferredLocs(rdd, 0) // No preferred locations are returned. assert(preferredLocs.length === 0) } @@ -362,7 +389,7 @@ class DAGSchedulerSuite submit(unserializableRdd, Array(0)) assert(failure.getMessage.startsWith( "Job aborted due to stage failure: Task not serializable:")) - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty() @@ -372,7 +399,7 @@ class DAGSchedulerSuite submit(new MyRDD(sc, 1, Nil), Array(0)) failed(taskSets(0), "some failure") assert(failure.getMessage === "Job aborted due to stage failure: some failure") - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty() @@ -383,7 +410,7 @@ class DAGSchedulerSuite val jobId = submit(rdd, Array(0)) cancel(jobId) assert(failure.getMessage === s"Job $jobId cancelled ") - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty() @@ -435,7 +462,7 @@ class DAGSchedulerSuite assert(results === Map(0 -> 42)) assertDataStructuresEmpty() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.isEmpty) assert(sparkListener.successfulStages.contains(0)) } @@ -504,7 +531,7 @@ class DAGSchedulerSuite Map[Long, Any](), createFakeTaskInfo(), null)) - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(1)) // The second ResultTask fails, with a fetch failure for the output from the second mapper. @@ -516,7 +543,7 @@ class DAGSchedulerSuite createFakeTaskInfo(), null)) // The SparkListener should not receive redundant failure events. - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.size == 1) } @@ -565,7 +592,7 @@ class DAGSchedulerSuite // Listener bus should get told about the map stage failing, but not the reduce stage // (since the reduce stage hasn't been started yet). - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.toSet === Set(0)) assertDataStructuresEmpty() @@ -607,8 +634,8 @@ class DAGSchedulerSuite val listener1 = new FailureRecordingJobListener() val listener2 = new FailureRecordingJobListener() - submit(reduceRdd1, Array(0, 1), listener=listener1) - submit(reduceRdd2, Array(0, 1), listener=listener2) + submit(reduceRdd1, Array(0, 1), listener = listener1) + submit(reduceRdd2, Array(0, 1), listener = listener2) val stageFailureMessage = "Exception failure in map stage" failed(taskSets(0), stageFailureMessage) @@ -616,7 +643,7 @@ class DAGSchedulerSuite assert(cancelledStages.toSet === Set(0, 2)) // Make sure the listeners got told about both failed stages. - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.successfulStages.isEmpty) assert(sparkListener.failedStages.toSet === Set(0, 2)) @@ -678,9 +705,9 @@ class DAGSchedulerSuite } test("cached post-shuffle") { - val shuffleOneRdd = new MyRDD(sc, 2, Nil) + val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache() val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) - val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)) + val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache() val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) submit(finalRdd, Array(0)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index b52a8d11d147d..f681f21b6205e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -25,7 +25,7 @@ import scala.io.Source import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ -import org.scalatest.{FunSuiteLike, BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -39,7 +39,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} * logging events, whether the parsing of the file names is correct, and whether the logged events * can be read and deserialized into actual SparkListenerEvents. */ -class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter +class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter with Logging { import EventLoggingListenerSuite._ diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 950c6dc58e332..b8e466fab4506 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -18,14 +18,13 @@ package org.apache.spark.scheduler import org.apache.spark.storage.BlockManagerId -import org.scalatest.FunSuite -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import scala.util.Random -class MapStatusSuite extends FunSuite { +class MapStatusSuite extends SparkFunSuite { test("compressSize") { assert(MapStatus.compressSize(0L) === 0) diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 7078a7a12232a..a9036da9cc93d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -24,7 +24,7 @@ import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter} @@ -64,7 +64,7 @@ import scala.language.postfixOps * increments would be captured even though the commit in both tasks was executed * erroneously. */ -class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter { +class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { var outputCommitCoordinator: OutputCommitCoordinator = null var tempDir: File = null diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala index e8f461e2f56c9..467796d7c24b0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala @@ -19,15 +19,13 @@ package org.apache.spark.scheduler import java.util.Properties -import org.scalatest.FunSuite - -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} /** * Tests that pools and the associated scheduling algorithms for FIFO and fair scheduling work * correctly. */ -class PoolSuite extends FunSuite with LocalSparkContext { +class PoolSuite extends SparkFunSuite with LocalSparkContext { def createTaskSetManager(stageId: Int, numTasks: Int, taskScheduler: TaskSchedulerImpl) : TaskSetManager = { @@ -97,9 +95,9 @@ class PoolSuite extends FunSuite with LocalSparkContext { assert(rootPool.getSchedulableByName("3").weight === 1) val properties1 = new Properties() - properties1.setProperty("spark.scheduler.pool","1") + properties1.setProperty("spark.scheduler.pool", "1") val properties2 = new Properties() - properties2.setProperty("spark.scheduler.pool","2") + properties2.setProperty("spark.scheduler.pool", "2") val taskSetManager10 = createTaskSetManager(0, 1, taskScheduler) val taskSetManager11 = createTaskSetManager(1, 1, taskScheduler) diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index dabe4574b6456..ff3fa95ec32ae 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -21,10 +21,10 @@ import java.io.{File, PrintWriter} import java.net.URI import org.json4s.jackson.JsonMethods._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkContext, SPARK_VERSION} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{JsonProtocol, Utils} @@ -32,7 +32,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} /** * Test whether ReplayListenerBus replays events from logs correctly. */ -class ReplayListenerSuite extends FunSuite with BeforeAndAfter { +class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { private val fileSystem = Utils.getHadoopFileSystem("/", SparkHadoopUtil.get.newConfiguration(new SparkConf())) private var testDir: File = _ diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 825c616c0c3e0..651295b7344c5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -22,13 +22,13 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import scala.collection.JavaConversions._ -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ResetSystemProperties -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} -class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers +class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers with ResetSystemProperties { /** Length of time to wait while draining listener events. */ @@ -47,7 +47,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers // Starting listener bus should flush all buffered events bus.start(sc) - assert(bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) // After listener bus has stopped, posting events should not increment counter @@ -131,7 +131,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers rdd2.setName("Target RDD") rdd2.count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be {1} val (stageInfo, taskInfoMetrics) = listener.stageInfos.head @@ -156,7 +156,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers rdd3.setName("Trois") rdd1.count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be {1} val stageInfo1 = listener.stageInfos.keys.find(_.stageId == 0).get stageInfo1.rddInfos.size should be {1} // ParallelCollectionRDD @@ -165,7 +165,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers listener.stageInfos.clear() rdd2.count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be {1} val stageInfo2 = listener.stageInfos.keys.find(_.stageId == 1).get stageInfo2.rddInfos.size should be {3} // ParallelCollectionRDD, FilteredRDD, MappedRDD @@ -174,7 +174,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers listener.stageInfos.clear() rdd3.count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be {2} // Shuffle map stage + result stage val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 3).get stageInfo3.rddInfos.size should be {1} // ShuffledRDD @@ -190,7 +190,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers val rdd2 = rdd1.map(_.toString) sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true) - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be {1} val (stageInfo, _) = listener.stageInfos.head @@ -214,7 +214,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers val d = sc.parallelize(0 to 1e4.toInt, 64).map(w) d.count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be (1) val d2 = d.map { i => w(i) -> i * 2 }.setName("shuffle input 1") @@ -225,7 +225,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers d4.setName("A Cogroup") d4.collectAsMap() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be (4) listener.stageInfos.foreach { case (stageInfo, taskInfoMetrics) => /** @@ -281,7 +281,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers .reduce { case (x, y) => x } assert(result === 1.to(akkaFrameSize).toArray) - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) val TASK_INDEX = 0 assert(listener.startedTasks.contains(TASK_INDEX)) assert(listener.startedGettingResultTasks.contains(TASK_INDEX)) @@ -297,7 +297,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers val result = sc.parallelize(Seq(1), 1).map(2 * _).reduce { case (x, y) => x } assert(result === 2) - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) val TASK_INDEX = 0 assert(listener.startedTasks.contains(TASK_INDEX)) assert(listener.startedGettingResultTasks.isEmpty) @@ -352,7 +352,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers // Post events to all listeners, and wait until the queue is drained (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } - assert(bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) // The exception should be caught, and the event should be propagated to other listeners assert(bus.listenerThreadIsAlive) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index 623a687c359a2..d97fba00976d2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -17,17 +17,17 @@ package org.apache.spark.scheduler -import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.{SparkContext, LocalSparkContext} +import scala.collection.mutable -import org.scalatest.{FunSuite, BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import scala.collection.mutable +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.cluster.ExecutorInfo /** * Unit tests for SparkListener that require a local cluster. */ -class SparkListenerWithClusterSuite extends FunSuite with LocalSparkContext +class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter with BeforeAndAfterAll { /** Length of time to wait while draining listener events. */ @@ -41,12 +41,16 @@ class SparkListenerWithClusterSuite extends FunSuite with LocalSparkContext val listener = new SaveExecutorInfo sc.addSparkListener(listener) + // This test will check if the number of executors received by "SparkListener" is same as the + // number of all executors, so we need to wait until all executors are up + sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) rdd2.setName("Target RDD") rdd2.count() - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(listener.addedExecutorInfo.size == 2) assert(listener.addedExecutorInfo("0").totalCores == 1) assert(listener.addedExecutorInfo("1").totalCores == 1) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 83ae8701243e5..7c1adc1aef1b6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.scheduler import org.mockito.Mockito._ import org.mockito.Matchers.any -import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.apache.spark._ @@ -28,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} -class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { +class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { test("calls TaskCompletionListener after failure") { TaskContextSuite.completed = false diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index e3a3803e6483a..815caa79ff529 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -23,10 +23,10 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.control.NonFatal -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.storage.TaskResultBlockId /** @@ -71,7 +71,7 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule /** * Tests related to handling task results (both direct and indirect). */ -class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { +class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small // as we can make it) so the tests don't take too long. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index ffa4381969b68..a6d5232feb8de 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.scheduler -import org.scalatest.FunSuite - import org.apache.spark._ class FakeSchedulerBackend extends SchedulerBackend { @@ -28,7 +26,7 @@ class FakeSchedulerBackend extends SchedulerBackend { def defaultParallelism(): Int = 1 } -class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Logging { +class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with Logging { test("Scheduler does not always schedule tasks on the same workers") { sc = new SparkContext("local", "TaskSchedulerImplSuite") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 6198cea46ddf8..0060f3396dcde 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -22,8 +22,6 @@ import java.util.Random import scala.collection.mutable.ArrayBuffer import scala.collection.mutable -import org.scalatest.FunSuite - import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.{ManualClock, Utils} @@ -146,7 +144,7 @@ class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) { override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() } -class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { +class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logging { import TaskLocality.{ANY, PROCESS_LOCAL, NO_PREF, NODE_LOCAL, RACK_LOCAL} private val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala index 3fa0115e68259..e72285d03d3ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala @@ -18,22 +18,21 @@ package org.apache.spark.scheduler.cluster.mesos import org.mockito.Mockito._ -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class MemoryUtilsSuite extends FunSuite with MockitoSugar { +class MemoryUtilsSuite extends SparkFunSuite with MockitoSugar { test("MesosMemoryUtils should always override memoryOverhead when it's set") { val sparkConf = new SparkConf val sc = mock[SparkContext] when(sc.conf).thenReturn(sparkConf) - + // 384 > sc.executorMemory * 0.1 => 512 + 384 = 896 when(sc.executorMemory).thenReturn(512) assert(MemoryUtils.calculateTotalMemory(sc) === 896) - + // 384 < sc.executorMemory * 0.1 => 4096 + (4096 * 0.1) = 4505.6 when(sc.executorMemory).thenReturn(4096) assert(MemoryUtils.calculateTotalMemory(sc) === 4505) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index ab863f3d8d672..68df46a41ddc8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -30,16 +30,15 @@ import org.apache.mesos.SchedulerDriver import org.mockito.Matchers._ import org.mockito.Mockito._ import org.mockito.{ArgumentCaptor, Matchers} -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, TaskDescription, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} -class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with MockitoSugar { +class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { test("check spark-class location correctly") { val conf = new SparkConf @@ -80,11 +79,11 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Mo .set("spark.mesos.executor.docker.image", "spark/mock") .set("spark.mesos.executor.docker.volumes", "/a,/b:/b,/c:/c:rw,/d:ro,/e:/e:ro") .set("spark.mesos.executor.docker.portmaps", "80:8080,53:53:tcp") - + val listenerBus = mock[LiveListenerBus] listenerBus.post( SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) - + val sc = mock[SparkContext] when(sc.executorMemory).thenReturn(100) when(sc.getSparkHome()).thenReturn(Option("/spark-home")) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala index eebcba40f8a1c..5a81bb335fdb7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler.cluster.mesos import java.nio.ByteBuffer -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class MesosTaskLaunchDataSuite extends FunSuite { +class MesosTaskLaunchDataSuite extends SparkFunSuite { test("serialize and deserialize data must be same") { val serializedTask = ByteBuffer.allocate(40) (Range(100, 110).map(serializedTask.putInt(_))) diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala index f28e29e9b8d8e..f5cef1caaf1ac 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala @@ -19,16 +19,15 @@ package org.apache.spark.scheduler.mesos import java.util.Date -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.{LocalSparkContext, SparkConf} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} -class MesosClusterSchedulerSuite extends FunSuite with LocalSparkContext with MockitoSugar { +class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { private val command = new Command("mainClass", Seq("arg"), null, null, null, null) diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala index ed4d8ce632e16..329a2b6dad831 100644 --- a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.serializer -import org.apache.spark.SparkConf -import org.scalatest.FunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} -class JavaSerializerSuite extends FunSuite { +class JavaSerializerSuite extends SparkFunSuite { test("JavaSerializer instances are serializable") { val serializer = new JavaSerializer(new SparkConf()) val instance = serializer.newInstance() diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 054a4c64897a9..63a8480c9b57b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -20,12 +20,11 @@ package org.apache.spark.serializer import org.apache.spark.util.Utils import com.esotericsoftware.kryo.Kryo -import org.scalatest.FunSuite -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils} +import org.apache.spark._ import org.apache.spark.serializer.KryoDistributedTest._ -class KryoSerializerDistributedSuite extends FunSuite { +class KryoSerializerDistributedSuite extends SparkFunSuite { test("kryo objects are serialised consistently in different processes") { val conf = new SparkConf(false) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index da98d09184735..a9b209ccfc76e 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.serializer -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SparkContext import org.apache.spark.LocalSparkContext import org.apache.spark.SparkException -class KryoSerializerResizableOutputSuite extends FunSuite { +class KryoSerializerResizableOutputSuite extends SparkFunSuite { // trial and error showed this will not serialize with 1mb buffer val x = (1 to 400000).toArray diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 5faf108b394a1..23a1fdb0f5009 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -17,57 +17,59 @@ package org.apache.spark.serializer -import java.io.ByteArrayOutputStream +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo -import org.scalatest.FunSuite -import org.apache.spark.{SharedSparkContext, SparkConf} +import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ import org.apache.spark.storage.BlockManagerId -class KryoSerializerSuite extends FunSuite with SharedSparkContext { +class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) - test("configuration limits") { - val conf1 = conf.clone() + test("SPARK-7392 configuration limits") { val kryoBufferProperty = "spark.kryoserializer.buffer" val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max" - conf1.set(kryoBufferProperty, "64k") - conf1.set(kryoBufferMaxProperty, "64m") - new KryoSerializer(conf1).newInstance() + + def newKryoInstance( + conf: SparkConf, + bufferSize: String = "64k", + maxBufferSize: String = "64m"): SerializerInstance = { + val kryoConf = conf.clone() + kryoConf.set(kryoBufferProperty, bufferSize) + kryoConf.set(kryoBufferMaxProperty, maxBufferSize) + new KryoSerializer(kryoConf).newInstance() + } + + // test default values + newKryoInstance(conf, "64k", "64m") // 2048m = 2097152k - conf1.set(kryoBufferProperty, "2097151k") - conf1.set(kryoBufferMaxProperty, "64m") // should not throw exception when kryoBufferMaxProperty < kryoBufferProperty - new KryoSerializer(conf1).newInstance() - conf1.set(kryoBufferMaxProperty, "2097151k") - new KryoSerializer(conf1).newInstance() - val conf2 = conf.clone() - conf2.set(kryoBufferProperty, "2048m") - val thrown1 = intercept[IllegalArgumentException](new KryoSerializer(conf2).newInstance()) + newKryoInstance(conf, "2097151k", "64m") + // test maximum size with unit of KiB + newKryoInstance(conf, "2097151k", "2097151k") + // should throw exception with bufferSize out of bound + val thrown1 = intercept[IllegalArgumentException](newKryoInstance(conf, "2048m")) assert(thrown1.getMessage.contains(kryoBufferProperty)) - val conf3 = conf.clone() - conf3.set(kryoBufferMaxProperty, "2048m") - val thrown2 = intercept[IllegalArgumentException](new KryoSerializer(conf3).newInstance()) + // should throw exception with maxBufferSize out of bound + val thrown2 = intercept[IllegalArgumentException]( + newKryoInstance(conf, maxBufferSize = "2048m")) assert(thrown2.getMessage.contains(kryoBufferMaxProperty)) - val conf4 = conf.clone() - conf4.set(kryoBufferProperty, "2g") - conf4.set(kryoBufferMaxProperty, "3g") - val thrown3 = intercept[IllegalArgumentException](new KryoSerializer(conf4).newInstance()) + // should throw exception when both bufferSize and maxBufferSize out of bound + // exception should only contain "spark.kryoserializer.buffer" + val thrown3 = intercept[IllegalArgumentException](newKryoInstance(conf, "2g", "3g")) assert(thrown3.getMessage.contains(kryoBufferProperty)) assert(!thrown3.getMessage.contains(kryoBufferMaxProperty)) - val conf5 = conf.clone() - conf5.set(kryoBufferProperty, "8m") - conf5.set(kryoBufferMaxProperty, "9m") - new KryoSerializer(conf5).newInstance() + // test configuration with mb is supported properly + newKryoInstance(conf, "8m", "9m") } - + test("basic types") { val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { @@ -106,7 +108,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { check((1, 1)) check((1, 1L)) check((1L, 1)) - check((1L, 1L)) + check((1L, 1L)) check((1.0, 1)) check((1, 1.0)) check((1.0, 1.0)) @@ -144,7 +146,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) check(List( mutable.HashMap("one" -> 1, "two" -> 2), - mutable.HashMap(1->"one",2->"two",3->"three"))) + mutable.HashMap(1->"one", 2->"two", 3->"three"))) } test("ranges") { @@ -358,6 +360,41 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } } +class KryoSerializerAutoResetDisabledSuite extends SparkFunSuite with SharedSparkContext { + conf.set("spark.serializer", classOf[KryoSerializer].getName) + conf.set("spark.kryo.registrator", classOf[RegistratorWithoutAutoReset].getName) + conf.set("spark.kryo.referenceTracking", "true") + conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.sort.bypassMergeThreshold", "200") + + test("sort-shuffle with bypassMergeSort (SPARK-7873)") { + val myObject = ("Hello", "World") + assert(sc.parallelize(Seq.fill(100)(myObject)).repartition(2).collect().toSet === Set(myObject)) + } + + test("calling deserialize() after deserializeStream()") { + val serInstance = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance] + assert(!serInstance.getAutoReset()) + val hello = "Hello" + val world = "World" + // Here, we serialize the same value twice, so the reference-tracking should cause us to store + // references to some of these values + val helloHello = serInstance.serialize((hello, hello)) + // Here's a stream which only contains one value + val worldWorld: Array[Byte] = { + val baos = new ByteArrayOutputStream() + val serStream = serInstance.serializeStream(baos) + serStream.writeObject(world) + serStream.writeObject(world) + serStream.close() + baos.toByteArray + } + val deserializationStream = serInstance.deserializeStream(new ByteArrayInputStream(worldWorld)) + assert(deserializationStream.readValue[Any]() === world) + deserializationStream.close() + assert(serInstance.deserialize[Any](helloHello) === (hello, hello)) + } +} class ClassLoaderTestingObject diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index 433fd6bb4a11d..c657414e9e5c3 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -17,19 +17,17 @@ package org.apache.spark.serializer -import org.scalatest.FunSuite - -import org.apache.spark.{SharedSparkContext, SparkException} +import org.apache.spark.{SharedSparkContext, SparkException, SparkFunSuite} import org.apache.spark.rdd.RDD /* A trivial (but unserializable) container for trivial functions */ class UnserializableClass { def op[T](x: T): String = x.toString - + def pred[T](x: T): Boolean = x.toString.length % 2 == 0 } -class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext { +class ProactiveClosureSerializationSuite extends SparkFunSuite with SharedSparkContext { def fixture: (RDD[String], UnserializableClass) = { (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) @@ -47,7 +45,7 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex // iterating over a map from transformation names to functions that perform that // transformation on a given RDD, creating one test case for each - for (transformation <- + for (transformation <- Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _, @@ -60,24 +58,24 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex val ex = intercept[SparkException] { xf(data, uc) } - assert(ex.getMessage.contains("Task not serializable"), + assert(ex.getMessage.contains("Task not serializable"), s"RDD.$name doesn't proactively throw NotSerializableException") } } - private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.map(y=>uc.op(y)) + private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.map(y => uc.op(y)) + + private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.flatMap(y => Seq(uc.op(y))) - private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.flatMap(y=>Seq(uc.op(y))) + private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.filter(y => uc.pred(y)) - private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.filter(y=>uc.pred(y)) + private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitions(_.map(y => uc.op(y))) - private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapPartitions(_.map(y=>uc.op(y))) + private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y))) - private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y))) - } diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala index e62828c4fbac6..2707bb53bc383 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.serializer import java.io.{ObjectOutput, ObjectInput} -import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.scalatest.BeforeAndAfterEach +import org.apache.spark.SparkFunSuite -class SerializationDebuggerSuite extends FunSuite with BeforeAndAfterEach { + +class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach { import SerializationDebugger.find diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala index bb34033fe9e7e..4ce3b941bea55 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala @@ -21,9 +21,9 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.util.Random -import org.scalatest.{Assertions, FunSuite} +import org.scalatest.Assertions -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset /** @@ -31,7 +31,7 @@ import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset * describe properties of the serialized stream, such as * [[Serializer.supportsRelocationOfSerializedObjects]]. */ -class SerializerPropertiesSuite extends FunSuite { +class SerializerPropertiesSuite extends SparkFunSuite { import SerializerPropertiesSuite._ diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala index 86fcf447287f7..c1e0a29a34bb1 100644 --- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -32,16 +32,19 @@ class TestSerializer extends Serializer { class TestSerializerInstance extends SerializerInstance { - override def serialize[T: ClassTag](t: T): ByteBuffer = ??? + override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException - override def serializeStream(s: OutputStream): SerializationStream = ??? + override def serializeStream(s: OutputStream): SerializationStream = + throw new UnsupportedOperationException override def deserializeStream(s: InputStream): TestDeserializationStream = new TestDeserializationStream - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = ??? + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = ??? + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException } diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index e0e646f0a3652..96778c9ebafb1 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.shuffle -import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.CountDownLatch -class ShuffleMemoryManagerSuite extends FunSuite with Timeouts { +import org.apache.spark.SparkFunSuite + +class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { /** Launch a thread with the given body block and return it. */ private def startThread(name: String)(body: => Unit): Thread = { val thread = new Thread("ShuffleMemorySuite " + name) { diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 0537bf66ad020..491dc3659e184 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -21,16 +21,14 @@ import java.io.{File, FileWriter} import scala.language.reflectiveCalls -import org.scalatest.FunSuite - -import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FileShuffleBlockResolver import org.apache.spark.storage.{ShuffleBlockId, FileSegment} -class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { +class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { private val testConf = new SparkConf(false) private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala new file mode 100644 index 0000000000000..542f8f45125a4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.File +import java.util.UUID + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark._ +import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} +import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer} +import org.apache.spark.storage._ +import org.apache.spark.util.Utils + +class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ + + private var taskMetrics: TaskMetrics = _ + private var shuffleWriteMetrics: ShuffleWriteMetrics = _ + private var tempDir: File = _ + private var outputFile: File = _ + private val conf: SparkConf = new SparkConf(loadDefaults = false) + private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() + private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] + private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0) + private val serializer: Serializer = new JavaSerializer(conf) + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + outputFile = File.createTempFile("shuffle", null, tempDir) + shuffleWriteMetrics = new ShuffleWriteMetrics + taskMetrics = new TaskMetrics + taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) + MockitoAnnotations.initMocks(this) + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(blockManager.getDiskWriter( + any[BlockId], + any[File], + any[SerializerInstance], + anyInt(), + any[ShuffleWriteMetrics] + )).thenAnswer(new Answer[BlockObjectWriter] { + override def answer(invocation: InvocationOnMock): BlockObjectWriter = { + val args = invocation.getArguments + new DiskBlockObjectWriter( + args(0).asInstanceOf[BlockId], + args(1).asInstanceOf[File], + args(2).asInstanceOf[SerializerInstance], + args(3).asInstanceOf[Int], + compressStream = identity, + syncWrites = false, + args(4).asInstanceOf[ShuffleWriteMetrics] + ) + } + }) + when(diskBlockManager.createTempShuffleBlock()).thenAnswer( + new Answer[(TempShuffleBlockId, File)] { + override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = { + val blockId = new TempShuffleBlockId(UUID.randomUUID) + val file = File.createTempFile(blockId.toString, null, tempDir) + blockIdToFileMap.put(blockId, file) + temporaryFilesCreated.append(file) + (blockId, file) + } + }) + when(diskBlockManager.getFile(any[BlockId])).thenAnswer( + new Answer[File] { + override def answer(invocation: InvocationOnMock): File = { + blockIdToFileMap.get(invocation.getArguments.head.asInstanceOf[BlockId]).get + } + }) + } + + override def afterEach(): Unit = { + Utils.deleteRecursively(tempDir) + blockIdToFileMap.clear() + temporaryFilesCreated.clear() + } + + test("write empty iterator") { + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + new SparkConf(loadDefaults = false), + blockManager, + new HashPartitioner(7), + shuffleWriteMetrics, + serializer + ) + writer.insertAll(Iterator.empty) + val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) + assert(partitionLengths.sum === 0) + assert(outputFile.exists()) + assert(outputFile.length() === 0) + assert(temporaryFilesCreated.isEmpty) + assert(shuffleWriteMetrics.shuffleBytesWritten === 0) + assert(shuffleWriteMetrics.shuffleRecordsWritten === 0) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } + + test("write with some empty partitions") { + def records: Iterator[(Int, Int)] = + Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + new SparkConf(loadDefaults = false), + blockManager, + new HashPartitioner(7), + shuffleWriteMetrics, + serializer + ) + writer.insertAll(records) + assert(temporaryFilesCreated.nonEmpty) + val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) + assert(partitionLengths.sum === outputFile.length()) + assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted + assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } + + test("cleanup of intermediate files after errors") { + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + new SparkConf(loadDefaults = false), + blockManager, + new HashPartitioner(7), + shuffleWriteMetrics, + serializer + ) + intercept[SparkException] { + writer.insertAll((0 until 100000).iterator.map(i => { + if (i == 99990) { + throw new SparkException("Intentional failure") + } + (i, i) + })) + } + assert(temporaryFilesCreated.nonEmpty) + writer.stop() + assert(temporaryFilesCreated.count(_.exists()) === 0) + } + +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala new file mode 100644 index 0000000000000..34b4984f12c09 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import org.mockito.Mockito._ + +import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite} + +class SortShuffleWriterSuite extends SparkFunSuite { + + import SortShuffleWriter._ + + test("conditions for bypassing merge-sort") { + val conf = new SparkConf(loadDefaults = false) + val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS) + val ord = implicitly[Ordering[Int]] + + // Numbers of partitions that are above and below the default bypassMergeThreshold + val FEW_PARTITIONS = 50 + val MANY_PARTITIONS = 10000 + + // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high + assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None)) + assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None)) + + // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions + assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord))) + assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None)) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala index 49a04a2a45280..a73e94e05575e 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.shuffle.unsafe import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} @@ -29,7 +29,7 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are * performed in other suites. */ -class UnsafeShuffleManagerSuite extends FunSuite with Matchers { +class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { import UnsafeShuffleManager.canUseUnsafeShuffle diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamTest.scala b/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala similarity index 73% rename from core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamTest.scala rename to core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala index 5274df904d395..63b0e77629dde 100644 --- a/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamTest.scala +++ b/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala @@ -16,14 +16,21 @@ */ package org.apache.spark.status.api.v1 -import org.scalatest.{Matchers, FunSuite} +import javax.ws.rs.WebApplicationException -class SimpleDateParamTest extends FunSuite with Matchers { +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class SimpleDateParamSuite extends SparkFunSuite with Matchers { test("date parsing") { new SimpleDateParam("2015-02-20T23:21:17.190GMT").timestamp should be (1424474477190L) - new SimpleDateParam("2015-02-20T17:21:17.190CST").timestamp should be (1424474477190L) - new SimpleDateParam("2015-02-20").timestamp should be (1424390400000L) // GMT + new SimpleDateParam("2015-02-20T17:21:17.190EST").timestamp should be (1424470877190L) + new SimpleDateParam("2015-02-20").timestamp should be (1424390400000L) // GMT + intercept[WebApplicationException] { + new SimpleDateParam("invalid date") + } } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index b647e8a6728ec..89ed031b6fcd1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.storage -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class BlockIdSuite extends FunSuite { +class BlockIdSuite extends SparkFunSuite { def assertSame(id1: BlockId, id2: BlockId) { assert(id1.name === id2.name) assert(id1.hashCode === id2.hashCode) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index f647200402ecb..0f5ba46f69c2f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -23,11 +23,11 @@ import scala.language.implicitConversions import scala.language.postfixOps import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark.rpc.RpcEnv -import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} +import org.apache.spark._ import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus @@ -36,7 +36,7 @@ import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.StorageLevel._ /** Testsuite that tests block replication in BlockManager */ -class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { +class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with BeforeAndAfter { private val conf = new SparkConf(false) var rpcEnv: RpcEnv = null diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 151955ef7f435..bcee901f5dd5f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ import org.apache.spark.rpc.RpcEnv -import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} +import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus @@ -41,7 +41,7 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util._ -class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach +class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { private val conf = new SparkConf(false) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala index 43ef469c1fd48..7bdea724fea58 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -18,16 +18,28 @@ package org.apache.spark.storage import java.io.File -import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils -class BlockObjectWriterSuite extends FunSuite { +class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + var tempDir: File = _ + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + } + + override def afterEach(): Unit = { + Utils.deleteRecursively(tempDir) + } + test("verify write metrics") { - val file = new File(Utils.createTempDir(), "somefile") + val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) @@ -49,7 +61,7 @@ class BlockObjectWriterSuite extends FunSuite { } test("verify write metrics on revert") { - val file = new File(Utils.createTempDir(), "somefile") + val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) @@ -72,7 +84,7 @@ class BlockObjectWriterSuite extends FunSuite { } test("Reopening a closed block writer") { - val file = new File(Utils.createTempDir(), "somefile") + val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) @@ -83,4 +95,79 @@ class BlockObjectWriterSuite extends FunSuite { writer.open() } } + + test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.commitAndClose() + val bytesWritten = writeMetrics.shuffleBytesWritten + assert(writeMetrics.shuffleRecordsWritten === 1000) + writer.revertPartialWritesAndClose() + assert(writeMetrics.shuffleRecordsWritten === 1000) + assert(writeMetrics.shuffleBytesWritten === bytesWritten) + } + + test("commitAndClose() should be idempotent") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.commitAndClose() + val bytesWritten = writeMetrics.shuffleBytesWritten + val writeTime = writeMetrics.shuffleWriteTime + assert(writeMetrics.shuffleRecordsWritten === 1000) + writer.commitAndClose() + assert(writeMetrics.shuffleRecordsWritten === 1000) + assert(writeMetrics.shuffleBytesWritten === bytesWritten) + assert(writeMetrics.shuffleWriteTime === writeTime) + } + + test("revertPartialWritesAndClose() should be idempotent") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.revertPartialWritesAndClose() + val bytesWritten = writeMetrics.shuffleBytesWritten + val writeTime = writeMetrics.shuffleWriteTime + assert(writeMetrics.shuffleRecordsWritten === 0) + writer.revertPartialWritesAndClose() + assert(writeMetrics.shuffleRecordsWritten === 0) + assert(writeMetrics.shuffleBytesWritten === bytesWritten) + assert(writeMetrics.shuffleWriteTime === writeTime) + } + + test("fileSegment() can only be called after commitAndClose() has been called") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + intercept[IllegalStateException] { + writer.fileSegment() + } + writer.close() + } + + test("commitAndClose() without ever opening or writing") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + writer.commitAndClose() + assert(writer.fileSegment().length === 0) + } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index bc5c74c126b74..688f56f4665f3 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -22,12 +22,12 @@ import java.io.{File, FileWriter} import scala.language.reflectiveCalls import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils -class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { +class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) private var rootDir0: File = _ private var rootDir1: File = _ diff --git a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala index bcf138b5ee6d0..b21c91f75d5c7 100644 --- a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala @@ -16,11 +16,10 @@ */ package org.apache.spark.storage -import org.scalatest.FunSuite -import org.apache.spark.{SharedSparkContext, SparkConf, LocalSparkContext, SparkContext} +import org.apache.spark._ -class FlatmapIteratorSuite extends FunSuite with LocalSparkContext { +class FlatmapIteratorSuite extends SparkFunSuite with LocalSparkContext { /* Tests the ability of Spark to deal with user provided iterators from flatMap * calls, that may generate more data then available memory. In any * memory based persistance Spark will unroll the iterator into an ArrayBuffer @@ -59,10 +58,10 @@ class FlatmapIteratorSuite extends FunSuite with LocalSparkContext { .set("spark.serializer.objectStreamReset", "10") sc = new SparkContext(sconf) val expand_size = 500 - val data = sc.parallelize(Seq(1,2)). + val data = sc.parallelize(Seq(1, 2)). flatMap(x => Stream.range(1, expand_size). - map(y => "%d: string test %d".format(y,x))) - var persisted = data.persist(StorageLevel.MEMORY_ONLY_SER) + map(y => "%d: string test %d".format(y, x))) + val persisted = data.persist(StorageLevel.MEMORY_ONLY_SER) assert(persisted.filter(_.startsWith("1:")).count()===2) } diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index b47157f8331cc..ac6fec56bbf4f 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -20,15 +20,15 @@ package org.apache.spark.storage import java.io.File import org.apache.spark.util.Utils -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} /** * Tests for the spark.local.dir and SPARK_LOCAL_DIRS configuration options. */ -class LocalDirsSuite extends FunSuite with BeforeAndAfter { +class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { before { Utils.clearLocalRootDirs() diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 2080c432d77db..2a7fe67ad8585 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -26,15 +26,14 @@ import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, TaskContextImpl} +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.serializer.TestSerializer -class ShuffleBlockFetcherIteratorSuite extends FunSuite { +class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala index 3a45875391e29..1a199beb3558f 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.storage -import org.scalatest.FunSuite -import org.apache.spark.Success +import org.apache.spark.{SparkFunSuite, Success} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ /** * Test the behavior of StorageStatusListener in response to all relevant events. */ -class StorageStatusListenerSuite extends FunSuite { +class StorageStatusListenerSuite extends SparkFunSuite { private val bm1 = BlockManagerId("big", "dog", 1) private val bm2 = BlockManagerId("fat", "duck", 2) private val taskInfo1 = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false) diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index 17193ddbfd894..1d5a813a4d336 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.storage -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite /** * Test various functionalities in StorageUtils and StorageStatus. */ -class StorageSuite extends FunSuite { +class StorageSuite extends SparkFunSuite { private val memAndDisk = StorageLevel.MEMORY_AND_DISK // For testing add, update, and remove (for non-RDD blocks) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index b6f5accef0cef..33712f1bfa782 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.status.api.v1.{JacksonMessageWriter, StageStatus} /** * Selenium tests for the Spark Web UI. */ -class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll { +class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll { implicit var webDriver: WebDriver = _ implicit val formats = DefaultFormats @@ -483,11 +483,11 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before val jobsJson = getJson(sc.ui.get, "jobs") jobsJson.children.size should be (expJobInfo.size) for { - (job @ JObject(_),idx) <- jobsJson.children.zipWithIndex + (job @ JObject(_), idx) <- jobsJson.children.zipWithIndex id = (job \ "jobId").extract[String] name = (job \ "name").extract[String] } { - withClue(s"idx = $idx; id = $id; name = ${name.substring(0,20)}") { + withClue(s"idx = $idx; id = $id; name = ${name.substring(0, 20)}") { id should be (expJobInfo(idx)._1) name should include (expJobInfo(idx)._2) } @@ -540,12 +540,12 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before goToUi(sc, "/stages/stage/?id=12&attempt=0") find("no-info").get.text should be ("No information to display for Stage 12 (Attempt 0)") - val badStage = HistoryServerSuite.getContentAndCode(apiUrl(sc.ui.get,"stages/12/0")) + val badStage = HistoryServerSuite.getContentAndCode(apiUrl(sc.ui.get, "stages/12/0")) badStage._1 should be (HttpServletResponse.SC_NOT_FOUND) badStage._2 should be (None) badStage._3 should be (Some("unknown stage: 12")) - val badAttempt = HistoryServerSuite.getContentAndCode(apiUrl(sc.ui.get,"stages/19/15")) + val badAttempt = HistoryServerSuite.getContentAndCode(apiUrl(sc.ui.get, "stages/19/15")) badAttempt._1 should be (HttpServletResponse.SC_NOT_FOUND) badAttempt._2 should be (None) badAttempt._3 should be (Some("unknown attempt for stage 19. Found attempts: [0]")) diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 77a038dc1720d..8f9502b5673d1 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -23,14 +23,13 @@ import scala.io.Source import scala.util.{Failure, Success, Try} import org.eclipse.jetty.servlet.ServletContextHandler -import org.scalatest.FunSuite import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.LocalSparkContext._ -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class UISuite extends FunSuite { +class UISuite extends SparkFunSuite { /** * Create a test SparkContext with the SparkUI enabled. diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 967dd0821ebd0..56f7b9cf1f358 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.ui.jobs import java.util.Properties -import org.scalatest.FunSuite import org.scalatest.Matchers import org.apache.spark._ @@ -28,7 +27,7 @@ import org.apache.spark.executor._ import org.apache.spark.scheduler._ import org.apache.spark.util.Utils -class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { +class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers { val jobSubmissionTime = 1421191042750L val jobCompletionTime = 1421191296660L diff --git a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala index c1126f3af52e6..86b078851851f 100644 --- a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.ui.scope -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SparkListenerStageSubmitted import org.apache.spark.scheduler.SparkListenerStageCompleted @@ -28,7 +26,7 @@ import org.apache.spark.scheduler.SparkListenerJobStart /** * Tests that this listener populates and cleans up its data structures properly. */ -class RDDOperationGraphListenerSuite extends FunSuite { +class RDDOperationGraphListenerSuite extends SparkFunSuite { private var jobIdCounter = 0 private var stageIdCounter = 0 private val maxRetainedJobs = 10 diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 7b38e6d9473e1..37e2670de9685 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.ui.storage -import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.Success +import org.scalatest.BeforeAndAfter +import org.apache.spark.{SparkFunSuite, Success} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.storage._ @@ -26,7 +26,7 @@ import org.apache.spark.storage._ /** * Test various functionality in the StorageListener that supports the StorageTab. */ -class StorageTabSuite extends FunSuite with BeforeAndAfter { +class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { private var bus: LiveListenerBus = _ private var storageStatusListener: StorageStatusListener = _ private var storageListener: StorageListener = _ @@ -169,7 +169,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { test("verify StorageTab contains all cached rdds") { val rddInfo0 = new RDDInfo(0, "rdd0", 1, memOnly, Seq(4)) - val rddInfo1 = new RDDInfo(1, "rdd1", 1 ,memOnly, Seq(4)) + val rddInfo1 = new RDDInfo(1, "rdd1", 1, memOnly, Seq(4)) val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo0), Seq.empty, "details") val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfo1), Seq.empty, "details") val taskMetrics0 = new TaskMetrics diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index bec79fc4dc8f7..6c40685484ed4 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.util import java.util.concurrent.TimeoutException import akka.actor.ActorNotFound -import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.rpc.RpcEnv @@ -32,7 +31,7 @@ import org.apache.spark.SSLSampleConfigs._ /** * Test the AkkaUtils with various security settings. */ -class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { +class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties { test("remote fetch security bad password") { val conf = new SparkConf @@ -138,7 +137,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(securityManagerGood.isAuthenticationEnabled() === true) - val slaveRpcEnv =RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood) val slaveTracker = new MapOutputTrackerWorker(conf) slaveTracker.trackerEndpoint = slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 7b165fe28bdd3..70cd27b04347d 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -20,14 +20,12 @@ package org.apache.spark.util import java.io.NotSerializableException import java.util.Random -import org.scalatest.FunSuite - import org.apache.spark.LocalSparkContext._ -import org.apache.spark.{TaskContext, SparkContext, SparkException} +import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.partial.CountEvaluator import org.apache.spark.rdd.RDD -class ClosureCleanerSuite extends FunSuite { +class ClosureCleanerSuite extends SparkFunSuite { test("closures inside an object") { assert(TestObject.run() === 30) // 6 + 7 + 8 + 9 } @@ -203,7 +201,7 @@ object TestObjectWithNestedReturns { def run(): Int = { withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map {x => + nums.map {x => // this return is fine since it will not transfer control outside the closure def foo(): Int = { return 5; 1 } foo() diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 59456790e89f0..3147c937769d2 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -21,16 +21,16 @@ import java.io.NotSerializableException import scala.collection.mutable -import org.scalatest.{BeforeAndAfterAll, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfterAll, PrivateMethodTester} -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.{SparkContext, SparkException, SparkFunSuite} import org.apache.spark.serializer.SerializerInstance /** * Another test suite for the closure cleaner that is finer-grained. * For tests involving end-to-end Spark jobs, see {{ClosureCleanerSuite}}. */ -class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { +class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with PrivateMethodTester { // Start a SparkContext so that the closure serializer is accessible // We do not actually use this explicitly otherwise diff --git a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala index 3755d43e25ea8..688fcd9f9aaba 100644 --- a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class CompletionIteratorSuite extends FunSuite { +class CompletionIteratorSuite extends SparkFunSuite { test("basic test") { var numTimesCompleted = 0 val iter = List(1, 2, 3).iterator diff --git a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala index 090d48ec921a1..cdd6555697c23 100644 --- a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.util -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite + /** * */ -class DistributionSuite extends FunSuite with Matchers { +class DistributionSuite extends SparkFunSuite with Matchers { test("summary") { val d = new Distribution((1 to 100).toArray.map{_.toDouble}) val stats = d.statCounter diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index 47b535206c949..b207d497f33c2 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -25,9 +25,10 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts -import org.scalatest.FunSuite -class EventLoopSuite extends FunSuite with Timeouts { +import org.apache.spark.SparkFunSuite + +class EventLoopSuite extends SparkFunSuite with Timeouts { test("EventLoop") { val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int] diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index c05317534cddf..2b76ae1f8a24b 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -22,15 +22,15 @@ import java.io._ import scala.collection.mutable.HashSet import scala.reflect._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy, FileAppender} -class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { +class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { val testFile = new File(Utils.createTempDir(), "FileAppenderSuite-test").getAbsoluteFile diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 0d9126f23ccc5..e0ef9c70a5fc3 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import scala.collection.Map import org.json4s.jackson.JsonMethods._ -import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.executor._ @@ -33,7 +32,7 @@ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ import org.apache.spark.storage._ -class JsonProtocolSuite extends FunSuite { +class JsonProtocolSuite extends SparkFunSuite { val jobSubmissionTime = 1421191042750L val jobCompletionTime = 1421191296660L diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index 87de90bb0dfb0..42125547436cb 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -19,11 +19,9 @@ package org.apache.spark.util import java.net.URLClassLoader -import org.scalatest.FunSuite +import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TestUtils} -import org.apache.spark.{SparkContext, SparkException, TestUtils} - -class MutableURLClassLoaderSuite extends FunSuite { +class MutableURLClassLoaderSuite extends SparkFunSuite { val urls2 = List(TestUtils.createJarWithClasses( classNames = Seq("FakeClass1", "FakeClass2", "FakeClass3"), diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala index 403dcb03bd6e5..4b7164d8acbce 100644 --- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala @@ -21,10 +21,11 @@ import java.util.NoSuchElementException import scala.collection.mutable.Buffer -import org.scalatest.FunSuite import org.scalatest.Matchers -class NextIteratorSuite extends FunSuite with Matchers { +import org.apache.spark.SparkFunSuite + +class NextIteratorSuite extends SparkFunSuite with Matchers { test("one iteration") { val i = new StubIterator(Buffer(1)) i.hasNext should be (true) diff --git a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala index bad1aa99952cf..c58db5e606f7c 100644 --- a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala +++ b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala @@ -22,12 +22,14 @@ import java.util.Properties import org.apache.commons.lang3.SerializationUtils import org.scalatest.{BeforeAndAfterEach, Suite} +import org.apache.spark.SparkFunSuite + /** * Mixin for automatically resetting system properties that are modified in ScalaTest tests. * This resets the properties after each individual test. * * The order in which fixtures are mixed in affects the order in which they are invoked by tests. - * If we have a suite `MySuite extends FunSuite with Foo with Bar`, then + * If we have a suite `MySuite extends SparkFunSuite with Foo with Bar`, then * Bar's `super` is Foo, so Bar's beforeEach() will and afterEach() methods will be invoked first * by the rest runner. * diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 04f0f3749d6b9..20550178fb1bd 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.util import scala.collection.mutable.ArrayBuffer -import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, PrivateMethodTester} + +import org.apache.spark.SparkFunSuite class DummyClass1 {} @@ -59,7 +61,10 @@ class DummyString(val arr: Array[Char]) { } class SizeEstimatorSuite - extends FunSuite with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { + extends SparkFunSuite + with BeforeAndAfterEach + with PrivateMethodTester + with ResetSystemProperties { override def beforeEach() { // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 751d3df9cc8f7..8c51e6b14b7fc 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -23,9 +23,9 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.concurrent.{Await, Future} import scala.concurrent.duration._ -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ThreadUtilsSuite extends FunSuite { +class ThreadUtilsSuite extends SparkFunSuite { test("newDaemonSingleThreadExecutor") { val executor = ThreadUtils.newDaemonSingleThreadExecutor("this-is-a-thread-name") diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala index 8b72fe665c214..9b3169026cda3 100644 --- a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -23,9 +23,9 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class TimeStampedHashMapSuite extends FunSuite { +class TimeStampedHashMapSuite extends SparkFunSuite { // Test the testMap function - a Scala HashMap should obviously pass testMap(new mutable.HashMap[String, String]()) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 61152c29a681f..a61ea3918f46a 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -29,16 +29,15 @@ import scala.util.Random import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.scalatest.FunSuite import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.network.util.ByteUnit -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.SparkConf -class UtilsSuite extends FunSuite with ResetSystemProperties with Logging { +class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("timeConversion") { // Test -1 @@ -551,7 +550,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties with Logging { test("fetch hcfs dir") { val tempDir = Utils.createTempDir() val sourceDir = new File(tempDir, "source-dir") - val innerSourceDir = Utils.createTempDir(root=sourceDir.getPath) + val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) val targetDir = new File(tempDir, "target-dir") Files.write("some text", sourceFile, UTF_8) @@ -609,4 +608,69 @@ class UtilsSuite extends FunSuite with ResetSystemProperties with Logging { manager.runAll() assert(output.toList === List(4, 3, 2)) } + + test("isInDirectory") { + val tmpDir = new File(sys.props("java.io.tmpdir")) + val parentDir = new File(tmpDir, "parent-dir") + val childDir1 = new File(parentDir, "child-dir-1") + val childDir1b = new File(parentDir, "child-dir-1b") + val childFile1 = new File(parentDir, "child-file-1.txt") + val childDir2 = new File(childDir1, "child-dir-2") + val childDir2b = new File(childDir1, "child-dir-2b") + val childFile2 = new File(childDir1, "child-file-2.txt") + val childFile3 = new File(childDir2, "child-file-3.txt") + val nullFile: File = null + parentDir.mkdir() + childDir1.mkdir() + childDir1b.mkdir() + childDir2.mkdir() + childDir2b.mkdir() + childFile1.createNewFile() + childFile2.createNewFile() + childFile3.createNewFile() + + // Identity + assert(Utils.isInDirectory(parentDir, parentDir)) + assert(Utils.isInDirectory(childDir1, childDir1)) + assert(Utils.isInDirectory(childDir2, childDir2)) + + // Valid ancestor-descendant pairs + assert(Utils.isInDirectory(parentDir, childDir1)) + assert(Utils.isInDirectory(parentDir, childFile1)) + assert(Utils.isInDirectory(parentDir, childDir2)) + assert(Utils.isInDirectory(parentDir, childFile2)) + assert(Utils.isInDirectory(parentDir, childFile3)) + assert(Utils.isInDirectory(childDir1, childDir2)) + assert(Utils.isInDirectory(childDir1, childFile2)) + assert(Utils.isInDirectory(childDir1, childFile3)) + assert(Utils.isInDirectory(childDir2, childFile3)) + + // Inverted ancestor-descendant pairs should fail + assert(!Utils.isInDirectory(childDir1, parentDir)) + assert(!Utils.isInDirectory(childDir2, parentDir)) + assert(!Utils.isInDirectory(childDir2, childDir1)) + assert(!Utils.isInDirectory(childFile1, parentDir)) + assert(!Utils.isInDirectory(childFile2, parentDir)) + assert(!Utils.isInDirectory(childFile3, parentDir)) + assert(!Utils.isInDirectory(childFile2, childDir1)) + assert(!Utils.isInDirectory(childFile3, childDir1)) + assert(!Utils.isInDirectory(childFile3, childDir2)) + + // Non-existent files or directories should fail + assert(!Utils.isInDirectory(parentDir, new File(parentDir, "one.txt"))) + assert(!Utils.isInDirectory(parentDir, new File(parentDir, "one/two.txt"))) + assert(!Utils.isInDirectory(parentDir, new File(parentDir, "one/two/three.txt"))) + + // Siblings should fail + assert(!Utils.isInDirectory(childDir1, childDir1b)) + assert(!Utils.isInDirectory(childDir1, childFile1)) + assert(!Utils.isInDirectory(childDir2, childDir2b)) + assert(!Utils.isInDirectory(childDir2, childFile2)) + + // Null files should fail without throwing NPE + assert(!Utils.isInDirectory(parentDir, nullFile)) + assert(!Utils.isInDirectory(childFile3, nullFile)) + assert(!Utils.isInDirectory(nullFile, parentDir)) + assert(!Utils.isInDirectory(nullFile, childFile3)) + } } diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala index ce2968728a996..11194cd22a419 100644 --- a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.util import scala.util.Random -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite /** * Tests org.apache.spark.util.Vector functionality */ @deprecated("suppress compile time deprecation warning", "1.0.0") -class VectorSuite extends FunSuite { +class VectorSuite extends SparkFunSuite { def verifyVector(vector: Vector, expectedLength: Int): Unit = { assert(vector.length == expectedLength) diff --git a/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala index cb99d14b27af4..a2a6d703860f2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala @@ -21,9 +21,9 @@ import java.util.Comparator import scala.collection.mutable.HashSet -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class AppendOnlyMapSuite extends FunSuite { +class AppendOnlyMapSuite extends SparkFunSuite { test("initialization") { val goodMap1 = new AppendOnlyMap[Int, Int](1) assert(goodMap1.size === 0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala index b85a409a4b2e9..69dbfa9cd7141 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util.collection -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class BitSetSuite extends FunSuite { +class BitSetSuite extends SparkFunSuite { test("basic set and get") { val setBits = Seq(0, 9, 1, 10, 90, 96) @@ -94,7 +94,7 @@ class BitSetSuite extends FunSuite { test( "xor len(bitsetX) > len(bitsetY)" ) { val setBitsX = Seq( 0, 1, 3, 37, 38, 41, 85) - val setBitsY = Seq( 0, 2, 3, 37, 41 ) + val setBitsY = Seq( 0, 2, 3, 37, 41) val bitsetX = new BitSet(100) setBitsX.foreach( i => bitsetX.set(i)) val bitsetY = new BitSet(60) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala index c0c38cd4ac4ad..05306f408847d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.util.collection import java.nio.ByteBuffer -import org.scalatest.FunSuite import org.scalatest.Matchers._ -class ChainedBufferSuite extends FunSuite { +import org.apache.spark.SparkFunSuite + +class ChainedBufferSuite extends SparkFunSuite { test("write and read at start") { // write from start of source array val buffer = new ChainedBuffer(8) diff --git a/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala index 6c956d93dc80d..bc5479991a99d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util.collection -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class CompactBufferSuite extends FunSuite { +class CompactBufferSuite extends SparkFunSuite { test("empty buffer") { val b = new CompactBuffer[Int] assert(b.size === 0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index dff8f3ddc816f..79eba61a87251 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -19,12 +19,10 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - import org.apache.spark._ import org.apache.spark.io.CompressionCodec -class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { +class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS private def createCombiner[T](i: T) = ArrayBuffer[T](i) private def mergeValue[T](buffer: ArrayBuffer[T], i: T): ArrayBuffer[T] = buffer += i diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 7a98723bc6472..9cefa612f5491 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -19,14 +19,12 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer -import org.scalatest.{FunSuite, PrivateMethodTester} - import scala.util.Random import org.apache.spark._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester { +class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = { val conf = new SparkConf(loadDefaults) if (kryo) { @@ -37,21 +35,12 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.serializer.objectStreamReset", "1") conf.set("spark.serializer", classOf[JavaSerializer].getName) } + conf.set("spark.shuffle.sort.bypassMergeThreshold", "0") // Ensure that we actually have multiple batches per spill file conf.set("spark.shuffle.spill.batchSize", "10") conf } - private def assertBypassedMergeSort(sorter: ExternalSorter[_, _, _]): Unit = { - val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort) - assert(sorter.invokePrivate(bypassMergeSort()), "sorter did not bypass merge-sort") - } - - private def assertDidNotBypassMergeSort(sorter: ExternalSorter[_, _, _]): Unit = { - val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort) - assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort") - } - test("empty data stream with kryo ser") { emptyDataStream(createSparkConf(false, true)) } @@ -161,39 +150,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(7)), Some(ord), None) - assertDidNotBypassMergeSort(sorter) - sorter.insertAll(elements) - assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled - val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) - assert(iter.next() === (0, Nil)) - assert(iter.next() === (1, List((1, 1)))) - assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList)) - assert(iter.next() === (3, Nil)) - assert(iter.next() === (4, Nil)) - assert(iter.next() === (5, List((5, 5)))) - assert(iter.next() === (6, Nil)) - sorter.stop() - } - - test("empty partitions with spilling, bypass merge-sort with kryo ser") { - emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, true)) - } - - test("empty partitions with spilling, bypass merge-sort with java ser") { - emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, false)) - } - - def emptyPartitionerWithSpillingBypassMergeSort(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - - val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), None, None) - assertBypassedMergeSort(sorter) sorter.insertAll(elements) assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) @@ -376,7 +332,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - assertDidNotBypassMergeSort(sorter) sorter.insertAll((0 until 120000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) sorter.stop() @@ -384,7 +339,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe val sorter2 = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - assertDidNotBypassMergeSort(sorter2) sorter2.insertAll((0 until 120000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) assert(sorter2.iterator.toSet === (0 until 120000).map(i => (i, i)).toSet) @@ -392,29 +346,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe assert(diskBlockManager.getAllBlocks().length === 0) } - test("cleanup of intermediate files in sorter, bypass merge-sort") { - val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - assertBypassedMergeSort(sorter) - sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) - assert(diskBlockManager.getAllFiles().length > 0) - sorter.stop() - assert(diskBlockManager.getAllBlocks().length === 0) - - val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - assertBypassedMergeSort(sorter2) - sorter2.insertAll((0 until 100000).iterator.map(i => (i, i))) - assert(diskBlockManager.getAllFiles().length > 0) - assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet) - sorter2.stop() - assert(diskBlockManager.getAllBlocks().length === 0) - } - test("cleanup of intermediate files in sorter if there are errors") { val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") @@ -426,7 +357,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - assertDidNotBypassMergeSort(sorter) intercept[SparkException] { sorter.insertAll((0 until 120000).iterator.map(i => { if (i == 119990) { @@ -440,28 +370,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe assert(diskBlockManager.getAllBlocks().length === 0) } - test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") { - val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - assertBypassedMergeSort(sorter) - intercept[SparkException] { - sorter.insertAll((0 until 100000).iterator.map(i => { - if (i == 99990) { - throw new SparkException("Intentional failure") - } - (i, i) - })) - } - assert(diskBlockManager.getAllFiles().length > 0) - sorter.stop() - assert(diskBlockManager.getAllBlocks().length === 0) - } - test("cleanup of intermediate files in shuffle") { val conf = createSparkConf(false, false) conf.set("spark.shuffle.memoryFraction", "0.001") @@ -776,40 +684,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe } } - test("conditions for bypassing merge-sort") { - val conf = createSparkConf(false, false) - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) - val ord = implicitly[Ordering[Int]] - - // Numbers of partitions that are above and below the default bypassMergeThreshold - val FEW_PARTITIONS = 50 - val MANY_PARTITIONS = 10000 - - // Sorters with no ordering or aggregator: should bypass unless # of partitions is high - - val sorter1 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(FEW_PARTITIONS)), None, None) - assertBypassedMergeSort(sorter1) - - val sorter2 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(MANY_PARTITIONS)), None, None) - assertDidNotBypassMergeSort(sorter2) - - // Sorters with an ordering or aggregator: should not bypass even if they have few partitions - - val sorter3 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(FEW_PARTITIONS)), Some(ord), None) - assertDidNotBypassMergeSort(sorter3) - - val sorter4 = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None) - assertDidNotBypassMergeSort(sorter4) - } - test("sort without breaking sorting contracts with kryo ser") { sortWithoutBreakingSortingContracts(createSparkConf(true, true)) } diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index ef890d2ba60f3..94e011799921b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class OpenHashMapSuite extends FunSuite with Matchers { +class OpenHashMapSuite extends SparkFunSuite with Matchers { test("size for specialized, primitive value (int)") { val capacity = 1024 diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 68a03e3a0970f..2607a543dd614 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.util.collection -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class OpenHashSetSuite extends FunSuite with Matchers { +class OpenHashSetSuite extends SparkFunSuite with Matchers { test("size for specialized, primitive int") { val loadFactor = 0.7 diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala index b5a2d9ef720c1..6d2459d48d326 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala @@ -21,14 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} import com.google.common.io.ByteStreams -import org.scalatest.FunSuite import org.scalatest.Matchers._ -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.storage.{FileSegment, BlockObjectWriter} -class PartitionedSerializedPairBufferSuite extends FunSuite { +class PartitionedSerializedPairBufferSuite extends SparkFunSuite { test("OrderedInputStream single record") { val serializerInstance = new KryoSerializer(new SparkConf()).newInstance diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index caf378fec8b3e..462bc2f29f9f8 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers { +class PrimitiveKeyOpenHashMapSuite extends SparkFunSuite with Matchers { test("size for specialized, primitive key, value (int, int)") { val capacity = 1024 diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala index 970dade628fe4..ae0eebc26f01b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.util.collection -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class PrimitiveVectorSuite extends FunSuite { +class PrimitiveVectorSuite extends SparkFunSuite { test("primitive value") { val vector = new PrimitiveVector[Int] diff --git a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala index 1f33967249654..5a5919fca2469 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.util.collection import scala.reflect.ClassTag import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class SizeTrackerSuite extends FunSuite { +class SizeTrackerSuite extends SparkFunSuite { val NORMAL_ERROR = 0.20 val HIGH_ERROR = 0.30 diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index e0d6cc16bde05..b2f5d9009ee5d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.util.collection import java.lang.{Float => JFloat, Integer => JInteger} import java.util.{Arrays, Comparator} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.random.XORShiftRandom -class SorterSuite extends FunSuite { +class SorterSuite extends SparkFunSuite { test("equivalent to Arrays.sort") { val rand = new XORShiftRandom(123) @@ -104,9 +103,6 @@ class SorterSuite extends FunSuite { * has the keys and values alternating. The basic Java sorts work only on the keys, so the * real Java solution is to make Tuple2s to store the keys and values and sort an array of * those, while the Sorter approach can work directly on the input data format. - * - * Note that the Java implementation varies tremendously between Java 6 and Java 7, when - * the Java sort changed from merge sort to TimSort. */ ignore("Sorter benchmark for key-value pairs") { val numElements = 25000000 // 25 mil diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala index f855831b8e367..361ec95654f47 100644 --- a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.util.io import scala.util.Random -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ByteArrayChunkOutputStreamSuite extends FunSuite { +class ByteArrayChunkOutputStreamSuite extends SparkFunSuite { test("empty output") { val o = new ByteArrayChunkOutputStream(1024) diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala index 20944b62473c5..d6af0aebde733 100644 --- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala @@ -21,9 +21,11 @@ import java.util.Random import scala.collection.mutable.ArrayBuffer import org.apache.commons.math3.distribution.PoissonDistribution -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers -class RandomSamplerSuite extends FunSuite with Matchers { +import org.apache.spark.SparkFunSuite + +class RandomSamplerSuite extends SparkFunSuite with Matchers { /** * My statistical testing methodology is to run a Kolmogorov-Smirnov (KS) test * between the random samplers and simple reference samplers (known to work correctly). @@ -76,7 +78,7 @@ class RandomSamplerSuite extends FunSuite with Matchers { } // Returns iterator over gap lengths between samples. - // This function assumes input data is integers sampled from the sequence of + // This function assumes input data is integers sampled from the sequence of // increasing integers: {0, 1, 2, ...}. This works because that is how I generate them, // and the samplers preserve their input order def gaps(data: Iterator[Int]): Iterator[Int] = { diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala index 73a9d029b0248..667a4db6f7bb6 100644 --- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.util.random import scala.util.Random import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} -import org.scalatest.FunSuite -class SamplingUtilsSuite extends FunSuite { +import org.apache.spark.SparkFunSuite + +class SamplingUtilsSuite extends SparkFunSuite { test("reservoirSampleAndCount") { val input = Seq.fill(100)(Random.nextInt()) diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index 03f5f2d1b8528..d26667bf720cf 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -17,18 +17,18 @@ package org.apache.spark.util.random -import org.scalatest.FunSuite import org.scalatest.Matchers import org.apache.commons.math3.stat.inference.ChiSquareTest +import org.apache.spark.SparkFunSuite import org.apache.spark.util.Utils.times import scala.language.reflectiveCalls -class XORShiftRandomSuite extends FunSuite with Matchers { +class XORShiftRandomSuite extends SparkFunSuite with Matchers { - def fixture: Object {val seed: Long; val hundMil: Int; val xorRand: XORShiftRandom} = new { + private def fixture = new { val seed = 1L val xorRand = new XORShiftRandom(seed) val hundMil = 1e8.toInt diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 787c5cc8e892d..cd83b352c1bfb 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -44,9 +44,9 @@ # Remote name which points to Apache git PUSH_REMOTE_NAME = os.environ.get("PUSH_REMOTE_NAME", "apache") # ASF JIRA username -JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "pwendell") +JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "") # ASF JIRA password -JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "35500") +JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "") GITHUB_BASE = "https://github.com/apache/spark/pull" GITHUB_API_BASE = "https://api.github.com/repos/apache/spark" diff --git a/dev/run-tests b/dev/run-tests index 7dd8d31fd44e3..d178e2a4601ea 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -80,18 +80,19 @@ export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" # Only run Hive tests if there are SQL changes. # Partial solution for SPARK-1455. if [ -n "$AMPLAB_JENKINS" ]; then - git fetch origin master:master + target_branch="$ghprbTargetBranch" + git fetch origin "$target_branch":"$target_branch" # AMP_JENKINS_PRB indicates if the current build is a pull request build. if [ -n "$AMP_JENKINS_PRB" ]; then # It is a pull request build. sql_diffs=$( - git diff --name-only master \ + git diff --name-only "$target_branch" \ | grep -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh" ) non_sql_diffs=$( - git diff --name-only master \ + git diff --name-only "$target_branch" \ | grep -v -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh" ) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 8b2a44fd72ba5..641b0ff3c4be4 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -47,7 +47,9 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" -TESTS_TIMEOUT="150m" # format: http://linux.die.net/man/1/timeout +# format: http://linux.die.net/man/1/timeout +# must be less than the timeout configured on Jenkins (currently 180m) +TESTS_TIMEOUT="175m" # Array to capture all tests to run on the pull request. These tests are held under the #+ dev/tests/ directory. @@ -191,7 +193,7 @@ done test_result="$?" if [ "$test_result" -eq "124" ]; then - fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}consoleFull)** \ + fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}console)** \ for PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL}) \ after a configured wait of \`${TESTS_TIMEOUT}\`." @@ -231,7 +233,7 @@ done # post end message { result_message="\ - [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}consoleFull) for \ + [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}console) for \ PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." result_message="${result_message}\n${test_result_note}" diff --git a/docs/_config.yml b/docs/_config.yml index b22b627f09007..c0e031a83ba9c 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.4.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.4.0 +SPARK_VERSION: 1.5.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.5.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.21.0 diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index b92c75f90b11c..eebb3faf90fc0 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -75,6 +75,7 @@
  • MLlib (Machine Learning)
  • GraphX (Graph Processing)
  • Bagel (Pregel on Spark)
  • +
  • SparkR (R on Spark)
  • diff --git a/docs/building-spark.md b/docs/building-spark.md index 4dbccb9e6e46c..2128fdffecc05 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -7,11 +7,7 @@ redirect_from: "building-with-maven.html" * This will become a table of contents (this text will be scraped). {:toc} -Building Spark using Maven requires Maven 3.0.4 or newer and Java 6+. - -**Note:** Building Spark with Java 7 or later can create JAR files that may not be -readable with early versions of Java 6, due to the large number of files in the JAR -archive. Build with Java 6 if this is an issue for your deployment. +Building Spark using Maven requires Maven 3.0.4 or newer and Java 7+. # Building with `build/mvn` @@ -80,6 +76,7 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro 2.2.xhadoop-2.2 2.3.xhadoop-2.3 2.4.xhadoop-2.4 + 2.6.x and later 2.xhadoop-2.6 @@ -118,14 +115,10 @@ mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=2.2.0 -DskipTests # Building With Hive and JDBC Support To enable Hive integration for Spark SQL along with its JDBC server and CLI, add the `-Phive` and `Phive-thriftserver` profiles to your existing build options. -By default Spark will build with Hive 0.13.1 bindings. You can also build for -Hive 0.12.0 using the `-Phive-0.12.0` profile. +By default Spark will build with Hive 0.13.1 bindings. {% highlight bash %} # Apache Hadoop 2.4.X with Hive 13 support mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package - -# Apache Hadoop 2.4.X with Hive 12 support -mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-0.12.0 -Phive-thriftserver -DskipTests clean package {% endhighlight %} # Building for Scala 2.11 @@ -134,9 +127,7 @@ To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` prop dev/change-version-to-2.11.sh mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package -Scala 2.11 support in Spark does not support a few features due to dependencies -which are themselves not Scala 2.11 ready. Specifically, Spark's external -Kafka library and JDBC component are not yet supported in Scala 2.11 builds. +Spark does not yet support its JDBC component for Scala 2.11. # Spark Tests in Maven @@ -180,7 +171,7 @@ Thus, the full flow for running continuous-compilation of the `core` submodule m # Building Spark with IntelliJ IDEA or Eclipse For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the -[wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-IDESetup). +[wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IDESetup). # Running Java 8 Test Suites diff --git a/docs/configuration.md b/docs/configuration.md index 30508a617fdd8..3960e7e78bde1 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1,4 +1,4 @@ --- +--- layout: global displayTitle: Spark Configuration title: Configuration @@ -618,7 +618,7 @@ Apart from these, the following properties are also available, and may be useful spark.kryo.referenceTracking - true + true (false when using Spark SQL Thrift Server) Whether to track references to the same object when serializing data with Kryo, which is necessary if your object graphs have loops and useful for efficiency if they contain multiple @@ -679,7 +679,10 @@ Apart from these, the following properties are also available, and may be useful spark.serializer - org.apache.spark.serializer.
    JavaSerializer + + org.apache.spark.serializer.
    JavaSerializer (org.apache.spark.serializer.
    + KryoSerializer when using Spark SQL Thrift Server) + Class to use for serializing objects that will be sent over the network or need to be cached in serialized form. The default of Java serialization works with any Serializable Java object @@ -1201,6 +1204,15 @@ Apart from these, the following properties are also available, and may be useful description. + + spark.dynamicAllocation.cachedExecutorIdleTimeout + 2 * executorIdleTimeout + + If dynamic allocation is enabled and an executor which has cached data blocks has been idle for more than this duration, + the executor will be removed. For more details, see this + description. + + spark.dynamicAllocation.initialExecutors spark.dynamicAllocation.minExecutors diff --git a/docs/hadoop-provided.md b/docs/hadoop-provided.md new file mode 100644 index 0000000000000..0ba5a58051abc --- /dev/null +++ b/docs/hadoop-provided.md @@ -0,0 +1,26 @@ +--- +layout: global +displayTitle: Using Spark's "Hadoop Free" Build +title: Using Spark's "Hadoop Free" Build +--- + +Spark uses Hadoop client libraries for HDFS and YARN. Starting in version Spark 1.4, the project packages "Hadoop free" builds that lets you more easily connect a single Spark binary to any Hadoop version. To use these builds, you need to modify `SPARK_DIST_CLASSPATH` to include Hadoop's package jars. The most convenient place to do this is by adding an entry in `conf/spark-env.sh`. + +This page describes how to connect Spark to Hadoop for different types of distributions. + +# Apache Hadoop +For Apache distributions, you can use Hadoop's 'classpath' command. For instance: + +{% highlight bash %} +### in conf/spark-env.sh ### + +# If 'hadoop' binary is on your PATH +export SPARK_DIST_CLASSPATH=$(hadoop classpath) + +# With explicit path to 'hadoop' binary +export SPARK_DIST_CLASSPATH=$(/path/to/hadoop/bin/hadoop classpath) + +# Passing a Hadoop configuration directory +export SPARK_DIST_CLASSPATH=$(hadoop classpath --config /path/to/configs) + +{% endhighlight %} diff --git a/docs/index.md b/docs/index.md index 5ef6d983c45a5..d85cf12defefd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,15 +12,19 @@ It also supports a rich set of higher-level tools including [Spark SQL](sql-prog # Downloading -Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. The downloads page -contains Spark packages for many popular HDFS versions. If you'd like to build Spark from -scratch, visit [Building Spark](building-spark.html). +Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions. +Users can also download a "Hadoop free" binary and run Spark with any Hadoop version +[by augmenting Spark's classpath](hadoop-provided.html). + +If you'd like to build Spark from +source, visit [Building Spark](building-spark.html). + Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy to run locally on one machine --- all you need is to have `java` installed on your system `PATH`, or the `JAVA_HOME` environment variable pointing to a Java installation. -Spark runs on Java 6+, Python 2.6+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} uses +Spark runs on Java 7+, Python 2.6+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version ({{site.SCALA_BINARY_VERSION}}.x). @@ -54,7 +58,7 @@ Example applications are also provided in Python. For example, ./bin/spark-submit examples/src/main/python/pi.py 10 -Spark also provides an experimental R API since 1.4 (only DataFrames APIs included). +Spark also provides an experimental [R API](sparkr.html) since 1.4 (only DataFrames APIs included). To run Spark interactively in a R interpreter, use `bin/sparkR`: ./bin/sparkR --master local[2] diff --git a/docs/ml-features.md b/docs/ml-features.md index efe9b3b8edb6e..f88c0248c1a8a 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -456,6 +456,122 @@ for expanded in polyDF.select("polyFeatures").take(3):
    +## StringIndexer + +`StringIndexer` encodes a string column of labels to a column of label indices. +The indices are in `[0, numLabels)`, ordered by label frequencies. +So the most frequent label gets index `0`. +If the input column is numeric, we cast it to string and index the string values. + +**Examples** + +Assume that we have the following DataFrame with columns `id` and `category`: + +~~~~ + id | category +----|---------- + 0 | a + 1 | b + 2 | c + 3 | a + 4 | a + 5 | c +~~~~ + +`category` is a string column with three labels: "a", "b", and "c". +Applying `StringIndexer` with `category` as the input column and `categoryIndex` as the output +column, we should get the following: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 + 3 | a | 0.0 + 4 | a | 0.0 + 5 | c | 1.0 +~~~~ + +"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with +index `2`. + +
    + +
    + +[`StringIndexer`](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) takes an input +column name and an output column name. + +{% highlight scala %} +import org.apache.spark.ml.feature.StringIndexer + +val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) +).toDF("id", "category") +val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") +val indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %} +
    + +
    +[`StringIndexer`](api/java/org/apache/spark/ml/feature/StringIndexer.html) takes an input column +name and an output column name. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") +)); +StructType schema = new StructType(new StructField[] { + createStructField("id", DoubleType, false), + createStructField("category", StringType, false) +}); +DataFrame df = sqlContext.createDataFrame(jrdd, schema); +StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); +DataFrame indexed = indexer.fit(df).transform(df); +indexed.show(); +{% endhighlight %} +
    + +
    + +[`StringIndexer`](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) takes an input +column name and an output column name. + +{% highlight python %} +from pyspark.ml.feature import StringIndexer + +df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) +indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") +indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %} +
    +
    + ## OneHotEncoder [One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features @@ -789,6 +905,294 @@ scaledData = scalerModel.transform(dataFrame) +## Bucketizer + +`Bucketizer` transforms a column of continuous features to a column of feature buckets, where the buckets are specified by users. It takes a parameter: + +* `splits`: Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which also includes y. Splits should be strictly increasing. Values at -inf, inf must be explicitly provided to cover all Double values; Otherwise, values outside the splits specified will be treated as errors. Two examples of `splits` are `Array(Double.NegativeInfinity, 0.0, 1.0, Double.PositiveInfinity)` and `Array(0.0, 1.0, 2.0)`. + +Note that if you have no idea of the upper bound and lower bound of the targeted column, you would better add the `Double.NegativeInfinity` and `Double.PositiveInfinity` as the bounds of your splits to prevent a potenial out of Bucketizer bounds exception. + +Note also that the splits that you provided have to be in strictly increasing order, i.e. `s0 < s1 < s2 < ... < sn`. + +More details can be found in the API docs for [Bucketizer](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer). + +The following example demonstrates how to bucketize a column of `Double`s into another index-wised column. + +
    +
    +{% highlight scala %} +import org.apache.spark.ml.feature.Bucketizer +import org.apache.spark.sql.DataFrame + +val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + +val data = Array(-0.5, -0.3, 0.0, 0.2) +val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + +val bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits) + +// Transform original data into its bucket index. +val bucketedData = bucketizer.transform(dataFrame) +{% endhighlight %} +
    + +
    +{% highlight java %} +import com.google.common.collect.Lists; + +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; + +JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) +}); +DataFrame dataFrame = jsql.createDataFrame(data, schema); + +Bucketizer bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits); + +// Transform original data into its bucket index. +DataFrame bucketedData = bucketizer.transform(dataFrame); +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.ml.feature import Bucketizer + +splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] + +data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] +dataFrame = sqlContext.createDataFrame(data, ["features"]) + +bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") + +# Transform original data into its bucket index. +bucketedData = bucketizer.transform(dataFrame) +{% endhighlight %} +
    +
    + +## ElementwiseProduct + +ElementwiseProduct multiplies each input vector by a provided "weight" vector, using element-wise multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `w`, to yield a result vector. + +`\[ \begin{pmatrix} +v_1 \\ +\vdots \\ +v_N +\end{pmatrix} \circ \begin{pmatrix} + w_1 \\ + \vdots \\ + w_N + \end{pmatrix} += \begin{pmatrix} + v_1 w_1 \\ + \vdots \\ + v_N w_N + \end{pmatrix} +\]` + +[`ElementwiseProduct`](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) takes the following parameter: + +* `scalingVec`: the transforming vector. + +This example below demonstrates how to transform vectors using a transforming vector value. + +
    +
    +{% highlight scala %} +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors + +// Create some vector data; also works for sparse vectors +val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + +val transformingVector = Vectors.dense(0.0, 1.0, 2.0) +val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + +// Batch transform the vectors to create new column: +val transformedData = transformer.transform(dataFrame) + +{% endhighlight %} +
    + +
    +{% highlight java %} +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +// Create some vector data; also works for sparse vectors +JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) +)); +List fields = new ArrayList(2); +fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); +fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); +StructType schema = DataTypes.createStructType(fields); +DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); +Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); +ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); +// Batch transform the vectors to create new column: +DataFrame transformedData = transformer.transform(dataFrame); + +{% endhighlight %} +
    +
    + +## VectorAssembler + +`VectorAssembler` is a transformer that combines a given list of columns into a single vector +column. +It is useful for combining raw features and features generated by different feature transformers +into a single feature vector, in order to train ML models like logistic regression and decision +trees. +`VectorAssembler` accepts the following input column types: all numeric types, boolean type, +and vector type. +In each row, the values of the input columns will be concatenated into a vector in the specified +order. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `hour`, `mobile`, `userFeatures`, +and `clicked`: + +~~~ + id | hour | mobile | userFeatures | clicked +----|------|--------|------------------|--------- + 0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0 +~~~ + +`userFeatures` is a vector column that contains three user features. +We want to combine `hour`, `mobile`, and `userFeatures` into a single feature vector +called `features` and use it to predict `clicked` or not. +If we set `VectorAssembler`'s input columns to `hour`, `mobile`, and `userFeatures` and +output column to `features`, after transformation we should get the following DataFrame: + +~~~ + id | hour | mobile | userFeatures | clicked | features +----|------|--------|------------------|---------|----------------------------- + 0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0 | [18.0, 1.0, 0.0, 10.0, 0.5] +~~~ + +
    +
    + +[`VectorAssembler`](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) takes an array +of input column names and an output column name. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.feature.VectorAssembler + +val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) +).toDF("id", "hour", "mobile", "userFeatures", "clicked") +val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") +val output = assembler.transform(dataset) +println(output.select("features", "clicked").first()) +{% endhighlight %} +
    + +
    + +[`VectorAssembler`](api/java/org/apache/spark/ml/feature/VectorAssembler.html) takes an array +of input column names and an output column name. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) +}); +Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); +JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); +DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + +VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + +DataFrame output = assembler.transform(dataset); +System.out.println(output.select("features", "clicked").first()); +{% endhighlight %} +
    + +
    + +[`VectorAssembler`](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) takes a list +of input column names and an output column name. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler + +dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) +assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") +output = assembler.transform(dataset) +print(output.select("features", "clicked").first()) +{% endhighlight %} +
    +
    # Feature Selectors diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c5f50ed7990f1..4eb622d4b95e8 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -207,7 +207,7 @@ val model1 = lr.fit(training.toDF) // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. -println("Model 1 was fit using parameters: " + model1.fittingParamMap) +println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. @@ -222,7 +222,7 @@ val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. val model2 = lr.fit(training.toDF, paramMapCombined) -println("Model 2 was fit using parameters: " + model2.fittingParamMap) +println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) // Prepare test data. val test = sc.parallelize(Seq( @@ -289,7 +289,7 @@ LogisticRegressionModel model1 = lr.fit(training); // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. -System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap()); +System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); // We may alternatively specify parameters using a ParamMap. ParamMap paramMap = new ParamMap(); @@ -305,7 +305,7 @@ ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); -System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap()); +System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); // Prepare test documents. List localTest = Lists.newArrayList( diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index f41ca70952eb7..1b088969ddc25 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -47,7 +47,7 @@ Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasin optimal *k* is usually one where there is an "elbow" in the WSSSE graph. {% highlight scala %} -import org.apache.spark.mllib.clustering.KMeans +import org.apache.spark.mllib.clustering.{KMeans, KMeansModel} import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -62,6 +62,10 @@ val clusters = KMeans.train(parsedData, numClusters, numIterations) // Evaluate clustering by computing Within Set Sum of Squared Errors val WSSSE = clusters.computeCost(parsedData) println("Within Set Sum of Squared Errors = " + WSSSE) + +// Save and load model +clusters.save(sc, "myModelPath") +val sameModel = KMeansModel.load(sc, "myModelPath") {% endhighlight %} @@ -110,6 +114,10 @@ public class KMeansExample { // Evaluate clustering by computing Within Set Sum of Squared Errors double WSSSE = clusters.computeCost(parsedData.rdd()); System.out.println("Within Set Sum of Squared Errors = " + WSSSE); + + // Save and load model + clusters.save(sc.sc(), "myModelPath"); + KMeansModel sameModel = KMeansModel.load(sc.sc(), "myModelPath"); } } {% endhighlight %} @@ -124,7 +132,7 @@ Within Set Sum of Squared Error (WSSSE). You can reduce this error measure by in fact the optimal *k* is usually one where there is an "elbow" in the WSSSE graph. {% highlight python %} -from pyspark.mllib.clustering import KMeans +from pyspark.mllib.clustering import KMeans, KMeansModel from numpy import array from math import sqrt @@ -143,6 +151,10 @@ def error(point): WSSSE = parsedData.map(lambda point: error(point)).reduce(lambda x, y: x + y) print("Within Set Sum of Squared Error = " + str(WSSSE)) + +# Save and load model +clusters.save(sc, "myModelPath") +sameModel = KMeansModel.load(sc, "myModelPath") {% endhighlight %} @@ -237,11 +249,11 @@ public class GaussianMixtureExample { GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); // Save and load GaussianMixtureModel - gmm.save(sc, "myGMMModel") - GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel") + gmm.save(sc.sc(), "myGMMModel"); + GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc.sc(), "myGMMModel"); // Output the parameters of the mixture model for(int j=0; j println(s"${a.id} -> ${a.cluster}") } + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") {% endhighlight %} A full example that produces the experiment described in the PIC paper can be found under @@ -360,6 +376,10 @@ PowerIterationClusteringModel model = pic.run(similarities); for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { System.out.println(a.id() + " -> " + a.cluster()); } + +// Save and load model +model.save(sc.sc(), "myModelPath"); +PowerIterationClusteringModel sameModel = PowerIterationClusteringModel.load(sc.sc(), "myModelPath"); {% endhighlight %} diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 7b397e30b2d90..dfdf6216b270c 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -107,7 +107,8 @@ other signals), you can use the `trainImplicit` method to get better results. {% highlight scala %} val alpha = 0.01 -val model = ALS.trainImplicit(ratings, rank, numIterations, alpha) +val lambda = 0.01 +val model = ALS.trainImplicit(ratings, rank, numIterations, lambda, alpha) {% endhighlight %} diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index f723cd6b9dfab..4fe470a8de810 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -188,7 +188,7 @@ Here we assume the extracted file is `text8` and in same directory as you run th import org.apache.spark._ import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.feature.Word2Vec +import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} val input = sc.textFile("text8").map(line => line.split(" ").toSeq) @@ -201,6 +201,10 @@ val synonyms = model.findSynonyms("china", 40) for((synonym, cosineSimilarity) <- synonyms) { println(s"$synonym $cosineSimilarity") } + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = Word2VecModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -410,6 +414,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.feature.ChiSqSelector // Load some data in libsvm format val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") @@ -505,7 +510,7 @@ v_N ### Example -This example below demonstrates how to load a simple vectors file, extract a set of vectors, then transform those vectors using a transforming vector value. +This example below demonstrates how to transform vectors using a transforming vector value.
    @@ -514,16 +519,44 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.feature.ElementwiseProduct import org.apache.spark.mllib.linalg.Vectors -// Load and parse the data: -val data = sc.textFile("data/mllib/kmeans_data.txt") -val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))) +// Create some vector data; also works for sparse vectors +val data = sc.parallelize(Array(Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))) val transformingVector = Vectors.dense(0.0, 1.0, 2.0) val transformer = new ElementwiseProduct(transformingVector) // Batch transform and per-row transform give the same results: -val transformedData = transformer.transform(parsedData) -val transformedData2 = parsedData.map(x => transformer.transform(x)) +val transformedData = transformer.transform(data) +val transformedData2 = data.map(x => transformer.transform(x)) + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +// Create some vector data; also works for sparse vectors +JavaRDD data = sc.parallelize(Arrays.asList( + Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))); +Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); +ElementwiseProduct transformer = new ElementwiseProduct(transformingVector); + +// Batch transform and per-row transform give the same results: +JavaRDD transformedData = transformer.transform(data); +JavaRDD transformedData2 = data.map( + new Function() { + @Override + public Vector call(Vector v) { + return transformer.transform(v); + } + } +); {% endhighlight %}
    diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 9fd9be0dd01b1..bcc066a185526 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -39,11 +39,11 @@ MLlib's FP-growth implementation takes the following (hyper-)parameters:
    -[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the +[`FPGrowth`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) implements the FP-growth algorithm. It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. Calling `FPGrowth.run` with transactions returns an -[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) +[`FPGrowthModel`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel) that stores the frequent itemsets with their frequencies. {% highlight scala %} diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index b521c2f27cd6e..5732bc4c7e79e 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -60,7 +60,7 @@ Model is created using the training set and a mean squared error is calculated f labels and real labels in the test set. {% highlight scala %} -import org.apache.spark.mllib.regression.IsotonicRegression +import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") @@ -88,6 +88,10 @@ val predictionAndLabel = test.map { point => // Calculate mean squared error between predicted and real labels. val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean() println("Mean Squared Error = " + meanSquaredError) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = IsotonicRegressionModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -150,6 +154,10 @@ Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( ).rdd()).mean(); System.out.println("Mean Squared Error = " + meanSquaredError); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 8029edca16002..3dc8cc902fa72 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -163,11 +163,8 @@ object, and make predictions with the resulting model to compute the training error. {% highlight scala %} -import org.apache.spark.SparkContext import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLUtils // Load training data in LIBSVM format. @@ -231,15 +228,13 @@ calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given bellow: {% highlight java %} -import java.util.Random; - import scala.Tuple2; import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.*; import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.linalg.Vector; + import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.SparkConf; @@ -282,8 +277,8 @@ public class SVMClassifier { System.out.println("Area under ROC = " + auROC); // Save and load model - model.save(sc.sc(), "myModelPath"); - SVMModel sameModel = SVMModel.load(sc.sc(), "myModelPath"); + model.save(sc, "myModelPath"); + SVMModel sameModel = SVMModel.load(sc, "myModelPath"); } } {% endhighlight %} @@ -315,15 +310,12 @@ a dependency.
    -The following example shows how to load a sample dataset, build Logistic Regression model, +The following example shows how to load a sample dataset, build SVM model, and make predictions with the resulting model to compute the training error. -Note that the Python API does not yet support model save/load but will in the future. - {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithSGD +from pyspark.mllib.classification import SVMWithSGD, SVMModel from pyspark.mllib.regression import LabeledPoint -from numpy import array # Load and parse the data def parsePoint(line): @@ -334,12 +326,16 @@ data = sc.textFile("data/mllib/sample_svm_data.txt") parsedData = data.map(parsePoint) # Build the model -model = LogisticRegressionWithSGD.train(parsedData) +model = SVMWithSGD.train(parsedData, iterations=100) # Evaluating the model on training data labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = SVMModel.load(sc, "myModelPath") {% endhighlight %}
    diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 56a2e9ca86bb1..bf6d124fd5d8d 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -14,9 +14,8 @@ and use it for prediction. MLlib supports [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) -and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). -These models are typically used for [document classification] -(http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). +These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). Within that context, each observation is a document and each feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes). @@ -54,7 +53,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) val training = splits(0) val test = splits(1) -val model = NaiveBayes.train(training, lambda = 1.0, model = "multinomial") +val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial") val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() diff --git a/docs/monitoring.md b/docs/monitoring.md index e75018499003a..bcf885fe4e681 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -228,6 +228,14 @@ for a running application, at `http://localhost:4040/api/v1`. /applications/[app-id]/storage/rdd/[rdd-id] Details for the storage status of a given RDD + + /applications/[app-id]/logs + Download the event logs for all attempts of the given application as a zip file + + + /applications/[app-id]/[attempt-id]/logs + Download the event logs for the specified attempt of the given application as a zip file + When running on Yarn, each application has multiple attempts, so `[app-id]` is actually diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 10f474f237bfa..d5ff416fe89a4 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -54,7 +54,7 @@ import org.apache.spark.SparkConf
    -Spark {{site.SPARK_VERSION}} works with Java 6 and higher. If you are using Java 8, Spark supports +Spark {{site.SPARK_VERSION}} works with Java 7 and higher. If you are using Java 8, Spark supports [lambda expressions](http://docs.oracle.com/javase/tutorial/java/javaOO/lambdaexpressions.html) for concisely writing functions, otherwise you can use the classes in the [org.apache.spark.api.java.function](api/java/index.html?org/apache/spark/api/java/function/package-summary.html) package. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 9d55f435e80ad..96cf612c54fdd 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -242,6 +242,22 @@ Most of the configs are the same for Spark on YARN as for other deployment modes running against earlier versions, this property will be ignored. + + spark.yarn.keytab + (none) + + The full path to the file that contains the keytab for the principal specified above. + This keytab will be copied to the node running the Application Master via the Secure Distributed Cache, + for renewing the login tickets and the delegation tokens periodically. + + + + spark.yarn.principal + (none) + + Principal to be used to login to KDC, while running on secure HDFS. + + # Launching Spark on YARN diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 0eed9adacf123..12d7d6e159bea 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -77,7 +77,7 @@ Note, the master machine accesses each of the worker machines via ssh. By defaul If you do not have a password-less setup, you can set the environment variable SPARK_SSH_FOREGROUND and serially provide a password for each worker. -Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`: +Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/sbin`: - `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on. - `sbin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file. diff --git a/docs/sparkr.md b/docs/sparkr.md new file mode 100644 index 0000000000000..4d82129921a37 --- /dev/null +++ b/docs/sparkr.md @@ -0,0 +1,223 @@ +--- +layout: global +displayTitle: SparkR (R on Spark) +title: SparkR (R on Spark) +--- + +* This will become a table of contents (this text will be scraped). +{:toc} + +# Overview +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. +In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that +supports operations like selection, filtering, aggregation etc. (similar to R data frames, +[dplyr](https://github.com/hadley/dplyr)) but on large datasets. + +# SparkR DataFrames + +A DataFrame is a distributed collection of data organized into named columns. It is conceptually +equivalent to a table in a relational database or a data frame in R, but with richer +optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: +structured data files, tables in Hive, external databases, or existing local R data frames. + +All of the examples on this page use sample data included in R or the Spark distribution and can be run using the `./bin/sparkR` shell. + +## Starting Up: SparkContext, SQLContext + +
    +The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster. +You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name +etc. Further, to work with DataFrames we will need a `SQLContext`, which can be created from the +SparkContext. If you are working from the SparkR shell, the `SQLContext` and `SparkContext` should +already be created for you. + +{% highlight r %} +sc <- sparkR.init() +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} + +
    + +## Creating DataFrames +With a `SQLContext`, applications can create `DataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). + +### From local data frames +The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R. + +
    +{% highlight r %} +df <- createDataFrame(sqlContext, faithful) + +# Displays the content of the DataFrame to stdout +head(df) +## eruptions waiting +##1 3.600 79 +##2 1.800 54 +##3 3.333 74 + +{% endhighlight %} +
    + +### From Data Sources + +SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. + +The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). + +We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. + +
    + +{% highlight r %} +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") +head(people) +## age name +##1 NA Michael +##2 30 Andy +##3 19 Justin + +# SparkR automatically infers the schema from the JSON file +printSchema(people) +# root +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) + +{% endhighlight %} +
    + +The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example +to a Parquet file using `write.df` + +
    +{% highlight r %} +write.df(people, path="people.parquet", source="parquet", mode="overwrite") +{% endhighlight %} +
    + +### From Hive tables + +You can also create SparkR DataFrames from Hive tables. To do this we will need to create a HiveContext which can access tables in the Hive MetaStore. Note that Spark should have been built with [Hive support](building-spark.html#building-with-hive-and-jdbc-support) and more details on the difference between SQLContext and HiveContext can be found in the [SQL programming guide](sql-programming-guide.html#starting-point-sqlcontext). + +
    +{% highlight r %} +# sc is an existing SparkContext. +hiveContext <- sparkRHive.init(sc) + +sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results <- hiveContext.sql("FROM src SELECT key, value") + +# results is now a DataFrame +head(results) +## key value +## 1 238 val_238 +## 2 86 val_86 +## 3 311 val_311 + +{% endhighlight %} +
    + +## DataFrame Operations + +SparkR DataFrames support a number of functions to do structured data processing. +Here we include some basic examples and a complete list can be found in the [API](api/R/index.html) docs: + +### Selecting rows, columns + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, faithful) + +# Get basic information about the DataFrame +df +## DataFrame[eruptions:double, waiting:double] + +# Select only the "eruptions" column +head(select(df, df$eruptions)) +## eruptions +##1 3.600 +##2 1.800 +##3 3.333 + +# You can also pass in column name as strings +head(select(df, "eruptions")) + +# Filter the DataFrame to only retain rows with wait times shorter than 50 mins +head(filter(df, df$waiting < 50)) +## eruptions waiting +##1 1.750 47 +##2 1.750 47 +##3 1.867 48 + +{% endhighlight %} + +
    + +### Grouping, Aggregation + +SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below + +
    +{% highlight r %} + +# We use the `n` operator to count the number of times each waiting time appears +head(summarize(groupBy(df, df$waiting), count = n(df$waiting))) +## waiting count +##1 81 13 +##2 60 6 +##3 68 1 + +# We can also sort the output from the aggregation to get the most common waiting times +waiting_counts <- summarize(groupBy(df, df$waiting), count = n(df$waiting)) +head(arrange(waiting_counts, desc(waiting_counts$count))) + +## waiting count +##1 78 15 +##2 83 14 +##3 81 13 + +{% endhighlight %} +
    + +### Operating on Columns + +SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. + +
    +{% highlight r %} + +# Convert waiting time from hours to seconds. +# Note that we can assign this to a new column in the same DataFrame +df$waiting_secs <- df$waiting * 60 +head(df) +## eruptions waiting waiting_secs +##1 3.600 79 4740 +##2 1.800 54 3240 +##3 3.333 74 4440 + +{% endhighlight %} +
    + +## Running SQL Queries from SparkR +A SparkR DataFrame can also be registered as a temporary table in Spark SQL and registering a DataFrame as a table allows you to run SQL queries over its data. +The `sql` function enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. + +
    +{% highlight r %} +# Load a JSON file +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") + +# Register this DataFrame as a table. +registerTempTable(people, "people") + +# SQL statements can be run by using the sql method +teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") +head(teenagers) +## name +##1 Justin + +{% endhighlight %} +
    diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5b41c0ee6e430..40e33f757d693 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -11,6 +11,7 @@ title: Spark SQL and DataFrames Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. +For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section. # DataFrames @@ -108,7 +109,7 @@ As an example, the following creates a `DataFrame` based on the content of a JSO val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -val df = sqlContext.jsonFile("examples/src/main/resources/people.json") +val df = sqlContext.read.json("examples/src/main/resources/people.json") // Displays the content of the DataFrame to stdout df.show() @@ -121,7 +122,7 @@ df.show() JavaSparkContext sc = ...; // An existing JavaSparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); -DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); // Displays the content of the DataFrame to stdout df.show(); @@ -134,7 +135,7 @@ df.show(); from pyspark.sql import SQLContext sqlContext = SQLContext(sc) -df = sqlContext.jsonFile("examples/src/main/resources/people.json") +df = sqlContext.read.json("examples/src/main/resources/people.json") # Displays the content of the DataFrame to stdout df.show() @@ -170,7 +171,7 @@ val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame -val df = sqlContext.jsonFile("examples/src/main/resources/people.json") +val df = sqlContext.read.json("examples/src/main/resources/people.json") // Show the content of the DataFrame df.show() @@ -220,7 +221,7 @@ JavaSparkContext sc // An existing SparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame -DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); // Show the content of the DataFrame df.show(); @@ -276,7 +277,7 @@ from pyspark.sql import SQLContext sqlContext = SQLContext(sc) # Create the DataFrame -df = sqlContext.jsonFile("examples/src/main/resources/people.json") +df = sqlContext.read.json("examples/src/main/resources/people.json") # Show the content of the DataFrame df.show() @@ -776,8 +777,8 @@ In the simplest form, the default data source (`parquet` unless otherwise config
    {% highlight scala %} -val df = sqlContext.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").save("namesAndFavColors.parquet") +val df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") {% endhighlight %}
    @@ -786,8 +787,8 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet") {% highlight java %} -DataFrame df = sqlContext.load("examples/src/main/resources/users.parquet"); -df.select("name", "favorite_color").save("namesAndFavColors.parquet"); +DataFrame df = sqlContext.read().load("examples/src/main/resources/users.parquet"); +df.select("name", "favorite_color").write().save("namesAndFavColors.parquet"); {% endhighlight %} @@ -797,8 +798,8 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet"); {% highlight python %} -df = sqlContext.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").save("namesAndFavColors.parquet") +df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") {% endhighlight %} @@ -826,8 +827,8 @@ using this syntax.
    {% highlight scala %} -val df = sqlContext.load("examples/src/main/resources/people.json", "json") -df.select("name", "age").save("namesAndAges.parquet", "parquet") +val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") +df.select("name", "age").write.format("json").save("namesAndAges.parquet") {% endhighlight %}
    @@ -836,8 +837,8 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet") {% highlight java %} -DataFrame df = sqlContext.load("examples/src/main/resources/people.json", "json"); -df.select("name", "age").save("namesAndAges.parquet", "parquet"); +DataFrame df = sqlContext.read().format("json").load("examples/src/main/resources/people.json"); +df.select("name", "age").write().format("parquet").save("namesAndAges.parquet"); {% endhighlight %} @@ -847,8 +848,8 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet"); {% highlight python %} -df = sqlContext.load("examples/src/main/resources/people.json", "json") -df.select("name", "age").save("namesAndAges.parquet", "parquet") +df = sqlContext.read.load("examples/src/main/resources/people.json", format="json") +df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") {% endhighlight %} @@ -906,7 +907,7 @@ new data. Ignore mode means that when saving a DataFrame to a data source, if data already exists, the save operation is expected to not save the contents of the DataFrame and to not - change the existing data. This is similar to a `CREATE TABLE IF NOT EXISTS` in SQL. + change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL. @@ -946,11 +947,11 @@ import sqlContext.implicits._ val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. // The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. -people.saveAsParquetFile("people.parquet") +people.write.parquet("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a Parquet file is also a DataFrame. -val parquetFile = sqlContext.parquetFile("people.parquet") +val parquetFile = sqlContext.read.parquet("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile") @@ -968,11 +969,11 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println) DataFrame schemaPeople = ... // The DataFrame from the previous example. // DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet"); +schemaPeople.write().parquet("people.parquet"); // Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. -DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); +DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -994,11 +995,11 @@ List teenagerNames = teenagers.javaRDD().map(new Function() schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet") +schemaPeople.read.parquet("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. -parquetFile = sqlContext.parquetFile("people.parquet") +parquetFile = sqlContext.write.parquet("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -1030,7 +1031,7 @@ teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND a teenNames <- map(teenagers, function(p) { paste("Name:", p$name)}) for (teenName in collect(teenNames)) { cat(teenName, "\n") -} +} {% endhighlight %}
    @@ -1086,9 +1087,9 @@ path {% endhighlight %} -By passing `path/to/table` to either `SQLContext.parquetFile` or `SQLContext.load`, Spark SQL will -automatically extract the partitioning information from the paths. Now the schema of the returned -DataFrame becomes: +By passing `path/to/table` to either `SQLContext.read.parquet` or `SQLContext.read.load`, Spark SQL +will automatically extract the partitioning information from the paths. +Now the schema of the returned DataFrame becomes: {% highlight text %} @@ -1101,7 +1102,11 @@ root {% endhighlight %} Notice that the data types of the partitioning columns are automatically inferred. Currently, -numeric data types and string type are supported. +numeric data types and string type are supported. Sometimes users may not want to automatically +infer the data types of the partitioning columns. For these use cases, the automatic type inference +can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to +`true`. When type inference is disabled, string type will be used for the partitioning columns. + ### Schema merging @@ -1121,15 +1126,15 @@ import sqlContext.implicits._ // Create a simple DataFrame, stored into a partition directory val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") -df1.saveAsParquetFile("data/test_table/key=1") +df1.write.parquet("data/test_table/key=1") // Create another DataFrame in a new partition directory, // adding a new column and dropping an existing column val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") -df2.saveAsParquetFile("data/test_table/key=2") +df2.write.parquet("data/test_table/key=2") // Read the partitioned table -val df3 = sqlContext.parquetFile("data/test_table") +val df3 = sqlContext.read.parquet("data/test_table") df3.printSchema() // The final schema consists of all 3 columns in the Parquet files together @@ -1268,12 +1273,10 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +This conversion can be done using `SQLContext.read.json()` on either an RDD of String, +or a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1284,8 +1287,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. val path = "examples/src/main/resources/people.json" -// Create a DataFrame from the file(s) pointed to by path -val people = sqlContext.jsonFile(path) +val people = sqlContext.read.json(path) // The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -1303,19 +1305,17 @@ val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age // an RDD[String] storing one JSON object per string. val anotherPeopleRDD = sc.parallelize( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) -val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) +val anotherPeople = sqlContext.read.json(anotherPeopleRDD) {% endhighlight %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext` : +This conversion can be done using `SQLContext.read().json()` on either an RDD of String, +or a JSON file. -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. - -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1325,9 +1325,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. -String path = "examples/src/main/resources/people.json"; -// Create a DataFrame from the file(s) pointed to by path -DataFrame people = sqlContext.jsonFile(path); +DataFrame people = sqlContext.read().json("examples/src/main/resources/people.json"); // The inferred schema can be visualized using the printSchema() method. people.printSchema(); @@ -1346,18 +1344,15 @@ DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AN List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); -DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); +DataFrame anotherPeople = sqlContext.read().json(anotherPeopleRDD); {% endhighlight %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +This conversion can be done using `SQLContext.read.json` on a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1368,9 +1363,7 @@ sqlContext = SQLContext(sc) # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. -path = "examples/src/main/resources/people.json" -# Create a DataFrame from the file(s) pointed to by path -people = sqlContext.jsonFile(path) +people = sqlContext.read.json("examples/src/main/resources/people.json") # The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -1393,12 +1386,11 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. using +the `jsonFile` function, which loads data from a directory of JSON files where each line of the +files is a JSON object. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1502,7 +1494,7 @@ Row[] results = sqlContext.sql("FROM src SELECT key, value").collect();
    When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in the MetaStore and writing queries using HiveQL. +adds support for finding tables in the MetaStore and writing queries using HiveQL. {% highlight python %} # sc is an existing SparkContext. from pyspark.sql import HiveContext @@ -1526,8 +1518,8 @@ adds support for finding tables in the MetaStore and writing queries using HiveQ # sc is an existing SparkContext. sqlContext <- sparkRHive.init(sc) -hql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. results = sqlContext.sql("FROM src SELECT key, value").collect() @@ -1537,6 +1529,70 @@ results = sqlContext.sql("FROM src SELECT key, value").collect()
    +### Interacting with Different Versions of Hive Metastore + +One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. + +Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET` +and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive +jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the +version specified by users. An isolated classloader is used here to avoid dependency conflicts. + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.sql.hive.metastore.version0.13.1 + Version of the Hive metastore. Available + options are 0.12.0 and 0.13.1. Support for more versions is coming in the future. +
    spark.sql.hive.metastore.jarsbuiltin + Location of the jars that should be used to instantiate the HiveMetastoreClient. This + property can be one of three options: +
      +
    1. builtin
    2. + Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is + enabled. When this option is chosen, spark.sql.hive.metastore.version must be + either 0.13.1 or not defined. +
    3. maven
    4. + Use Hive jars of specified version downloaded from Maven repositories. +
    5. A classpath in the standard format for both Hive and Hadoop.
    6. +
    +
    spark.sql.hive.metastore.sharedPrefixescom.mysql.jdbc,
    org.postgresql,
    com.microsoft.sqlserver,
    oracle.jdbc
    +

    + A comma separated list of class prefixes that should be loaded using the classloader that is + shared between Spark SQL and a specific version of Hive. An example of classes that should + be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need + to be shared are those that interact with classes that are already shared. For example, + custom appenders that are used by log4j. +

    +
    spark.sql.hive.metastore.barrierPrefixes(empty) +

    + A comma separated list of class prefixes that should explicitly be reloaded for each version + of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + prefix that typically would be shared (i.e. org.apache.spark.*). +

    +
    + + ## JDBC To Other Databases Spark SQL also includes a data source that can read data from other databases using JDBC. This @@ -1570,7 +1626,7 @@ the Data Sources API. The following options are supported: dbtable - The JDBC table that should be read. Note that anything that is valid in a `FROM` clause of + The JDBC table that should be read. Note that anything that is valid in a FROM clause of a SQL query can be used. For example, instead of a full table you could also use a subquery in parentheses. @@ -1714,7 +1770,7 @@ that these options will be deprecated in future release as more optimizations ar Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently statistics are only supported for Hive Metastore tables where the command - `ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan` has been run. + ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run. @@ -1733,11 +1789,20 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. + + spark.sql.planner.externalSort + false + + When true, performs sorts spilling to disk as needed otherwise sort each partition in memory. + + # Distributed SQL Engine -Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, without the need to write any code. +Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. +In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, +without the need to write any code. ## Running the Thrift JDBC/ODBC server @@ -1816,6 +1881,25 @@ options. ## Upgrading from Spark SQL 1.3 to 1.4 +#### DataFrame data reader/writer interface + +Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`) +and writing data out (`DataFrame.write`), +and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`). + +See the API docs for `SQLContext.read` ( + Scala, + Java, + Python +) and `DataFrame.write` ( + Scala, + Java, + Python +) more information. + + +#### DataFrame.groupBy retains grouping columns + Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
    @@ -1939,7 +2023,7 @@ sqlContext.udf.register("strLen", (s: String) => s.length())
    {% highlight java %} -sqlContext.udf().register("strLen", (String s) -> { s.length(); }); +sqlContext.udf().register("strLen", (String s) -> s.length(), DataTypes.IntegerType); {% endhighlight %}
    diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 64714f0b799fc..998c8c994e4b4 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -7,7 +7,7 @@ title: Spark Streaming + Kafka Integration Guide ## Approach 1: Receiver-based Approach This approach uses a Receiver to receive the data. The Received is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. -However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming. To ensure zero data loss, enable the Write Ahead Logs (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. Next, we discuss how to use this approach in your streaming application. @@ -29,7 +29,7 @@ Next, we discuss how to use this approach in your streaming application. [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
    import org.apache.spark.streaming.kafka.*; @@ -39,7 +39,7 @@ Next, we discuss how to use this approach in your streaming application. [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]); You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
    @@ -105,7 +105,7 @@ Next, we discuss how to use this approach in your streaming application. streamingContext, [map of Kafka parameters], [set of topics to consume]) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
    import org.apache.spark.streaming.kafka.*; @@ -116,7 +116,7 @@ Next, we discuss how to use this approach in your streaming application. [map of Kafka parameters], [set of topics to consume]); See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java).
    @@ -153,4 +153,4 @@ Next, we discuss how to use this approach in your streaming application. Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate at which each Kafka partition will be read by this direct API. -3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. \ No newline at end of file +3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index bd863d48d53e3..42b33947873b0 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1946,10 +1946,10 @@ creates a single receiver (running on a worker machine) that receives a single s Receiving multiple data streams can therefore be achieved by creating multiple input DStreams and configuring them to receive different partitions of the data stream from the source(s). For example, a single Kafka input DStream receiving two topics of data can be split into two -Kafka input streams, each receiving only one topic. This would run two receivers on two workers, -thus allowing data to be received in parallel, and increasing overall throughput. These multiple -DStream can be unioned together to create a single DStream. Then the transformations that was -being applied on the single input DStream can applied on the unified stream. This is done as follows. +Kafka input streams, each receiving only one topic. This would run two receivers, +allowing data to be received in parallel, and increasing overall throughput. These multiple +DStreams can be unioned together to create a single DStream. Then the transformations that were +being applied on a single input DStream can be applied on the unified stream. This is done as follows.
    diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index c6d5a1f0d0a81..84629cb9a0ca0 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -19,8 +19,9 @@ # limitations under the License. # -from __future__ import with_statement, print_function +from __future__ import division, print_function, with_statement +import codecs import hashlib import itertools import logging @@ -47,6 +48,8 @@ else: from urllib.request import urlopen, Request from urllib.error import HTTPError + raw_input = input + xrange = range SPARK_EC2_VERSION = "1.3.1" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -216,7 +219,8 @@ def parse_args(): "(default: %default).") parser.add_option( "--hadoop-major-version", default="1", - help="Major version of Hadoop (default: %default)") + help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.2.0), yarn " + + "(Hadoop 2.4.0) (default: %default)") parser.add_option( "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + @@ -268,7 +272,8 @@ def parse_args(): help="Launch fresh slaves, but use an existing stopped master if possible") parser.add_option( "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: %default)") + help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " + + "is used as Hadoop major version (default: %default)") parser.add_option( "--master-opts", type="string", default="", help="Extra options to give to master through SPARK_MASTER_OPTS variable " + @@ -423,13 +428,14 @@ def get_spark_ami(opts): b=opts.spark_ec2_git_branch) ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) + reader = codecs.getreader("ascii") try: - ami = urlopen(ami_path).read().strip() - print("Spark AMI: " + ami) + ami = reader(urlopen(ami_path)).read().strip() except: print("Could not resolve AMI at: " + ami_path, file=stderr) sys.exit(1) + print("Spark AMI: " + ami) return ami @@ -487,6 +493,8 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('udp', 2049, 2049, authorized_address) master_group.authorize('tcp', 4242, 4242, authorized_address) master_group.authorize('udp', 4242, 4242, authorized_address) + # RM in YARN mode uses 8088 + master_group.authorize('tcp', 8088, 8088, authorized_address) if opts.ganglia: master_group.authorize('tcp', 5080, 5080, authorized_address) if slave_group.rules == []: # Group was just now created @@ -750,11 +758,15 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): 'mapreduce', 'spark-standalone', 'tachyon'] if opts.hadoop_major_version == "1": - modules = filter(lambda x: x != "mapreduce", modules) + modules = list(filter(lambda x: x != "mapreduce", modules)) if opts.ganglia: modules.append('ganglia') + # Clear SPARK_WORKER_INSTANCES if running on YARN + if opts.hadoop_major_version == "yarn": + opts.worker_instances = "" + # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( @@ -992,6 +1004,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] + worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else "" template_vars = { "master_list": '\n'.join(master_addresses), "active_master": active_master, @@ -1005,7 +1018,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): "spark_version": spark_v, "tachyon_version": tachyon_v, "hadoop_major_version": opts.hadoop_major_version, - "spark_worker_instances": "%d" % opts.worker_instances, + "spark_worker_instances": worker_instances_str, "spark_master_opts": opts.master_opts } @@ -1160,7 +1173,7 @@ def get_zones(conn, opts): # Gets the number of items in a partition def get_partition(total, num_partitions, current_partitions): - num_slaves_this_zone = total / num_partitions + num_slaves_this_zone = total // num_partitions if (total % num_partitions) - current_partitions > 0: num_slaves_this_zone += 1 return num_slaves_this_zone diff --git a/examples/pom.xml b/examples/pom.xml index 5b04b4f8d6ca0..e6884b09dca94 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -97,6 +97,11 @@ + + org.apache.spark + spark-streaming-kafka_${scala.binary.version} + ${project.version} + org.apache.hbase hbase-testing-util @@ -392,45 +397,6 @@ - - - scala-2.10 - - !scala-2.11 - - - - org.apache.spark - spark-streaming-kafka_${scala.binary.version} - ${project.version} - - - - - - org.codehaus.mojo - build-helper-maven-plugin - - - add-scala-sources - generate-sources - - add-source - - - - src/main/scala - scala-2.10/src/main/scala - scala-2.10/src/main/java - - - - - - - - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 29158d5c85651..dac649d1d5ae6 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -97,7 +97,7 @@ public static void main(String[] args) { DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. + // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. DataFrame results = model2.transform(test); diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java similarity index 100% rename from examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java similarity index 100% rename from examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py new file mode 100644 index 0000000000000..f0ca97c724940 --- /dev/null +++ b/examples/src/main/python/ml/cross_validator.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.evaluation import BinaryClassificationEvaluator +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.ml.tuning import CrossValidator, ParamGridBuilder +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating model selection using CrossValidator. +This example also demonstrates how Pipelines are Estimators. +Run with: + + bin/spark-submit examples/src/main/python/ml/cross_validator.py +""" + +if __name__ == "__main__": + sc = SparkContext(appName="CrossValidatorExample") + sqlContext = SQLContext(sc) + + # Prepare training documents, which are labeled. + LabeledDocument = Row("id", "text", "label") + training = sc.parallelize([(0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0), + (4, "b spark who", 1.0), + (5, "g d a y", 0.0), + (6, "spark fly", 1.0), + (7, "was mapreduce", 0.0), + (8, "e spark program", 1.0), + (9, "a e c l", 0.0), + (10, "spark compile", 1.0), + (11, "hadoop software", 0.0) + ]) \ + .map(lambda x: LabeledDocument(*x)).toDF() + + # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + + # We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + # This will allow us to jointly choose parameters for all Pipeline stages. + # A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + # We use a ParamGridBuilder to construct a grid of parameters to search over. + # With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + # this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + paramGrid = ParamGridBuilder() \ + .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \ + .addGrid(lr.regParam, [0.1, 0.01]) \ + .build() + + crossval = CrossValidator(estimator=pipeline, + estimatorParamMaps=paramGrid, + evaluator=BinaryClassificationEvaluator(), + numFolds=2) # use 3+ folds in practice + + # Run cross-validation, and choose the best set of parameters. + cvModel = crossval.fit(training) + + # Prepare test documents, which are unlabeled. + Document = Row("id", "text") + test = sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() + + # Make predictions on test documents. cvModel uses the best model found (lrModel). + prediction = cvModel.transform(test) + selected = prediction.select("id", "text", "probability", "prediction") + for row in selected.collect(): + print(row) + + sc.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py new file mode 100644 index 0000000000000..6446f0fe5eeab --- /dev/null +++ b/examples/src/main/python/ml/gradient_boosted_trees.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import GBTClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import GBTRegressor +from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline. +Note: GBTClassifier only supports binary classification currently +Run with: + bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py +""" + + +def testClassification(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = BinaryClassificationMetrics(predictionAndLabels) + print("AUC %.3f" % metrics.areaUnderROC) + + +def testRegression(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: gradient_boosted_trees", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonGBTExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py new file mode 100644 index 0000000000000..c7730e1bfacd9 --- /dev/null +++ b/examples/src/main/python/ml/random_forest_example.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import RandomForestClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import RandomForestRegressor +from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a RandomForest Classification/Regression Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/random_forest_example.py +""" + + +def testClassification(train, test): + # Train a RandomForest model. + # Setting featureSubsetStrategy="auto" lets the algorithm choose. + # Note: Use larger numTrees in practice. + + rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + +def testRegression(train, test): + # Train a RandomForest model. + # Note: Use larger numTrees in practice. + + rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: random_forest_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonRandomForestExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py new file mode 100644 index 0000000000000..a9f29dab2d602 --- /dev/null +++ b/examples/src/main/python/ml/simple_params_example.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import pprint +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.linalg import DenseVector +from pyspark.mllib.regression import LabeledPoint +from pyspark.sql import SQLContext + +""" +A simple example demonstrating ways to specify parameters for Estimators and Transformers. +Run with: + bin/spark-submit examples/src/main/python/ml/simple_params_example.py +""" + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: simple_params_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonSimpleParamsExample") + sqlContext = SQLContext(sc) + + # prepare training data. + # We create an RDD of LabeledPoints and convert them into a DataFrame. + # A LabeledPoint is an Object with two fields named label and features + # and Spark SQL identifies these fields and creates the schema appropriately. + training = sc.parallelize([ + LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])), + LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])), + LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])), + LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF() + + # Create a LogisticRegression instance with maxIter = 10. + # This instance is an Estimator. + lr = LogisticRegression(maxIter=10) + # Print out the parameters, documentation, and any default values. + print("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + + # We may also set parameters using setter methods. + lr.setRegParam(0.01) + + # Learn a LogisticRegression model. This uses the parameters stored in lr. + model1 = lr.fit(training) + + # Since model1 is a Model (i.e., a Transformer produced by an Estimator), + # we can view the parameters it used during fit(). + # This prints the parameter (name: value) pairs, where names are unique IDs for this + # LogisticRegression instance. + print("Model 1 was fit using parameters:\n") + pprint.pprint(model1.extractParamMap()) + + # We may alternatively specify parameters using a parameter map. + # paramMap overrides all lr parameters set earlier. + paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"} + + # Now learn a new model using the new parameters. + model2 = lr.fit(training, paramMap) + print("Model 2 was fit using parameters:\n") + pprint.pprint(model2.extractParamMap()) + + # prepare test data. + test = sc.parallelize([ + LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])), + LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])), + LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF() + + # Make predictions on test data using the Transformer.transform() method. + # LogisticRegressionModel.transform will only use the 'features' column. + # Note that model2.transform() outputs a 'myProbability' column instead of the usual + # 'probability' column since we renamed the lr.probabilityCol parameter previously. + result = model2.transform(test) \ + .select("features", "label", "myProbability", "prediction") \ + .collect() + + for row in result: + print("features=%s,label=%s -> prob=%s, prediction=%s" + % (row.features, row.label, row.myProbability, row.prediction)) + + sc.stop() diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index 96ddac761d698..e1fd85b082c08 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -51,7 +51,7 @@ parquet_rdd = sc.newAPIHadoopFile( path, - 'parquet.avro.AvroParquetInputFormat', + 'org.apache.parquet.avro.AvroParquetInputFormat', 'java.lang.Void', 'org.apache.avro.generic.IndexedRecord', valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter') diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 11d5c92c5952d..023bb3ee2d108 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -104,8 +104,8 @@ object CassandraCQLTest { val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), classOf[CqlPagingInputFormat], - classOf[java.util.Map[String,ByteBuffer]], - classOf[java.util.Map[String,ByteBuffer]]) + classOf[java.util.Map[String, ByteBuffer]], + classOf[java.util.Map[String, ByteBuffer]]) println("Count: " + casRdd.count) val productSaleRDD = casRdd.map { @@ -118,7 +118,7 @@ object CassandraCQLTest { case (productId, saleCount) => println(productId + ":" + saleCount) } - val casoutputCF = aggregatedRDD.map { + val casoutputCF = aggregatedRDD.map { case (productId, saleCount) => { val outColFamKey = Map("prod_id" -> ByteBufferUtil.bytes(productId)) val outKey: java.util.Map[String, ByteBuffer] = outColFamKey diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index a55e0dc8d36c2..c3fc74a116c0a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -39,7 +39,7 @@ object LocalLR { def generateData: Array[DataPoint] = { def generatePoint(i: Int): DataPoint = { - val y = if(i % 2 == 0) -1 else 1 + val y = if (i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) } diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 32e02eab8b031..75c82117cbad2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -22,7 +22,7 @@ import org.apache.spark.SparkContext._ /** * Executes a roll up-style query against Apache logs. - * + * * Usage: LogQuery [logFile] */ object LogQuery { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 6c0ac8013ce34..30c4261551837 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -117,7 +117,7 @@ object SparkALS { var us = Array.fill(U)(randomVector(F)) // Iteratively update movies then users - val Rc = sc.broadcast(R) + val Rc = sc.broadcast(R) var msb = sc.broadcast(ms) var usb = sc.broadcast(us) for (iter <- 1 to ITERATIONS) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 8c01a60844620..1e6b4fb0c7514 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -44,7 +44,7 @@ object SparkLR { def generateData: Array[DataPoint] = { def generatePoint(i: Int): DataPoint = { - val y = if(i % 2 == 0) -1 else 1 + val y = if (i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 8d092b6506d33..bd7894f184c4c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -51,7 +51,7 @@ object SparkPageRank { showWarning() val sparkConf = new SparkConf().setAppName("PageRank") - val iters = if (args.length > 0) args(1).toInt else 10 + val iters = if (args.length > 1) args(1).toInt else 10 val ctx = new SparkContext(sparkConf) val lines = ctx.textFile(args(0), 1) val links = lines.map{ s => diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala deleted file mode 100644 index ab6e63deb3c95..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import org.apache.spark._ -import org.apache.spark.bagel._ - -class PageRankUtils extends Serializable { - def computeWithCombiner(numVertices: Long, epsilon: Double)( - self: PRVertex, messageSum: Option[Double], superstep: Int - ): (PRVertex, Array[PRMessage]) = { - val newValue = messageSum match { - case Some(msgSum) if msgSum != 0 => - 0.15 / numVertices + 0.85 * msgSum - case _ => self.value - } - - val terminate = superstep >= 10 - - val outbox: Array[PRMessage] = - if (!terminate) { - self.outEdges.map(targetId => new PRMessage(targetId, newValue / self.outEdges.size)) - } else { - Array[PRMessage]() - } - - (new PRVertex(newValue, self.outEdges, !terminate), outbox) - } - - def computeNoCombiner(numVertices: Long, epsilon: Double) - (self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int) - : (PRVertex, Array[PRMessage]) = - computeWithCombiner(numVertices, epsilon)(self, messages match { - case Some(msgs) => Some(msgs.map(_.value).sum) - case None => None - }, superstep) -} - -class PRCombiner extends Combiner[PRMessage, Double] with Serializable { - def createCombiner(msg: PRMessage): Double = - msg.value - def mergeMsg(combiner: Double, msg: PRMessage): Double = - combiner + msg.value - def mergeCombiners(a: Double, b: Double): Double = - a + b -} - -class PRVertex() extends Vertex with Serializable { - var value: Double = _ - var outEdges: Array[String] = _ - var active: Boolean = _ - - def this(value: Double, outEdges: Array[String], active: Boolean = true) { - this() - this.value = value - this.outEdges = outEdges - this.active = active - } - - override def toString(): String = { - "PRVertex(value=%f, outEdges.length=%d, active=%s)" - .format(value, outEdges.length, active.toString) - } -} - -class PRMessage() extends Message[String] with Serializable { - var targetId: String = _ - var value: Double = _ - - def this(targetId: String, value: Double) { - this() - this.targetId = targetId - this.value = value - } -} - -class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions: Int = partitions - - def getPartition(key: Any): Int = { - val hash = key match { - case k: Long => (k & 0x00000000FFFFFFFFL).toInt - case _ => key.hashCode - } - - val mod = key.hashCode % partitions - if (mod < 0) mod + partitions else mod - } - - override def equals(other: Any): Boolean = other match { - case c: CustomPartitioner => - c.numPartitions == numPartitions - case _ => false - } - - override def hashCode: Int = numPartitions -} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala deleted file mode 100644 index 859abedf2a55e..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import org.apache.spark._ -import org.apache.spark.SparkContext._ - -import org.apache.spark.bagel._ - -import scala.xml.{XML,NodeSeq} - -/** - * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles" - * files from there, which contains one line per wiki article in a tab-separated format - * (http://wiki.freebase.com/wiki/WEX/Documentation#articles). - */ -object WikipediaPageRank { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println( - "Usage: WikipediaPageRank ") - System.exit(-1) - } - val sparkConf = new SparkConf() - sparkConf.setAppName("WikipediaPageRank") - sparkConf.registerKryoClasses(Array(classOf[PRVertex], classOf[PRMessage])) - - val inputFile = args(0) - val threshold = args(1).toDouble - val numPartitions = args(2).toInt - val usePartitioner = args(3).toBoolean - - sparkConf.setAppName("WikipediaPageRank") - val sc = new SparkContext(sparkConf) - - // Parse the Wikipedia page data into a graph - val input = sc.textFile(inputFile) - - println("Counting vertices...") - val numVertices = input.count() - println("Done counting vertices.") - - println("Parsing input file...") - var vertices = input.map(line => { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val links = - if (body == "\\N") { - NodeSeq.Empty - } else { - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body) - NodeSeq.Empty - } - } - val outEdges = links.map(link => new String(link.text)).toArray - val id = new String(title) - (id, new PRVertex(1.0 / numVertices, outEdges)) - }) - if (usePartitioner) { - vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache() - } else { - vertices = vertices.cache() - } - println("Done parsing input file.") - - // Do the computation - val epsilon = 0.01 / numVertices - val messages = sc.parallelize(Array[(String, PRMessage)]()) - val utils = new PageRankUtils - val result = - Bagel.run( - sc, vertices, messages, combiner = new PRCombiner(), - numPartitions = numPartitions)( - utils.computeWithCombiner(numVertices, epsilon)) - - // Print the result - System.err.println("Articles with PageRank >= " + threshold + ":") - val top = - (result - .filter { case (id, vertex) => vertex.value >= threshold } - .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } - .collect().mkString) - println(top) - - sc.stop() - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala deleted file mode 100644 index 576a3e371b993..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer -import scala.xml.{XML, NodeSeq} - -import org.apache.spark._ -import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - -import scala.reflect.ClassTag - -object WikipediaPageRankStandalone { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: WikipediaPageRankStandalone " + - " ") - System.exit(-1) - } - val sparkConf = new SparkConf() - sparkConf.set("spark.serializer", "spark.bagel.examples.WPRSerializer") - - val inputFile = args(0) - val threshold = args(1).toDouble - val numIterations = args(2).toInt - val usePartitioner = args(3).toBoolean - - sparkConf.setAppName("WikipediaPageRankStandalone") - - val sc = new SparkContext(sparkConf) - - val input = sc.textFile(inputFile) - val partitioner = new HashPartitioner(sc.defaultParallelism) - val links = - if (usePartitioner) { - input.map(parseArticle _).partitionBy(partitioner).cache() - } else { - input.map(parseArticle _).cache() - } - val n = links.count() - val defaultRank = 1.0 / n - val a = 0.15 - - // Do the computation - val startTime = System.currentTimeMillis - val ranks = - pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, - sc.defaultParallelism) - - // Print the result - System.err.println("Articles with PageRank >= " + threshold + ":") - val top = - (ranks - .filter { case (id, rank) => rank >= threshold } - .map { case (id, rank) => "%s\t%s\n".format(id, rank) } - .collect().mkString) - println(top) - - val time = (System.currentTimeMillis - startTime) / 1000.0 - println("Completed %d iterations in %f seconds: %f seconds per iteration" - .format(numIterations, time, time / numIterations)) - sc.stop() - } - - def parseArticle(line: String): (String, Array[String]) = { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val id = new String(title) - val links = - if (body == "\\N") { - NodeSeq.Empty - } else { - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body) - NodeSeq.Empty - } - } - val outEdges = links.map(link => new String(link.text)).toArray - (id, outEdges) - } - - def pageRank( - links: RDD[(String, Array[String])], - numIterations: Int, - defaultRank: Double, - a: Double, - n: Long, - partitioner: Partitioner, - usePartitioner: Boolean, - numPartitions: Int - ): RDD[(String, Double)] = { - var ranks = links.mapValues { edges => defaultRank } - for (i <- 1 to numIterations) { - val contribs = links.groupWith(ranks).flatMap { - case (id, (linksWrapperIterable, rankWrapperIterable)) => - val linksWrapper = linksWrapperIterable.iterator - val rankWrapper = rankWrapperIterable.iterator - if (linksWrapper.hasNext) { - val linksWrapperHead = linksWrapper.next - if (rankWrapper.hasNext) { - val rankWrapperHead = rankWrapper.next - linksWrapperHead.map(dest => (dest, rankWrapperHead / linksWrapperHead.size)) - } else { - linksWrapperHead.map(dest => (dest, defaultRank / linksWrapperHead.size)) - } - } else { - Array[(String, Double)]() - } - } - ranks = (contribs.combineByKey((x: Double) => x, - (x: Double, y: Double) => x + y, - (x: Double, y: Double) => x + y, - partitioner) - .mapValues(sum => a/n + (1-a)*sum)) - } - ranks - } -} - -class WPRSerializer extends org.apache.spark.serializer.Serializer { - def newInstance(): SerializerInstance = new WPRSerializerInstance() -} - -class WPRSerializerInstance extends SerializerInstance { - def serialize[T: ClassTag](t: T): ByteBuffer = { - throw new UnsupportedOperationException() - } - - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - throw new UnsupportedOperationException() - } - - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { - throw new UnsupportedOperationException() - } - - def serializeStream(s: OutputStream): SerializationStream = { - new WPRSerializationStream(s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new WPRDeserializationStream(s) - } -} - -class WPRSerializationStream(os: OutputStream) extends SerializationStream { - val dos = new DataOutputStream(os) - - def writeObject[T: ClassTag](t: T): SerializationStream = t match { - case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { - case links: Array[String] => { - dos.writeInt(0) // links - dos.writeUTF(id) - dos.writeInt(links.length) - for (link <- links) { - dos.writeUTF(link) - } - this - } - case rank: Double => { - dos.writeInt(1) // rank - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - case (id: String, rank: Double) => { - dos.writeInt(2) // rank without wrapper - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - - def flush() { dos.flush() } - def close() { dos.close() } -} - -class WPRDeserializationStream(is: InputStream) extends DeserializationStream { - val dis = new DataInputStream(is) - - def readObject[T: ClassTag](): T = { - val typeId = dis.readInt() - typeId match { - case 0 => { - val id = dis.readUTF() - val numLinks = dis.readInt() - val links = new Array[String](numLinks) - for (i <- 0 until numLinks) { - val link = dis.readUTF() - links(i) = link - } - (id, ArrayBuffer(links)).asInstanceOf[T] - } - case 1 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, ArrayBuffer(rank)).asInstanceOf[T] - } - case 2 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, rank).asInstanceOf[T] - } - } - } - - def close() { dis.close() } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala new file mode 100644 index 0000000000000..b54466fd48bc5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import org.apache.spark.sql.DataFrame + +/** + * An example runner for linear regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LinearRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LinearRegressionExample --regParam 0.15 --elasticNetParam 1.0 \ + * data/mllib/sample_linear_regression_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LinearRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LinearRegressionExample") { + head("LinearRegressionExample: an example Linear Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LinearRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LinearRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "regression", params.fracTest) + + val lir = new LinearRegression() + .setFeaturesCol("features") + .setLabelCol("label") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + // Train the model + val startTime = System.nanoTime() + val lirModel = lir.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Print the weights and intercept for linear regression. + println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label") + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, test, "label") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala new file mode 100644 index 0000000000000..b12f833ce94c8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.feature.StringIndexer +import org.apache.spark.sql.DataFrame + +/** + * An example runner for logistic regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LogisticRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_libsvm_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LogisticRegressionExample --regParam 0.3 --elasticNetParam 0.8 \ + * data/mllib/sample_libsvm_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LogisticRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + fitIntercept: Boolean = true, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LogisticRegressionExample") { + head("LogisticRegressionExample: an example Logistic Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Boolean]("fitIntercept") + .text(s"whether to fit an intercept term, default: ${defaultParams.fitIntercept}") + .action((x, c) => c.copy(fitIntercept = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LogisticRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LogisticRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "classification", params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol("indexedLabel") + stages += labelIndexer + + val lor = new LogisticRegression() + .setFeaturesCol("features") + .setLabelCol("indexedLabel") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + stages += lor + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + val lirModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel] + // Print the weights and intercept for logistic regression. + println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel") + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, "indexedLabel") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index b99d0a1246011..6927eb8f275cf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -73,7 +73,7 @@ object OneVsRestExample { .action((x, c) => c.copy(fracTest = x)) opt[String]("testInput") .text("input path to test dataset. If given, option fracTest is ignored") - .action((x,c) => c.copy(testInput = Some(x))) + .action((x, c) => c.copy(testInput = Some(x))) opt[Int]("maxIter") .text(s"maximum number of iterations for Logistic Regression." + s" default: ${defaultParams.maxIter}") @@ -88,10 +88,10 @@ object OneVsRestExample { .action((x, c) => c.copy(fitIntercept = x)) opt[Double]("regParam") .text(s"the regularization parameter for Logistic Regression.") - .action((x,c) => c.copy(regParam = Some(x))) + .action((x, c) => c.copy(regParam = Some(x))) opt[Double]("elasticNetParam") .text(s"the ElasticNet mixing parameter for Logistic Regression.") - .action((x,c) => c.copy(elasticNetParam = Some(x))) + .action((x, c) => c.copy(elasticNetParam = Some(x))) checkConfig { params => if (params.fracTest < 0 || params.fracTest >= 1) { failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index e8a991f50e338..a0561e2573fc9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -87,7 +87,7 @@ object SimpleParamsExample { LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) // Make predictions on test data using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. + // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. model2.transform(test.toDF()) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index b0613632c9946..3381941673db8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -22,7 +22,6 @@ import scala.language.reflectiveCalls import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -354,7 +353,11 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. + * + * This is just for demo purpose. In general, don't copy this code because it is NOT efficient + * due to the use of structural types, which leads to one reflection call per record. */ + // scalastyle:off structural.type private[mllib] def meanSquaredError( model: { def predict(features: Vector): Double }, data: RDD[LabeledPoint]): Double = { @@ -363,4 +366,5 @@ object DecisionTreeRunner { err * err }.mean() } + // scalastyle:on structural.type } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index df76b45e50810..f8c71ccabc43b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -40,23 +40,23 @@ object DenseGaussianMixture { private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") - val ctx = new SparkContext(conf) - + val ctx = new SparkContext(conf) + val data = ctx.textFile(inputFile).map { line => Vectors.dense(line.trim.split(' ').map(_.toDouble)) }.cache() - + val clusters = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) .setMaxIterations(maxIterations) .run(data) - + for (i <- 0 until clusters.k) { - println("weight=%f\nmu=%s\nsigma=\n%s\n" format + println("weight=%f\nmu=%s\nsigma=\n%s\n" format (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) } - + println("Cluster labels (first <= 100):") val clusterLabels = clusters.predict(data) clusterLabels.take(100).foreach { x => diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index a11890d6f2b1c..3ebb112fc069e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -36,22 +36,21 @@ object AvroConversionUtil extends Serializable { return null } schema.getType match { - case UNION => unpackUnion(obj, schema) - case ARRAY => unpackArray(obj, schema) - case FIXED => unpackFixed(obj, schema) - case MAP => unpackMap(obj, schema) - case BYTES => unpackBytes(obj) - case RECORD => unpackRecord(obj) - case STRING => obj.toString - case ENUM => obj.toString - case NULL => obj + case UNION => unpackUnion(obj, schema) + case ARRAY => unpackArray(obj, schema) + case FIXED => unpackFixed(obj, schema) + case MAP => unpackMap(obj, schema) + case BYTES => unpackBytes(obj) + case RECORD => unpackRecord(obj) + case STRING => obj.toString + case ENUM => obj.toString + case NULL => obj case BOOLEAN => obj - case DOUBLE => obj - case FLOAT => obj - case INT => obj - case LONG => obj - case other => throw new SparkException( - s"Unknown Avro schema type ${other.getName}") + case DOUBLE => obj + case FLOAT => obj + case INT => obj + case LONG => obj + case other => throw new SparkException(s"Unknown Avro schema type ${other.getName}") } } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 92867b44be138..016de4c63d1d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -104,10 +104,8 @@ extends Actor with ActorHelper { object FeederActor { def main(args: Array[String]) { - if(args.length < 2){ - System.err.println( - "Usage: FeederActor \n" - ) + if (args.length < 2){ + System.err.println("Usage: FeederActor \n") System.exit(1) } val Seq(host, port) = args.toSeq diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala similarity index 97% rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala rename to examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index 11a8cf09533ce..fbe394de4a179 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -51,7 +51,7 @@ object DirectKafkaWordCount { // Create context with 2 second batch interval val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") - val ssc = new StreamingContext(sparkConf, Seconds(2)) + val ssc = new StreamingContext(sparkConf, Seconds(2)) // Create direct kafka stream with brokers and topics val topicsSet = topics.split(",").toSet diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala similarity index 95% rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala rename to examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index f407367a54f6c..60416ee343544 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -49,10 +49,10 @@ object KafkaWordCount { val Array(zkQuorum, group, topics, numThreads) = args val sparkConf = new SparkConf().setAppName("KafkaWordCount") - val ssc = new StreamingContext(sparkConf, Seconds(2)) + val ssc = new StreamingContext(sparkConf, Seconds(2)) ssc.checkpoint("checkpoint") - val topicMap = topics.split(",").map((_,numThreads.toInt)).toMap + val topicMap = topics.split(",").map((_, numThreads.toInt)).toMap val lines = KafkaUtils.createStream(ssc, zkQuorum, group, topicMap).map(_._2) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1L)) @@ -96,7 +96,7 @@ object KafkaWordCountProducer { producer.send(message) } - Thread.sleep(100) + Thread.sleep(1000) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 85b9a54b40baf..813c8554f5193 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -40,7 +40,7 @@ object MQTTPublisher { StreamingExamples.setStreamingLogLevels() val Seq(brokerUrl, topic) = args.toSeq - + var client: MqttClient = null try { @@ -49,7 +49,7 @@ object MQTTPublisher { client.connect() - val msgtopic = client.getTopic(topic) + val msgtopic = client.getTopic(topic) val msgContent = "hello mqtt demo for spark streaming" val message = new MqttMessage(msgContent.getBytes("utf-8")) @@ -59,10 +59,10 @@ object MQTTPublisher { println(s"Published data. topic: ${msgtopic.getName()}; Message: $message") } catch { case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(10) + Thread.sleep(10) println("Queue is full, wait for to consume data from the message queue") - } - } + } + } } catch { case e: MqttException => println("Exception Caught: " + e) } finally { @@ -107,7 +107,7 @@ object MQTTWordCount { val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) val words = lines.flatMap(x => x.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - + wordCounts.print() ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 54d996b8ac990..889f052c70263 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -57,8 +57,7 @@ object PageViewGenerator { 404 -> .05) val userZipCode = Map(94709 -> .5, 94117 -> .5) - val userID = Map((1 to 100).map(_ -> .01):_*) - + val userID = Map((1 to 100).map(_ -> .01) : _*) def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { val rand = new Random().nextDouble() diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 1f3e619d97a24..7a7dccc3d0922 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -42,15 +42,46 @@ org.apache.flume flume-ng-sdk + + + + com.google.guava + guava + + + + org.apache.thrift + libthrift + + org.apache.flume flume-ng-core + + + com.google.guava + guava + + + org.apache.thrift + libthrift + + org.scala-lang scala-library + + + com.google.guava + guava + test + + + + diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index fd01807fc3ac4..dc2a4ab138e18 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -21,7 +21,6 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.flume.Channel import org.apache.commons.lang3.RandomStringUtils @@ -45,8 +44,7 @@ import org.apache.commons.lang3.RandomStringUtils private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel, val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging { val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, - new ThreadFactoryBuilder().setDaemon(true) - .setNameFormat("Spark Sink Processor Thread - %d").build())) + new SparkSinkThreadFactory("Spark Sink Processor Thread - %d"))) // Protected by `sequenceNumberToProcessor` private val sequenceNumberToProcessor = mutable.HashMap[CharSequence, TransactionProcessor]() // This sink will not persist sequence numbers and reuses them if it gets restarted. diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala similarity index 61% rename from core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala rename to external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala index d75959f480756..845fc8debda75 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala @@ -14,11 +14,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.spark.streaming.flume.sink -package org.apache.spark.util.collection +import java.util.concurrent.ThreadFactory +import java.util.concurrent.atomic.AtomicLong -private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] { - def hasNext: Boolean = iter.hasNext +/** + * Thread factory that generates daemon threads with a specified name format. + */ +private[sink] class SparkSinkThreadFactory(nameFormat: String) extends ThreadFactory { + + private val threadId = new AtomicLong() + + override def newThread(r: Runnable): Thread = { + val t = new Thread(r, nameFormat.format(threadId.incrementAndGet())) + t.setDaemon(true) + t + } - def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V]) } diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index ea45b14294df9..7ad43b1d7b0a0 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -143,7 +143,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, eventBatch.setErrorMsg(msg) } else { // At this point, the events are available, so fill them into the event batch - eventBatch = new EventBatch("",seqNum, events) + eventBatch = new EventBatch("", seqNum, events) } }) } catch { diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 650b2fbe1c142..fa43629d49771 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -24,16 +24,24 @@ import scala.collection.JavaConversions._ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.event.EventBuilder import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory + +// Due to MNG-1378, there is not a way to include test dependencies transitively. +// We cannot include Spark core tests as a dependency here because it depends on +// Spark core main, which has too many dependencies to require here manually. +// For this reason, we continue to use FunSuite and ignore the scalastyle checks +// that fail if this is detected. +//scalastyle:off import org.scalatest.FunSuite class SparkSinkSuite extends FunSuite { +//scalastyle:on + val eventsPerBatch = 1000 val channelCapacity = 5000 @@ -185,9 +193,8 @@ class SparkSinkSuite extends FunSuite { count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = { (1 to count).map(_ => { - lazy val channelFactoryExecutor = - Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). - setNameFormat("Flume Receiver Channel Thread - %d").build()) + lazy val channelFactoryExecutor = Executors.newCachedThreadPool( + new SparkSinkThreadFactory("Flume Receiver Channel Thread - %d")) lazy val channelFactory = new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) val transceiver = new NettyTransceiver(address, channelFactory) diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 8df7edbdcad33..14f7daaf417e0 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming-flume-sink_${scala.binary.version} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala index dc629df4f4ac2..65c49c131518b 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -60,7 +60,7 @@ private[streaming] object EventTransformer extends Logging { out.write(body) val numHeaders = headers.size() out.writeInt(numHeaders) - for ((k,v) <- headers) { + for ((k, v) <- headers) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 60e2994431b38..1e32a365a1eee 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -152,9 +152,9 @@ class FlumeReceiver( val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(), Executors.newCachedThreadPool()) val channelPipelineFactory = new CompressionChannelPipelineFactory() - + new NettyServer( - responder, + responder, new InetSocketAddress(host, port), channelFactory, channelPipelineFactory, @@ -188,12 +188,12 @@ class FlumeReceiver( override def preferredLocation: Option[String] = Option(host) - /** A Netty Pipeline factory that will decompress incoming data from + /** A Netty Pipeline factory that will decompress incoming data from * and the Netty client and compress data going back to the client. * * The compression on the return is required because Flume requires - * a successful response to indicate it can remove the event/batch - * from the configured channel + * a successful response to indicate it can remove the event/batch + * from the configured channel */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 92fa5b41be89e..583e7dca317ad 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -110,7 +110,7 @@ private[streaming] class FlumePollingReceiver( } /** - * A wrapper around the transceiver and the Avro IPC API. + * A wrapper around the transceiver and the Avro IPC API. * @param transceiver The transceiver to use for communication with Flume * @param client The client that the callbacks are received on. */ diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 93afe50c2134f..d772b9ca9b570 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -31,16 +31,16 @@ import org.apache.flume.conf.Configurables import org.apache.flume.event.EventBuilder import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.{ManualClock, Utils} -class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { val batchCount = 5 val eventsPerBatch = 100 diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 39e6754c81dbf..c926359987d89 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -35,15 +35,15 @@ import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory import org.jboss.netty.handler.codec.compression._ -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} import org.apache.spark.util.Utils -class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") var ssc: StreamingContext = null @@ -138,7 +138,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L val status = client.appendBatch(inputEvents.toList) status should be (avro.Status.OK) } - + eventually(timeout(10 seconds), interval(100 milliseconds)) { val outputEvents = outputBuffer.flatten.map { _.event } outputEvents.foreach { diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 0b79f47647f6b..8059c443827ef 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 243ce6eaca658..ded863bd985e8 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.kafka kafka_${scala.binary.version} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 6cf254a7b69cb..65d51d87f8486 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -113,7 +113,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { r.flatMap { tm: TopicMetadata => tm.partitionsMetadata.map { pm: PartitionMetadata => TopicAndPartition(tm.topic, pm.partitionId) - } + } } } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index 6dc4e9517d5a4..b608b75952721 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -195,6 +195,8 @@ private class KafkaTestUtils extends Logging { val props = new Properties() props.put("metadata.broker.list", brokerAddress) props.put("serializer.class", classOf[StringEncoder].getName) + // wait for all in-sync replicas to ack sends + props.put("request.required.acks", "-1") props } @@ -229,21 +231,6 @@ private class KafkaTestUtils extends Logging { tryAgain(1) } - /** Wait until the leader offset for the given topic/partition equals the specified offset */ - def waitUntilLeaderOffset( - topic: String, - partition: Int, - offset: Long): Unit = { - eventually(Time(10000), Time(100)) { - val kc = new KafkaCluster(Map("metadata.broker.list" -> brokerAddress)) - val tp = TopicAndPartition(topic, partition) - val llo = kc.getLatestLeaderOffsets(Set(tp)).right.get.apply(tp).offset - assert( - llo == offset, - s"$topic $partition $offset not reached after timeout") - } - } - private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { case Some(partitionState) => diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 8be2707528d93..0b8a391a2c569 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -315,7 +315,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -363,7 +363,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -427,7 +427,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -489,7 +489,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index 5cf379635354f..a9dc6e50613ca 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -72,9 +72,6 @@ public void testKafkaRDD() throws InterruptedException { HashMap kafkaParams = new HashMap(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); - kafkaTestUtils.waitUntilLeaderOffset(topic1, 0, topic1data.length); - kafkaTestUtils.waitUntilLeaderOffset(topic2, 0, topic2data.length); - OffsetRange[] offsetRanges = { OffsetRange.create(topic1, 0, 0, 1), OffsetRange.create(topic2, 0, 0, 1) diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index b6d314dfc7783..47bbfb605850a 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -28,10 +28,10 @@ import scala.language.postfixOps import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream @@ -39,7 +39,7 @@ import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.Utils class DirectKafkaStreamSuite - extends FunSuite + extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll with Eventually diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala index 7fb841b79cb65..d66830cbacdee 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.streaming.kafka import scala.util.Random import kafka.common.TopicAndPartition -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll { +import org.apache.spark.SparkFunSuite + +class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll { private val topic = "kcsuitetopic" + Random.nextInt(10000) private val topicAndPartition = TopicAndPartition(topic, 0) private var kc: KafkaCluster = null diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index 39c3fb448ff57..d5baf5fd89994 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -22,11 +22,11 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.apache.spark._ -class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { +class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private var kafkaTestUtils: KafkaTestUtils = _ @@ -61,11 +61,9 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, "group.id" -> s"test-consumer-${Random.nextInt}") - kafkaTestUtils.waitUntilLeaderOffset(topic, 0, messages.size) - val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) - val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( sc, kafkaParams, offsetRanges) val received = rdd.map(_._2).collect.toSet @@ -86,7 +84,6 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { // this is the "lots of messages" case kafkaTestUtils.sendMessages(topic, sent) val sentCount = sent.values.sum - kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount) // rdd defined from leaders after sending messages, should get the number sent val rdd = getRdd(kc, Set(topic)) @@ -113,7 +110,6 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { val sentOnlyOne = Map("d" -> 1) kafkaTestUtils.sendMessages(topic, sentOnlyOne) - kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount + 1) assert(rdd2.isDefined) assert(rdd2.get.count === 0, "got messages when there shouldn't be any") diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 24699dfc33adb..8ee2cc660f849 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -23,14 +23,14 @@ import scala.language.postfixOps import scala.util.Random import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} -class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { +class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll { private var ssc: StreamingContext = _ private var kafkaTestUtils: KafkaTestUtils = _ diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 38548dd73b82c..80e2df62de3fe 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -26,15 +26,15 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.utils.{ZKGroupTopicDirs, ZkUtils} -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils -class ReliableKafkaStreamSuite extends FunSuite +class ReliableKafkaStreamSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with Eventually { private val sparkConf = new SparkConf() diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 98f95a9a64fa0..0e41e5781784b 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.eclipse.paho org.eclipse.paho.client.mqttv3 diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 40f5f18547236..7c2f18cb35bda 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -17,22 +17,10 @@ package org.apache.spark.streaming.mqtt -import java.io.IOException -import java.util.concurrent.Executors -import java.util.Properties - -import scala.collection.JavaConversions._ -import scala.collection.Map -import scala.collection.mutable.HashMap -import scala.reflect.ClassTag - import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken import org.eclipse.paho.client.mqttv3.MqttCallback import org.eclipse.paho.client.mqttv3.MqttClient -import org.eclipse.paho.client.mqttv3.MqttClientPersistence -import org.eclipse.paho.client.mqttv3.MqttException import org.eclipse.paho.client.mqttv3.MqttMessage -import org.eclipse.paho.client.mqttv3.MqttTopic import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence import org.apache.spark.storage.StorageLevel @@ -87,7 +75,7 @@ class MQTTReceiver( // Handles Mqtt message override def messageArrived(topic: String, message: MqttMessage) { - store(new String(message.getPayload(),"utf-8")) + store(new String(message.getPayload(), "utf-8")) } override def deliveryComplete(token: IMqttDeliveryToken) { diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index a19a72c58a705..c4bf5aa7869bb 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -29,7 +29,7 @@ import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually import org.apache.spark.streaming.{Milliseconds, StreamingContext} @@ -37,10 +37,10 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.scheduler.StreamingListener import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils -class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { +class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 8b6a8959ac4cf..178ae8de13b57 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.twitter4j twitter4j-stream diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index 9ee57d7581d85..d9acb568879fe 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.streaming.twitter -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import twitter4j.Status import twitter4j.auth.{NullAuthorization, Authorization} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { val batchDuration = Seconds(1) diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index a50d378b34335..37bfd10d43663 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + ${akka.group} akka-zeromq_${scala.binary.version} diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala index a7566e733d891..35d2e62c68480 100644 --- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala +++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.streaming.zeromq import akka.actor.SupervisorStrategy import akka.util.ByteString import akka.zeromq.Subscribe -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -class ZeroMQStreamSuite extends FunSuite { +class ZeroMQStreamSuite extends SparkFunSuite { val batchDuration = Seconds(1) diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 4351a8a12fe21..f138251748c9e 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 25847a1b33d9c..c6f60bc907438 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -40,6 +40,13 @@ spark-streaming_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index df77f4be9db1d..be8b62d3cc6ba 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -119,7 +119,7 @@ object KinesisWordCountASL extends Logging { val batchInterval = Milliseconds(2000) // Kinesis checkpoint interval is the interval at which the DynamoDB is updated with information - // on sequence number of records that have been received. Same as batchInterval for this + // on sequence number of records that have been received. Same as batchInterval for this // example. val kinesisCheckpointInterval = batchInterval @@ -145,7 +145,7 @@ object KinesisWordCountASL extends Logging { // Map each word to a (word, 1) tuple so we can reduce by key to count the words val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _) - + // Print the first 10 wordCounts wordCounts.print() @@ -208,16 +208,16 @@ object KinesisWordProducerASL { recordsPerSecond: Int, wordsPerRecord: Int): Seq[(String, Int)] = { - val randomWords = List("spark","you","are","my","father") + val randomWords = List("spark", "you", "are", "my", "father") val totals = scala.collection.mutable.Map[String, Int]() - + // Create the low-level Kinesis Client from the AWS Java SDK. val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) kinesisClient.setEndpoint(endpoint) println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" + s" $recordsPerSecond records per second and $wordsPerRecord words per record") - + // Iterate and put records onto the stream per the given recordPerSec and wordsPerRecord for (i <- 1 to 10) { // Generate recordsPerSec records to put onto the stream @@ -255,8 +255,8 @@ object KinesisWordProducerASL { } } -/** - * Utility functions for Spark Streaming examples. +/** + * Utility functions for Spark Streaming examples. * This has been lifted from the examples/ project to remove the circular dependency. */ private[streaming] object StreamingExamples extends Logging { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala index 1c9b0c218ae18..83a4537559512 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -23,20 +23,20 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock} /** * This is a helper class for managing checkpoint clocks. * - * @param checkpointInterval + * @param checkpointInterval * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) */ private[kinesis] class KinesisCheckpointState( - checkpointInterval: Duration, + checkpointInterval: Duration, currentClock: Clock = new SystemClock()) extends Logging { - + /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ val checkpointClock = new ManualClock() checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) /** - * Check if it's time to checkpoint based on the current time and the derived time + * Check if it's time to checkpoint based on the current time and the derived time * for the next checkpoint * * @return true if it's time to checkpoint diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 7dd8bfdc2a6db..1a8a4cecc1141 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -44,12 +44,12 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * https://github.com/awslabs/amazon-kinesis-client * This is a custom receiver used with StreamingContext.receiverStream(Receiver) as described here: * http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * Instances of this class will get shipped to the Spark Streaming Workers to run within a + * Instances of this class will get shipped to the Spark Streaming Workers to run within a * Spark Executor. * * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams * by the Kinesis Client Library. If you change the App name or Stream name, - * the KCL will throw errors. This usually requires deleting the backing + * the KCL will throw errors. This usually requires deleting the backing * DynamoDB table with the same name this Kinesis application. * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) @@ -87,7 +87,7 @@ private[kinesis] class KinesisReceiver( */ /** - * workerId is used by the KCL should be based on the ip address of the actual Spark Worker + * workerId is used by the KCL should be based on the ip address of the actual Spark Worker * where this code runs (not the driver's IP address.) */ private var workerId: String = null @@ -121,7 +121,7 @@ private[kinesis] class KinesisReceiver( /* * RecordProcessorFactory creates impls of IRecordProcessor. - * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the + * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the * IRecordProcessor.processRecords() method. * We're using our custom KinesisRecordProcessor in this case. */ diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index f65e743c4e2a3..fe9e3a0c793e2 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -35,9 +35,9 @@ import com.amazonaws.services.kinesis.model.Record /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. - * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each - * shard in the Kinesis stream upon startup. This is normally done in separate threads, - * but the KCLs within the KinesisReceivers will balance themselves out if you create + * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each + * shard in the Kinesis stream upon startup. This is normally done in separate threads, + * but the KCLs within the KinesisReceivers will balance themselves out if you create * multiple Receivers. * * @param receiver Kinesis receiver @@ -69,14 +69,14 @@ private[kinesis] class KinesisRecordProcessor( * and Spark Streaming's Receiver.store(). * * @param batch list of records from the Kinesis stream shard - * @param checkpointer used to update Kinesis when this batch has been processed/stored + * @param checkpointer used to update Kinesis when this batch has been processed/stored * in the DStream */ override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { if (!receiver.isStopped()) { try { /* - * Notes: + * Notes: * 1) If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the * internally-configured Spark serializer (kryo, etc). @@ -84,19 +84,19 @@ private[kinesis] class KinesisRecordProcessor( * ourselves from Spark's internal serialization strategy. * 3) For performance, the BlockGenerator is asynchronously queuing elements within its * memory before creating blocks. This prevents the small block scenario, but requires - * that you register callbacks to know when a block has been generated and stored + * that you register callbacks to know when a block has been generated and stored * (WAL is sufficient for storage) before can checkpoint back to the source. */ batch.foreach(record => receiver.store(record.getData().array())) - + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") /* - * Checkpoint the sequence number of the last record successfully processed/stored + * Checkpoint the sequence number of the last record successfully processed/stored * in the batch. * In this implementation, we're checkpointing after the given checkpointIntervalMillis. - * Note that this logic requires that processRecords() be called AND that it's time to - * checkpoint. I point this out because there is no background thread running the + * Note that this logic requires that processRecords() be called AND that it's time to + * checkpoint. I point this out because there is no background thread running the * checkpointer. Checkpointing is tested and trigger only when a new batch comes in. * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below). * However, if the worker dies unexpectedly, a checkpoint may not happen. @@ -130,16 +130,16 @@ private[kinesis] class KinesisRecordProcessor( } } else { /* RecordProcessor has been stopped. */ - logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + + logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + s" and shardId $shardId. No more records will be processed.") } } /** * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: - * 1) the stream is resharding by splitting or merging adjacent shards + * 1) the stream is resharding by splitting or merging adjacent shards * (ShutdownReason.TERMINATE) - * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason + * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason * (ShutdownReason.ZOMBIE) * * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE @@ -153,7 +153,7 @@ private[kinesis] class KinesisRecordProcessor( * Checkpoint to indicate that all records from the shard have been drained and processed. * It's now OK to read from the new shards that resulted from a resharding event. */ - case ShutdownReason.TERMINATE => + case ShutdownReason.TERMINATE => KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) /* diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2531aebe7813c..e5acab50181e1 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -55,7 +55,7 @@ object KinesisUtils { */ def createStream( ssc: StreamingContext, - kinesisAppName: String, + kinesisAppName: String, streamName: String, endpointUrl: String, regionName: String, @@ -102,7 +102,7 @@ object KinesisUtils { */ def createStream( ssc: StreamingContext, - kinesisAppName: String, + kinesisAppName: String, streamName: String, endpointUrl: String, regionName: String, diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index e14bbae4a9b6e..478d0019a25f0 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index d38a3aa8256b7..853dea9a7795e 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + com.google.guava guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala index 058c8c8aa1b24..ce1054ed92ba1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala @@ -26,8 +26,8 @@ class EdgeDirection private (private val name: String) extends Serializable { * out becomes in and both and either remain the same. */ def reverse: EdgeDirection = this match { - case EdgeDirection.In => EdgeDirection.Out - case EdgeDirection.Out => EdgeDirection.In + case EdgeDirection.In => EdgeDirection.Out + case EdgeDirection.Out => EdgeDirection.In case EdgeDirection.Either => EdgeDirection.Either case EdgeDirection.Both => EdgeDirection.Both } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index cc70b396a8dd4..4611a3ace219b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -41,14 +41,16 @@ abstract class EdgeRDD[ED]( @transient sc: SparkContext, @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { + // scalastyle:off structural.type private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } + // scalastyle:on structural.type override protected def getPartitions: Array[Partition] = partitionsRDD.partitions override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context) if (p.hasNext) { - p.next._2.iterator.map(_.copy()) + p.next()._2.iterator.map(_.copy()) } else { Iterator.empty } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala index c8790cac3d8a0..65f82429d2029 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala @@ -37,7 +37,7 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { /** * Set the edge properties of this triplet. */ - protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD,ED] = { + protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD, ED] = { srcId = other.srcId dstId = other.dstId attr = other.attr diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 36dc7b0f86c89..db73a8abc5733 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -316,7 +316,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * satisfy the predicates */ def subgraph( - epred: EdgeTriplet[VD,ED] => Boolean = (x => true), + epred: EdgeTriplet[VD, ED] => Boolean = (x => true), vpred: (VertexId, VD) => Boolean = ((v, d) => true)) : Graph[VD, ED] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 7edd627b20918..9451ff1e5c0e2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -124,18 +124,18 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = { val nbrs = edgeDirection match { case EdgeDirection.Either => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => { ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))) ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))) }, (a, b) => a ++ b, TripletFields.All) case EdgeDirection.In => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))), (a, b) => a ++ b, TripletFields.Src) case EdgeDirection.Out => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))), (a, b) => a ++ b, TripletFields.Dst) case EdgeDirection.Both => @@ -253,7 +253,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def filter[VD2: ClassTag, ED2: ClassTag]( preprocess: Graph[VD, ED] => Graph[VD2, ED2], epred: (EdgeTriplet[VD2, ED2]) => Boolean = (x: EdgeTriplet[VD2, ED2]) => true, - vpred: (VertexId, VD2) => Boolean = (v:VertexId, d:VD2) => true): Graph[VD, ED] = { + vpred: (VertexId, VD2) => Boolean = (v: VertexId, d: VD2) => true): Graph[VD, ED] = { graph.mask(preprocess(graph).subgraph(epred, vpred)) } @@ -356,7 +356,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali maxIterations: Int = Int.MaxValue, activeDirection: EdgeDirection = EdgeDirection.Either)( vprog: (VertexId, VD, A) => VD, - sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId,A)], + sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], mergeMsg: (A, A) => A) : Graph[VD, ED] = { Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 01b013ff716fc..cfcf7244eaed5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -147,10 +147,10 @@ object Pregel extends Logging { logInfo("Pregel finished iteration " + i) // Unpersist the RDDs hidden by newly-materialized RDDs - oldMessages.unpersist(blocking=false) - newVerts.unpersist(blocking=false) - prevG.unpersistVertices(blocking=false) - prevG.edges.unpersist(blocking=false) + oldMessages.unpersist(blocking = false) + newVerts.unpersist(blocking = false) + prevG.unpersistVertices(blocking = false) + prevG.edges.unpersist(blocking = false) // count the iteration i += 1 } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index c561570809253..ab021a252eb8a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -156,8 +156,8 @@ class EdgePartition[ val size = data.size var i = 0 while (i < size) { - edge.srcId = srcIds(i) - edge.dstId = dstIds(i) + edge.srcId = srcIds(i) + edge.dstId = dstIds(i) edge.attr = data(i) newData(i) = f(edge) i += 1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index bc974b2f04e70..8c0a461e99fa4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -116,7 +116,7 @@ object PageRank extends Logging { val personalized = srcId isDefined val src: VertexId = srcId.getOrElse(-1L) - def delta(u: VertexId, v: VertexId):Double = { if (u == v) 1.0 else 0.0 } + def delta(u: VertexId, v: VertexId): Double = { if (u == v) 1.0 else 0.0 } var iteration = 0 var prevRankGraph: Graph[Double, Double] = null @@ -133,13 +133,13 @@ object PageRank extends Logging { // edge partitions. prevRankGraph = rankGraph val rPrb = if (personalized) { - (src: VertexId ,id: VertexId) => resetProb * delta(src,id) + (src: VertexId , id: VertexId) => resetProb * delta(src, id) } else { (src: VertexId, id: VertexId) => resetProb } rankGraph = rankGraph.joinVertices(rankUpdates) { - (id, oldRank, msgSum) => rPrb(src,id) + (1.0 - resetProb) * msgSum + (id, oldRank, msgSum) => rPrb(src, id) + (1.0 - resetProb) * msgSum }.cache() rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices @@ -243,7 +243,7 @@ object PageRank extends Logging { // Execute a dynamic version of Pregel. val vp = if (personalized) { - (id: VertexId, attr: (Double, Double),msgSum: Double) => + (id: VertexId, attr: (Double, Double), msgSum: Double) => personalizedVertexProgram(id, attr, msgSum) } else { (id: VertexId, attr: (Double, Double), msgSum: Double) => diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 3b0e1628d86b5..9cb24ed080e1c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -210,7 +210,7 @@ object SVDPlusPlus { /** * Forces materialization of a Graph by count()ing its RDDs. */ - private def materialize(g: Graph[_,_]): Unit = { + private def materialize(g: Graph[_, _]): Unit = { g.vertices.count() g.edges.count() } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala index daf162085e3e4..a5d598053f9ca 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala @@ -38,7 +38,7 @@ import org.apache.spark.graphx._ */ object TriangleCount { - def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD,ED]): Graph[Int, ED] = { + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = { // Remove redundant edges val g = graph.groupEdges((a, b) => a).cache() @@ -49,7 +49,7 @@ object TriangleCount { var i = 0 while (i < nbrs.size) { // prevent self cycle - if(nbrs(i) != vid) { + if (nbrs(i) != vid) { set.add(nbrs(i)) } i += 1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 2d6a825b61726..9591c4e9b8f4e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -243,14 +243,15 @@ object GraphGenerators { * @return A graph containing vertices with the row and column ids * as their attributes and edge values as 1.0. */ - def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int,Int), Double] = { + def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int, Int), Double] = { // Convert row column address into vertex ids (row major order) def sub2ind(r: Int, c: Int): VertexId = r * cols + c - val vertices: RDD[(VertexId, (Int,Int))] = - sc.parallelize(0 until rows).flatMap( r => (0 until cols).map( c => (sub2ind(r,c), (r,c)) ) ) + val vertices: RDD[(VertexId, (Int, Int))] = sc.parallelize(0 until rows).flatMap { r => + (0 until cols).map( c => (sub2ind(r, c), (r, c)) ) + } val edges: RDD[Edge[Double]] = - vertices.flatMap{ case (vid, (r,c)) => + vertices.flatMap{ case (vid, (r, c)) => (if (r + 1 < rows) { Seq( (sub2ind(r, c), sub2ind(r + 1, c))) } else { Seq.empty }) ++ (if (c + 1 < cols) { Seq( (sub2ind(r, c), sub2ind(r, c + 1))) } else { Seq.empty }) }.map{ case (src, dst) => Edge(src, dst, 1.0) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala index eb1dbe52c2fda..f1ecc9e2219d1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel -class EdgeRDDSuite extends FunSuite with LocalSparkContext { +class EdgeRDDSuite extends SparkFunSuite with LocalSparkContext { test("cache, getStorageLevel") { // test to see if getStorageLevel returns correct value after caching diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala index 5a2c73b414279..094a63472eaab 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala @@ -17,21 +17,21 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class EdgeSuite extends FunSuite { +class EdgeSuite extends SparkFunSuite { test ("compare") { // decending order val testEdges: Array[Edge[Int]] = Array( - Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), - Edge(0x2345L, 0x1234L, 1), - Edge(0x1234L, 0x5678L, 1), - Edge(0x1234L, 0x2345L, 1), + Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), + Edge(0x2345L, 0x1234L, 1), + Edge(0x1234L, 0x5678L, 1), + Edge(0x1234L, 0x2345L, 1), Edge(-0x7FEDCBA987654321L, 0x7FEDCBA987654321L, 1) ) // to ascending order val sortedEdges = testEdges.sorted(Edge.lexicographicOrdering[Int]) - + for (i <- 0 until testEdges.length) { assert(sortedEdges(i) == testEdges(testEdges.length - i - 1)) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index 9bc8007ce49cd..57a8b95dd12e9 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.rdd._ -import org.scalatest.FunSuite -class GraphOpsSuite extends FunSuite with LocalSparkContext { +class GraphOpsSuite extends SparkFunSuite with LocalSparkContext { test("joinVertices") { withSpark { sc => @@ -59,7 +58,7 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { test ("filter") { withSpark { sc => val n = 5 - val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x))) + val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) val graph: Graph[Int, Int] = Graph(vertices, edges).cache() val filteredGraph = graph.filter( @@ -67,11 +66,11 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { val degrees: VertexRDD[Int] = graph.outDegrees graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)} }, - vpred = (vid: VertexId, deg:Int) => deg > 0 + vpred = (vid: VertexId, deg: Int) => deg > 0 ).cache() val v = filteredGraph.vertices.collect().toSet - assert(v === Set((0,0))) + assert(v === Set((0, 0))) // the map is necessary because of object-reuse in the edge iterator val e = filteredGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index a570e4ed75fc3..1f5e27d5508b8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.PartitionStrategy._ import org.apache.spark.rdd._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class GraphSuite extends FunSuite with LocalSparkContext { +class GraphSuite extends SparkFunSuite with LocalSparkContext { def starGraph(sc: SparkContext, n: Int): Graph[String, Int] = { Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: VertexId, x: VertexId)), 3), "v") @@ -248,7 +246,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { test("mask") { withSpark { sc => val n = 5 - val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x))) + val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) val graph: Graph[Int, Int] = Graph(vertices, edges).cache() @@ -260,11 +258,11 @@ class GraphSuite extends FunSuite with LocalSparkContext { val projectedGraph = graph.mask(subgraph) val v = projectedGraph.vertices.collect().toSet - assert(v === Set((0,0), (1,1), (2,2), (4,4), (5,5))) + assert(v === Set((0, 0), (1, 1), (2, 2), (4, 4), (5, 5))) // the map is necessary because of object-reuse in the edge iterator val e = projectedGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet - assert(e === Set(Edge(0,1,1), Edge(0,2,2), Edge(0,5,5))) + assert(e === Set(Edge(0, 1, 1), Edge(0, 2, 2), Edge(0, 5, 5))) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala index 490b94429ea1f..8afa2d403b53f 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.rdd._ -class PregelSuite extends FunSuite with LocalSparkContext { +class PregelSuite extends SparkFunSuite with LocalSparkContext { test("1 iteration") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index d0a7198d691d7..f1aa685a79c98 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.{HashPartitioner, SparkContext} +import org.apache.spark.{HashPartitioner, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -class VertexRDDSuite extends FunSuite with LocalSparkContext { +class VertexRDDSuite extends SparkFunSuite with LocalSparkContext { private def vertices(sc: SparkContext, n: Int) = { VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5)) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index 515f3a9cd02eb..7435647c6d9ee 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class EdgePartitionSuite extends FunSuite { +class EdgePartitionSuite extends SparkFunSuite { def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = { val builder = new EdgePartitionBuilder[A, Int] diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index fe8304c1cdc32..1203f8959f506 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.graphx.impl -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class VertexPartitionSuite extends FunSuite { +class VertexPartitionSuite extends SparkFunSuite { test("isDefined, filter") { val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala index 4cc30a96408f8..c965a6eb8df13 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class ConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Grid Connected Components") { withSpark { sc => @@ -52,13 +50,16 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1)) val chain2 = (10 until 20).map(x => (x, x + 1)) - val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0) val ccGraph = twoChains.connectedComponents() val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { - if(id < 10) { assert(cc === 0) } - else { assert(cc === 10) } + if (id < 10) { + assert(cc === 0) + } else { + assert(cc === 10) + } } val ccMap = vertices.toMap for (id <- 0 until 20) { @@ -75,7 +76,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1)) val chain2 = (10 until 20).map(x => (x, x + 1)) - val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse val ccGraph = twoChains.connectedComponents() val vertices = ccGraph.vertices.collect() @@ -106,9 +107,9 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { (4L, ("peter", "student")))) // Create an RDD for edges val relationships: RDD[Edge[String]] = - sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"), + sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"), Edge(2L, 5L, "colleague"), Edge(5L, 7L, "pi"), - Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague"))) + Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague"))) // Edges are: // 2 ---> 5 ---> 3 // | \ diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala index 61fd0c4605568..808877f0590f8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class LabelPropagationSuite extends FunSuite with LocalSparkContext { +class LabelPropagationSuite extends SparkFunSuite with LocalSparkContext { test("Label Propagation") { withSpark { sc => // Construct a graph with two cliques connected by a single edge diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 3f3c9dfd7b3dd..45f1e3011035e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators @@ -31,14 +30,14 @@ object GridPageRank { def sub2ind(r: Int, c: Int): Int = r * nCols + c // Make the grid graph for (r <- 0 until nRows; c <- 0 until nCols) { - val ind = sub2ind(r,c) + val ind = sub2ind(r, c) if (r + 1 < nRows) { outDegree(ind) += 1 - inNbrs(sub2ind(r + 1,c)) += ind + inNbrs(sub2ind(r + 1, c)) += ind } if (c + 1 < nCols) { outDegree(ind) += 1 - inNbrs(sub2ind(r,c + 1)) += ind + inNbrs(sub2ind(r, c + 1)) += ind } } // compute the pagerank @@ -57,7 +56,7 @@ object GridPageRank { } -class PageRankSuite extends FunSuite with LocalSparkContext { +class PageRankSuite extends SparkFunSuite with LocalSparkContext { def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = { a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) } @@ -99,8 +98,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val resetProb = 0.15 val errorTol = 1.0e-5 - val staticRanks1 = starGraph.staticPersonalizedPageRank(0,numIter = 1, resetProb).vertices - val staticRanks2 = starGraph.staticPersonalizedPageRank(0,numIter = 2, resetProb) + val staticRanks1 = starGraph.staticPersonalizedPageRank(0, numIter = 1, resetProb).vertices + val staticRanks2 = starGraph.staticPersonalizedPageRank(0, numIter = 2, resetProb) .vertices.cache() // Static PageRank should only take 2 iterations to converge @@ -117,7 +116,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } assert(staticErrors.sum === 0) - val dynamicRanks = starGraph.personalizedPageRank(0,0, resetProb).vertices.cache() + val dynamicRanks = starGraph.personalizedPageRank(0, 0, resetProb).vertices.cache() assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) } } // end of test Star PageRank @@ -162,7 +161,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { test("Chain PersonalizedPageRank") { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1) ) - val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() val resetProb = 0.15 val tol = 0.0001 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index 7bd6b7f3c4ab2..2991438f5e57e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { +class SVDPlusPlusSuite extends SparkFunSuite with LocalSparkContext { test("Test SVD++ with mean square error on training set") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala index f2c38e79c452c..d7eaa70ce6407 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ShortestPathsSuite extends FunSuite with LocalSparkContext { +class ShortestPathsSuite extends SparkFunSuite with LocalSparkContext { test("Shortest Path Computations") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala index 1f658c371ffcf..d6b03208180db 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class StronglyConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Island Strongly Connected Components") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index 293c7f3ba4c21..c47552cf3a3bd 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.PartitionStrategy.RandomVertexCut -class TriangleCountSuite extends FunSuite with LocalSparkContext { +class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle") { withSpark { sc => @@ -58,7 +57,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { val triangles = Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ Array(0L -> -1L, -1L -> -2L, -2L -> 0L) - val revTriangles = triangles.map { case (a,b) => (b,a) } + val revTriangles = triangles.map { case (a, b) => (b, a) } val rawEdges = sc.parallelize(triangles ++ revTriangles, 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() val triangleCount = graph.triangleCount() diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index f3b3738db0dad..186d0cc2a977b 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class BytecodeUtilsSuite extends FunSuite { +class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala index 8d9c8ddccbb3c..32e0c841c6997 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.LocalSparkContext -class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { +class GraphGeneratorsSuite extends SparkFunSuite with LocalSparkContext { test("GraphGenerators.generateRandomEdges") { val src = 5 diff --git a/launcher/pom.xml b/launcher/pom.xml index ebfa7685eaa18..48dd0d5f9106b 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,14 +22,14 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml org.apache.spark spark-launcher_2.10 jar - Spark Launcher Project + Spark Project Launcher http://spark.apache.org/ launcher diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 33fd813f7a86c..33d65d13f0d25 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -296,6 +296,9 @@ Properties loadPropertiesFile() throws IOException { try { fd = new FileInputStream(propsFile); props.load(new InputStreamReader(fd, "UTF-8")); + for (Map.Entry e : props.entrySet()) { + e.setValue(e.getValue().toString().trim()); + } } finally { if (fd != null) { try { diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 929b29a49ed70..62492f9baf3bb 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -53,21 +53,33 @@ public static void main(String[] argsArray) throws Exception { List args = new ArrayList(Arrays.asList(argsArray)); String className = args.remove(0); - boolean printLaunchCommand; - boolean printUsage; + boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); AbstractCommandBuilder builder; - try { - if (className.equals("org.apache.spark.deploy.SparkSubmit")) { + if (className.equals("org.apache.spark.deploy.SparkSubmit")) { + try { builder = new SparkSubmitCommandBuilder(args); - } else { - builder = new SparkClassCommandBuilder(className, args); + } catch (IllegalArgumentException e) { + printLaunchCommand = false; + System.err.println("Error: " + e.getMessage()); + System.err.println(); + + MainClassOptionParser parser = new MainClassOptionParser(); + try { + parser.parse(args); + } catch (Exception ignored) { + // Ignore parsing exceptions. + } + + List help = new ArrayList(); + if (parser.className != null) { + help.add(parser.CLASS); + help.add(parser.className); + } + help.add(parser.USAGE_ERROR); + builder = new SparkSubmitCommandBuilder(help); } - printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); - printUsage = false; - } catch (IllegalArgumentException e) { - builder = new UsageCommandBuilder(e.getMessage()); - printLaunchCommand = false; - printUsage = true; + } else { + builder = new SparkClassCommandBuilder(className, args); } Map env = new HashMap(); @@ -78,13 +90,7 @@ public static void main(String[] argsArray) throws Exception { } if (isWindows()) { - // When printing the usage message, we can't use "cmd /v" since that prevents the env - // variable from being seen in the caller script. So do not call prepareWindowsCommand(). - if (printUsage) { - System.out.println(join(" ", cmd)); - } else { - System.out.println(prepareWindowsCommand(cmd, env)); - } + System.out.println(prepareWindowsCommand(cmd, env)); } else { // In bash, use NULL as the arg separator since it cannot be used in an argument. List bashCmd = prepareBashCommand(cmd, env); @@ -135,33 +141,30 @@ private static List prepareBashCommand(List cmd, Map buildCommand(Map env) { - if (isWindows()) { - return Arrays.asList("set", "SPARK_LAUNCHER_USAGE_ERROR=" + message); - } else { - return Arrays.asList("usage", message, "1"); - } + protected boolean handleUnknown(String opt) { + return false; + } + + @Override + protected void handleExtraArgs(List extra) { + } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 7d387d406edae..3e5a2820b6c11 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -77,6 +77,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } private final List sparkArgs; + private final boolean printHelp; /** * Controls whether mixing spark-submit arguments with app arguments is allowed. This is needed @@ -87,10 +88,11 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkSubmitCommandBuilder() { this.sparkArgs = new ArrayList(); + this.printHelp = false; } SparkSubmitCommandBuilder(List args) { - this(); + this.sparkArgs = new ArrayList(); List submitArgs = args; if (args.size() > 0 && args.get(0).equals(PYSPARK_SHELL)) { this.allowsMixedArguments = true; @@ -104,14 +106,16 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { this.allowsMixedArguments = false; } - new OptionParser().parse(submitArgs); + OptionParser parser = new OptionParser(); + parser.parse(submitArgs); + this.printHelp = parser.helpRequested; } @Override public List buildCommand(Map env) throws IOException { - if (PYSPARK_SHELL_RESOURCE.equals(appResource)) { + if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printHelp) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL_RESOURCE.equals(appResource)) { + } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printHelp) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -311,6 +315,8 @@ private boolean isThriftServer(String mainClass) { private class OptionParser extends SparkSubmitOptionParser { + boolean helpRequested = false; + @Override protected boolean handle(String opt, String value) { if (opt.equals(MASTER)) { @@ -341,6 +347,9 @@ protected boolean handle(String opt, String value) { allowsMixedArguments = true; appResource = specialClasses.get(value); } + } else if (opt.equals(HELP) || opt.equals(USAGE_ERROR)) { + helpRequested = true; + sparkArgs.add(opt); } else { sparkArgs.add(opt); if (value != null) { @@ -360,6 +369,7 @@ protected boolean handleUnknown(String opt) { appArgs.add(opt); return true; } else { + checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); sparkArgs.add(opt); return false; } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index 229000087688f..b88bba883ac65 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -61,6 +61,7 @@ class SparkSubmitOptionParser { // Options that do not take arguments. protected final String HELP = "--help"; protected final String SUPERVISE = "--supervise"; + protected final String USAGE_ERROR = "--usage-error"; protected final String VERBOSE = "--verbose"; protected final String VERSION = "--version"; @@ -120,6 +121,7 @@ class SparkSubmitOptionParser { final String[][] switches = { { HELP, "-h" }, { SUPERVISE }, + { USAGE_ERROR }, { VERBOSE, "-v" }, { VERSION }, }; diff --git a/make-distribution.sh b/make-distribution.sh index a2b0c431fb4d0..9f063da3a16c0 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -141,22 +141,6 @@ SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hi # because we use "set -o pipefail" echo -n) -JAVA_CMD="$JAVA_HOME"/bin/java -JAVA_VERSION=$("$JAVA_CMD" -version 2>&1) -if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then - echo "***NOTE***: JAVA_HOME is not set to a JDK 6 installation. The resulting" - echo " distribution may not work well with PySpark and will not run" - echo " with Java 6 (See SPARK-1703 and SPARK-1911)." - echo " This test can be disabled by adding --skip-java-test." - echo "Output from 'java -version' was:" - echo "$JAVA_VERSION" - read -p "Would you like to continue anyways? [y,n]: " -r - if [[ ! "$REPLY" =~ ^[Yy]$ ]]; then - echo "Okay, exiting." - exit 1 - fi -fi - if [ "$NAME" == "none" ]; then NAME=$SPARK_HADOOP_VERSION fi diff --git a/mllib/pom.xml b/mllib/pom.xml index 0c07ca1a62fd3..b16058ddc203a 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 9e16e60270141..e9a5d7c0e7988 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for estimators that fit models to data. */ -@AlphaComponent +@DeveloperApi abstract class Estimator[M <: Model[M]] extends PipelineStage { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 70e7495ac616c..186bf7ae7a2f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.ParamMap /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]]. * * @tparam M model type */ -@AlphaComponent +@DeveloperApi abstract class Model[M <: Model[M]] extends Transformer { /** * The parent estimator that produced this model. diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 43bee1b770e67..a9bd28df71ee1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -17,20 +17,23 @@ package org.apache.spark.ml +import java.{util => ju} + +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]]. */ -@AlphaComponent +@DeveloperApi abstract class PipelineStage extends Params with Logging { /** @@ -69,7 +72,7 @@ abstract class PipelineStage extends Params with Logging { } /** - * :: AlphaComponent :: + * :: Experimental :: * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline#fit]] is called, the * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator#fit]] method will @@ -80,7 +83,7 @@ abstract class PipelineStage extends Params with Logging { * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as * an identity transformer. */ -@AlphaComponent +@Experimental class Pipeline(override val uid: String) extends Estimator[PipelineModel] { def this() = this(Identifiable.randomUID("pipeline")) @@ -97,12 +100,9 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { /** @group getParam */ def getStages: Array[PipelineStage] = $(stages).clone() - override def validateParams(paramMap: ParamMap): Unit = { - val map = extractParamMap(paramMap) - getStages.foreach { - case pStage: Params => pStage.validateParams(map) - case _ => - } + override def validateParams(): Unit = { + super.validateParams() + $(stages).foreach(_.validateParams()) } /** @@ -169,15 +169,20 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { } /** - * :: AlphaComponent :: + * :: Experimental :: * Represents a fitted pipeline. */ -@AlphaComponent +@Experimental class PipelineModel private[ml] ( override val uid: String, val stages: Array[Transformer]) extends Model[PipelineModel] with Logging { + /** A Java/Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, stages: ju.List[Transformer]) = { + this(uid, stages.asScala.toArray) + } + override def validateParams(): Unit = { super.validateParams() stages.foreach(_.validateParams()) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index ec0f76aa668bd..e752b81a14282 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -58,7 +58,6 @@ private[ml] trait PredictorParams extends Params /** * :: DeveloperApi :: - * * Abstraction for prediction problems (regression and classification). * * @tparam FeaturesType Type of features. @@ -113,7 +112,6 @@ abstract class Predictor[ * * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. */ - @DeveloperApi private[ml] def featuresDataType: DataType = new VectorUDT override def transformSchema(schema: StructType): StructType = { @@ -134,7 +132,6 @@ abstract class Predictor[ /** * :: DeveloperApi :: - * * Abstraction for a model for prediction tasks (regression and classification). * * @tparam FeaturesType Type of features. diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 38bb6a5a5391e..f07f733a5ddb5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml import scala.annotation.varargs import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.sql.DataFrame @@ -28,10 +28,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for transformers that transform one dataset into another. */ -@AlphaComponent +@DeveloperApi abstract class Transformer extends PipelineStage { /** @@ -73,10 +73,12 @@ abstract class Transformer extends PipelineStage { } /** + * :: DeveloperApi :: * Abstract class for transformers that take one input column, apply transformation, and output the * result as a new column. */ -private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] +@DeveloperApi +abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] extends Transformer with HasInputCol with HasOutputCol with Logging { /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index f5f37aa77929c..457c15830fd38 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -19,10 +19,12 @@ package org.apache.spark.ml.attribute import scala.collection.mutable.ArrayBuffer +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField} /** + * :: DeveloperApi :: * Attributes that describe a vector ML column. * * @param name name of the attribute group (the ML column name) @@ -31,6 +33,7 @@ import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField} * @param attrs optional array of attributes. Attribute will be copied with their corresponding * indices in the array. */ +@DeveloperApi class AttributeGroup private ( val name: String, val numAttributes: Option[Int], @@ -182,7 +185,11 @@ class AttributeGroup private ( } } -/** Factory methods to create attribute groups. */ +/** + * :: DeveloperApi :: + * Factory methods to create attribute groups. + */ +@DeveloperApi object AttributeGroup { import AttributeKeys._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala index a83febd7de2cc..5c7089b491677 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala @@ -17,12 +17,17 @@ package org.apache.spark.ml.attribute +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: * An enum-like type for attribute types: [[AttributeType$#Numeric]], [[AttributeType$#Nominal]], * and [[AttributeType$#Binary]]. */ +@DeveloperApi sealed abstract class AttributeType(val name: String) +@DeveloperApi object AttributeType { /** Numeric type. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index e8f7f152784a1..ce43a450daad0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -19,11 +19,14 @@ package org.apache.spark.ml.attribute import scala.annotation.varargs +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField} /** + * :: DeveloperApi :: * Abstract class for ML attributes. */ +@DeveloperApi sealed abstract class Attribute extends Serializable { name.foreach { n => @@ -135,6 +138,10 @@ private[attribute] trait AttributeFactory { } } +/** + * :: DeveloperApi :: + */ +@DeveloperApi object Attribute extends AttributeFactory { private[attribute] override def fromMetadata(metadata: Metadata): Attribute = { @@ -163,6 +170,7 @@ object Attribute extends AttributeFactory { /** + * :: DeveloperApi :: * A numeric attribute with optional summary statistics. * @param name optional name * @param index optional index @@ -171,6 +179,7 @@ object Attribute extends AttributeFactory { * @param std optional standard deviation * @param sparsity optional sparsity (ratio of zeros) */ +@DeveloperApi class NumericAttribute private[ml] ( override val name: Option[String] = None, override val index: Option[Int] = None, @@ -278,8 +287,10 @@ class NumericAttribute private[ml] ( } /** + * :: DeveloperApi :: * Factory methods for numeric attributes. */ +@DeveloperApi object NumericAttribute extends AttributeFactory { /** The default numeric attribute. */ @@ -298,6 +309,7 @@ object NumericAttribute extends AttributeFactory { } /** + * :: DeveloperApi :: * A nominal attribute. * @param name optional name * @param index optional index @@ -306,6 +318,7 @@ object NumericAttribute extends AttributeFactory { * defined. * @param values optional values. At most one of `numValues` and `values` can be defined. */ +@DeveloperApi class NominalAttribute private[ml] ( override val name: Option[String] = None, override val index: Option[Int] = None, @@ -430,7 +443,11 @@ class NominalAttribute private[ml] ( } } -/** Factory methods for nominal attributes. */ +/** + * :: DeveloperApi :: + * Factory methods for nominal attributes. + */ +@DeveloperApi object NominalAttribute extends AttributeFactory { /** The default nominal attribute. */ @@ -450,11 +467,13 @@ object NominalAttribute extends AttributeFactory { } /** + * :: DeveloperApi :: * A binary attribute. * @param name optional name * @param index optional index * @param values optionla values. If set, its size must be 2. */ +@DeveloperApi class BinaryAttribute private[ml] ( override val name: Option[String] = None, override val index: Option[Int] = None, @@ -526,7 +545,11 @@ class BinaryAttribute private[ml] ( } } -/** Factory methods for binary attributes. */ +/** + * :: DeveloperApi :: + * Factory methods for binary attributes. + */ +@DeveloperApi object BinaryAttribute extends AttributeFactory { /** The default binary attribute. */ @@ -543,8 +566,10 @@ object BinaryAttribute extends AttributeFactory { } /** + * :: DeveloperApi :: * An unresolved attribute. */ +@DeveloperApi object UnresolvedAttribute extends Attribute { override def attrType: AttributeType = AttributeType.Unresolved diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 7c961332bf5b6..8030e0728a56c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node} +import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -31,14 +31,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm * for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ -@AlphaComponent +@Experimental final class DecisionTreeClassifier(override val uid: String) extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { @@ -89,19 +88,19 @@ final class DecisionTreeClassifier(override val uid: String) } } +@Experimental object DecisionTreeClassifier { /** Accessor for supported impurities: entropy, gini */ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ -@AlphaComponent +@Experimental final class DecisionTreeClassificationModel private[ml] ( override val uid: String, override val rootNode: Node) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index d504d84beb91e..62f4b51f770e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,11 +20,11 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -36,14 +36,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * learning algorithm for classification. * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. */ -@AlphaComponent +@Experimental final class GBTClassifier(override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] with GBTParams with TreeClassifierParams with Logging { @@ -144,6 +143,7 @@ final class GBTClassifier(override val uid: String) } } +@Experimental object GBTClassifier { // The losses below should be lowercase. /** Accessor for supported loss settings: logistic */ @@ -151,8 +151,7 @@ object GBTClassifier { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * model for classification. * It supports binary labels, as well as both continuous and categorical features. @@ -160,7 +159,7 @@ object GBTClassifier { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ -@AlphaComponent +@Experimental final class GBTClassificationModel( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], @@ -209,7 +208,7 @@ private[ml] object GBTClassificationModel { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8694c96e4c5b6..f136bcee9cf2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.classification import scala.collection.mutable -import breeze.linalg.{norm => brzNorm, DenseVector => BDV} -import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import breeze.optimize.{CachedDiffFunction, DiffFunction} +import breeze.linalg.{DenseVector => BDV, norm => brzNorm} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable @@ -35,7 +35,6 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel -import org.apache.spark.{SparkException, Logging} /** * Params for logistic regression. @@ -45,12 +44,11 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas with HasThreshold /** - * :: AlphaComponent :: - * + * :: Experimental :: * Logistic regression. * Currently, this class only supports binary classification. */ -@AlphaComponent +@Experimental class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams with Logging { @@ -76,7 +74,7 @@ class LogisticRegression(override val uid: String) setDefault(elasticNetParam -> 0.0) /** - * Set the maximal number of iterations. + * Set the maximum number of iterations. * Default is 100. * @group setParam */ @@ -92,7 +90,11 @@ class LogisticRegression(override val uid: String) def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) - /** @group setParam */ + /** + * Whether to fit an intercept term. + * Default is true. + * @group setParam + * */ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) @@ -221,11 +223,10 @@ class LogisticRegression(override val uid: String) } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Model produced by [[LogisticRegression]]. */ -@AlphaComponent +@Experimental class LogisticRegressionModel private[ml] ( override val uid: String, val weights: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1543f051ccd17..825f9ed1b54b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -21,7 +21,7 @@ import java.util.UUID import scala.language.existentials -import org.apache.spark.annotation.{AlphaComponent, Experimental} +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.Param @@ -37,11 +37,13 @@ import org.apache.spark.storage.StorageLevel */ private[ml] trait OneVsRestParams extends PredictorParams { + // scalastyle:off structural.type type ClassifierType = Classifier[F, E, M] forSome { type F type M <: ClassificationModel[F, M] type E <: Classifier[F, E, M] } + // scalastyle:on structural.type /** * param for the base binary classifier that we reduce multiclass classification into. @@ -54,8 +56,7 @@ private[ml] trait OneVsRestParams extends PredictorParams { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Model produced by [[OneVsRest]]. * This stores the models resulting from training k binary classifiers: one for each class. * Each example is scored against all k models, and the model with the highest score @@ -67,11 +68,11 @@ private[ml] trait OneVsRestParams extends PredictorParams { * The i-th model is produced by testing the i-th class (taking label 1) vs the rest * (taking label 0). */ -@AlphaComponent +@Experimental final class OneVsRestModel private[ml] ( override val uid: String, labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_,_]]) + val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams { override def transformSchema(schema: StructType): StructType = { @@ -105,17 +106,17 @@ final class OneVsRestModel private[ml] ( // add temporary column to store intermediate scores and update val tmpColName = "mbc$tmp" + UUID.randomUUID().toString - val update: (Map[Int, Double], Vector) => Map[Int, Double] = + val update: (Map[Int, Double], Vector) => Map[Int, Double] = (predictions: Map[Int, Double], prediction: Vector) => { predictions + ((index, prediction(1))) } val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol)) - val transformedDataset = model.transform(df).select(columns:_*) + val transformedDataset = model.transform(df).select(columns : _*) val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf) val newColumns = origCols ++ List(col(tmpColName)) // switch out the intermediate column with the accumulator column - updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName) + updatedDataset.select(newColumns : _*).withColumnRenamed(tmpColName, accColName) } if (handlePersistence) { @@ -130,6 +131,7 @@ final class OneVsRestModel private[ml] ( // output label and label metadata as prediction val labelUdf = callUDF(label, DoubleType, col(accColName)) aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) + .drop(accColName) } } @@ -191,7 +193,7 @@ final class OneVsRest(override val uid: String) val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) val classifier = getClassifier classifier.fit(trainingDataset, classifier.labelCol -> labelColName) - }.toArray[ClassificationModel[_,_]] + }.toArray[ClassificationModel[_, _]] if (handlePersistence) { multiclassLabeled.unpersist() diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index a1de7919859eb..852a67e066322 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.classification import scala.collection.mutable -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -33,14 +33,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for * classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ -@AlphaComponent +@Experimental final class RandomForestClassifier(override val uid: String) extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { @@ -100,6 +99,7 @@ final class RandomForestClassifier(override val uid: String) } } +@Experimental object RandomForestClassifier { /** Accessor for supported impurity settings: entropy, gini */ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities @@ -110,15 +110,14 @@ object RandomForestClassifier { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. * @param _trees Decision trees in the ensemble. * Warning: These have null parents. */ -@AlphaComponent +@Experimental final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel]) @@ -171,7 +170,7 @@ private[ml] object RandomForestClassificationModel { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index ddbdd00ceb159..f695ddaeefc72 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} @@ -28,11 +27,10 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** - * :: AlphaComponent :: - * + * :: Experimental :: * Evaluator for binary classification, which expects two input columns: score and label. */ -@AlphaComponent +@Experimental class BinaryClassificationEvaluator(override val uid: String) extends Evaluator with HasRawPredictionCol with HasLabelCol { diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index cabd1c97c085c..61e937e693699 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -17,15 +17,15 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for evaluators that compute metrics from predictions. */ -@AlphaComponent +@DeveloperApi abstract class Evaluator extends Params { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 80458928c5439..abb1b35bedea5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.{Param, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} @@ -26,19 +26,18 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** - * :: AlphaComponent :: - * + * :: Experimental :: * Evaluator for regression, which expects two input columns: prediction and label. */ -@AlphaComponent +@Experimental final class RegressionEvaluator(override val uid: String) extends Evaluator with HasPredictionCol with HasLabelCol { def this() = this(Identifiable.randomUID("regEval")) /** - * param for metric name in evaluation - * @group param supports mse, rmse, r2, mae as valid metric names. + * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`) + * @group param */ val metricName: Param[String] = { val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae")) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 62f4a6343423e..b06122d733853 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ @@ -28,10 +28,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} /** - * :: AlphaComponent :: + * :: Experimental :: * Binarize a column of continuous features given a threshold. */ -@AlphaComponent +@Experimental final class Binarizer(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index ac8dfb5632a7b..a3d1f6f65ccaf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import java.{util => ju} import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ @@ -31,10 +31,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * :: AlphaComponent :: + * :: Experimental :: * `Bucketizer` maps a column of continuous features to a column of feature buckets. */ -@AlphaComponent +@Experimental final class Bucketizer(override val uid: String) extends Model[Bucketizer] with HasInputCol with HasOutputCol { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 8b32eee0e490a..1e758cb775de7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.Param import org.apache.spark.ml.util.Identifiable @@ -26,12 +26,12 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType /** - * :: AlphaComponent :: + * :: Experimental :: * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a * provided "weight" vector. In other words, it scales each column of the dataset by a scalar * multiplier. */ -@AlphaComponent +@Experimental class ElementwiseProduct(override val uid: String) extends UnaryTransformer[Vector, Vector, ElementwiseProduct] { @@ -41,7 +41,7 @@ class ElementwiseProduct(override val uid: String) * the vector to multiply with input vectors * @group param */ - val scalingVec: Param[Vector] = new Param(this, "scalingVector", "vector for hadamard product") + val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product") /** @group setParam */ def setScalingVec(value: Vector): this.type = set(scalingVec, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 8942d45219177..f936aef80f8af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -17,22 +17,22 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions.{udf, col} +import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StructType} /** - * :: AlphaComponent :: + * :: Experimental :: * Maps a sequence of terms to their term frequencies using the hashing trick. */ -@AlphaComponent +@Experimental class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("hashingTF")) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 788c392050c2d..376b84530cd57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -58,10 +58,10 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol } /** - * :: AlphaComponent :: + * :: Experimental :: * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ -@AlphaComponent +@Experimental final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase { def this() = this(Identifiable.randomUID("idf")) @@ -85,10 +85,10 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[IDF]]. */ -@AlphaComponent +@Experimental class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 3f689d1585cd6..8282e5ffa17f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{DoubleParam, ParamValidators} import org.apache.spark.ml.util.Identifiable @@ -26,10 +26,10 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType /** - * :: AlphaComponent :: + * :: Experimental :: * Normalize a vector to have unit norm using the given p-norm. */ -@AlphaComponent +@Experimental class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] { def this() = this(Identifiable.randomUID("normalizer")) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 1fb9b9ae75091..8f34878c8d329 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,93 +17,152 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute} -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} -import org.apache.spark.sql.types.{DataType, DoubleType, StructType} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{DoubleType, StructType} /** - * A one-hot encoder that maps a column of label indices to a column of binary vectors, with - * at most a single one-value. By default, the binary vector has an element for each category, so - * with 5 categories, an input value of 2.0 would map to an output vector of - * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the - * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value - * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns - * linearly dependent because they sum up to one. + * :: Experimental :: + * A one-hot encoder that maps a column of category indices to a column of binary vectors, with + * at most a single one-value per row that indicates the input category index. + * For example with 5 categories, an input value of 2.0 would map to an output vector of + * `[0.0, 0.0, 1.0, 0.0]`. + * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]] + * because it makes the vector entries sum up to one, and hence linearly dependent. + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories. + * The output vectors are sparse. + * + * @see [[StringIndexer]] for converting categorical values into category indices */ -@AlphaComponent -class OneHotEncoder(override val uid: String) - extends UnaryTransformer[Double, Vector, OneHotEncoder] with HasInputCol with HasOutputCol { +@Experimental +class OneHotEncoder(override val uid: String) extends Transformer + with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("oneHot")) /** - * Whether to include a component in the encoded vectors for the first category, defaults to true. + * Whether to drop the last category in the encoded vector (default: true) * @group param */ - final val includeFirst: BooleanParam = - new BooleanParam(this, "includeFirst", "include first category") - setDefault(includeFirst -> true) - - private var categories: Array[String] = _ + final val dropLast: BooleanParam = + new BooleanParam(this, "dropLast", "whether to drop the last category") + setDefault(dropLast -> true) /** @group setParam */ - def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value) + def setDropLast(value: Boolean): this.type = set(dropLast, value) /** @group setParam */ - override def setInputCol(value: String): this.type = set(inputCol, value) + def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ - override def setOutputCol(value: String): this.type = set(outputCol, value) + def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) - val inputFields = schema.fields + val is = "_is_" + val inputColName = $(inputCol) val outputColName = $(outputCol) - require(inputFields.forall(_.name != $(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val inputColAttr = Attribute.fromStructField(schema($(inputCol))) - categories = inputColAttr match { + SchemaUtils.checkColumnType(schema, inputColName, DoubleType) + val inputFields = schema.fields + require(!inputFields.exists(_.name == outputColName), + s"Output column $outputColName already exists.") + + val inputAttr = Attribute.fromStructField(schema(inputColName)) + val outputAttrNames: Option[Array[String]] = inputAttr match { case nominal: NominalAttribute => - nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray) - case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1")) + if (nominal.values.isDefined) { + nominal.values.map(_.map(v => inputColName + is + v)) + } else if (nominal.numValues.isDefined) { + nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i)) + } else { + None + } + case binary: BinaryAttribute => + if (binary.values.isDefined) { + binary.values.map(_.map(v => inputColName + is + v)) + } else { + Some(Array.tabulate(2)(i => inputColName + is + i)) + } + case _: NumericAttribute => + throw new RuntimeException( + s"The input column $inputColName cannot be numeric.") case _ => - throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal") + None // optimistic about unknown attributes } - val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray - val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues) - val outputFields = inputFields :+ attr.toStructField() + val filteredOutputAttrNames = outputAttrNames.map { names => + if ($(dropLast)) { + require(names.length > 1, + s"The input column $inputColName should have at least two distinct values.") + names.dropRight(1) + } else { + names + } + } + + val outputAttrGroup = if (filteredOutputAttrNames.isDefined) { + val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name => + BinaryAttribute.defaultAttr.withName(name) + } + new AttributeGroup($(outputCol), attrs) + } else { + new AttributeGroup($(outputCol)) + } + + val outputFields = inputFields :+ outputAttrGroup.toStructField() StructType(outputFields) } - protected override def createTransformFunc(): (Double) => Vector = { - val first = $(includeFirst) - val vecLen = if (first) categories.length else categories.length - 1 + override def transform(dataset: DataFrame): DataFrame = { + // schema transformation + val is = "_is_" + val inputColName = $(inputCol) + val outputColName = $(outputCol) + val shouldDropLast = $(dropLast) + var outputAttrGroup = AttributeGroup.fromStructField( + transformSchema(dataset.schema)(outputColName)) + if (outputAttrGroup.size < 0) { + // If the number of attributes is unknown, we check the values from the input column. + val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0)) + .aggregate(0.0)( + (m, x) => { + assert(x >=0.0 && x == x.toInt, + s"Values from column $inputColName must be indices, but got $x.") + math.max(m, x) + }, + (m0, m1) => { + math.max(m0, m1) + } + ).toInt + 1 + val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i) + val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames + val outputAttrs: Array[Attribute] = + filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) + outputAttrGroup = new AttributeGroup(outputColName, outputAttrs) + } + val metadata = outputAttrGroup.toMetadata() + + // data transformation + val size = outputAttrGroup.size val oneValue = Array(1.0) val emptyValues = Array[Double]() val emptyIndices = Array[Int]() - label: Double => { - val values = if (first || label != 0.0) oneValue else emptyValues - val indices = if (first) { - Array(label.toInt) - } else if (label != 0.0) { - Array(label.toInt - 1) + val encode = udf { label: Double => + if (label < size) { + Vectors.sparse(size, Array(label.toInt), oneValue) } else { - emptyIndices + Vectors.sparse(size, emptyIndices, emptyValues) } - Vectors.sparse(vecLen, indices, values) } - } - /** - * Returns the data type of the output column. - */ - protected def outputDataType: DataType = new VectorUDT + dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata)) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 8ddf9d6a1e138..442e95820217a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import scala.collection.mutable -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{IntParam, ParamValidators} import org.apache.spark.ml.util.Identifiable @@ -27,14 +27,14 @@ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType /** - * :: AlphaComponent :: + * :: Experimental :: * Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, * which is available at [[http://en.wikipedia.org/wiki/Polynomial_expansion]], "In mathematics, an * expansion of a product of sums expresses it as a sum of products by using the fact that * multiplication distributes over addition". Take a 2-variable feature vector as an example: * `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, y, x * y, y * y)`. */ -@AlphaComponent +@Experimental class PolynomialExpansion(override val uid: String) extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 5ccda15d872ed..b0fd06d84fdb3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -35,13 +35,13 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with /** * Centers the data with mean before scaling. - * It will build a dense output, so this does not work on sparse input + * It will build a dense output, so this does not work on sparse input * and will raise an exception. * Default: false * @group param */ val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") - + /** * Scales the data to unit standard deviation. * Default: true @@ -51,11 +51,11 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with } /** - * :: AlphaComponent :: + * :: Experimental :: * Standardizes features by removing the mean and scaling to unit variance using column summary * statistics on the samples in the training set. */ -@AlphaComponent +@Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] with StandardScalerParams { @@ -68,13 +68,13 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - + /** @group setParam */ def setWithMean(value: Boolean): this.type = set(withMean, value) - + /** @group setParam */ def setWithStd(value: Boolean): this.type = set(withStd, value) - + override def fit(dataset: DataFrame): StandardScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } @@ -95,10 +95,10 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[StandardScaler]]. */ -@AlphaComponent +@Experimental class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 3f79b67309f07..f4e250757560a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ @@ -52,13 +52,13 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha } /** - * :: AlphaComponent :: + * :: Experimental :: * A label indexer that maps a string column of labels to an ML column of label indices. * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. */ -@AlphaComponent +@Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] with StringIndexerBase { @@ -86,10 +86,13 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[StringIndexer]]. + * NOTE: During transformation, if the input column does not exist, + * [[StringIndexerModel.transform]] would return the input dataset unmodified. + * This is a temporary fix for the case when target labels do not exist during prediction. */ -@AlphaComponent +@Experimental class StringIndexerModel private[ml] ( override val uid: String, labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { @@ -112,6 +115,12 @@ class StringIndexerModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { + if (!dataset.schema.fieldNames.contains($(inputCol))) { + logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + + "Skip StringIndexerModel.") + return dataset + } + val indexer = udf { label: String => if (labelToIndex.contains(label)) { labelToIndex(label) @@ -128,6 +137,11 @@ class StringIndexerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + if (schema.fieldNames.contains($(inputCol))) { + validateAndTransformSchema(schema) + } else { + // If the input column does not exist during transformation, we skip StringIndexerModel. + schema + } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 31f3a1aa4c76b..21c15b6c33f6c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -17,19 +17,19 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** - * :: AlphaComponent :: + * :: Experimental :: * A tokenizer that converts the input string to lowercase and then splits it by white spaces. * * @see [[RegexTokenizer]] */ -@AlphaComponent +@Experimental class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] { def this() = this(Identifiable.randomUID("tok")) @@ -46,13 +46,13 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S } /** - * :: AlphaComponent :: + * :: Experimental :: * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split * the text (default) or repeatedly matching the regex (if `gaps` is true). * Optional parameters also allow filtering tokens using a minimal length. * It returns an array of strings that can be empty. */ -@AlphaComponent +@Experimental class RegexTokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], RegexTokenizer] { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 181b62f46fce8..229ee27ec5942 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -20,8 +20,9 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} @@ -30,14 +31,14 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: AlphaComponent :: + * :: Experimental :: * A feature transformer that merges multiple columns into a vector column. */ -@AlphaComponent +@Experimental class VectorAssembler(override val uid: String) extends Transformer with HasInputCols with HasOutputCol { - def this() = this(Identifiable.randomUID("va")) + def this() = this(Identifiable.randomUID("vecAssembler")) /** @group setParam */ def setInputCols(value: Array[String]): this.type = set(inputCols, value) @@ -46,19 +47,59 @@ class VectorAssembler(override val uid: String) def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { + // Schema transformation. + val schema = dataset.schema + lazy val first = dataset.first() + val attrs = $(inputCols).flatMap { c => + val field = schema(c) + val index = schema.fieldIndex(c) + field.dataType match { + case DoubleType => + val attr = Attribute.fromStructField(field) + // If the input column doesn't have ML attribute, assume numeric. + if (attr == UnresolvedAttribute) { + Some(NumericAttribute.defaultAttr.withName(c)) + } else { + Some(attr.withName(c)) + } + case _: NumericType | BooleanType => + // If the input column type is a compatible scalar type, assume numeric. + Some(NumericAttribute.defaultAttr.withName(c)) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(field) + if (group.attributes.isDefined) { + // If attributes are defined, copy them with updated names. + group.attributes.get.map { attr => + if (attr.name.isDefined) { + // TODO: Define a rigorous naming scheme. + attr.withName(c + "_" + attr.name.get) + } else { + attr + } + } + } else { + // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes + // from metadata, check the first row. + val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) + Array.fill(numAttrs)(NumericAttribute.defaultAttr) + } + } + } + val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() + + // Data transformation. val assembleFunc = udf { r: Row => VectorAssembler.assemble(r.toSeq: _*) } - val schema = dataset.schema - val inputColNames = $(inputCols) - val args = inputColNames.map { c => + val args = $(inputCols).map { c => schema(c).dataType match { case DoubleType => dataset(c) case _: VectorUDT => dataset(c) case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } - dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol))) + + dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol), metadata)) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index e238fb310ed37..1d0f23b4fb3db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -22,7 +22,7 @@ import java.util.{Map => JMap} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.{IntParam, ParamValidators, Params} @@ -56,8 +56,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Class for indexing categorical feature columns in a dataset of [[Vector]]. * * This has 2 usage modes: @@ -91,7 +90,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * - Add warning if a categorical feature has only 1 category. * - Add option for allowing unknown categories. */ -@AlphaComponent +@Experimental class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel] with VectorIndexerParams { @@ -230,8 +229,7 @@ private object VectorIndexer { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Transform categorical features to use 0-based indices instead of their original values. * - Categorical features are mapped to indices. * - Continuous features (columns) are left unchanged. @@ -246,7 +244,7 @@ private object VectorIndexer { * Values are maps from original features values to 0-based category indices. * If a feature is not in this map, it is treated as continuous. */ -@AlphaComponent +@Experimental class VectorIndexerModel private[ml] ( override val uid: String, val numFeatures: Int, diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index ed032669229ce..36f19509f0cfb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -82,11 +82,11 @@ private[feature] trait Word2VecBase extends Params } /** - * :: AlphaComponent :: + * :: Experimental :: * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further * natural language processing or machine learning process. */ -@AlphaComponent +@Experimental final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase { def this() = this(Identifiable.randomUID("w2v")) @@ -135,10 +135,10 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[Word2Vec]]. */ -@AlphaComponent +@Experimental class Word2VecModel private[ml] ( override val uid: String, wordVectors: feature.Word2VecModel) diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java index 00d9c802e930d..87f4223964ada 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package-info.java +++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java @@ -16,10 +16,10 @@ */ /** - * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. */ -@AlphaComponent +@Experimental package org.apache.spark.ml; -import org.apache.spark.annotation.AlphaComponent; +import org.apache.spark.annotation.Experimental; diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala index ac75e9de1a8f2..c589d06d9f7e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -18,7 +18,7 @@ package org.apache.spark /** - * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. * * @groupname param Parameters diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 12fc5b561f76e..ba94d6a3a80a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,11 +24,11 @@ import scala.annotation.varargs import scala.collection.mutable import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.util.Identifiable /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A param with self-contained documentation and optionally default value. Primitive-typed param * should use the specialized versions, which are more friendly to Java users. * @@ -39,7 +39,7 @@ import org.apache.spark.ml.util.Identifiable * See [[ParamValidators]] for factory methods for common validation functions. * @tparam T param value type */ -@AlphaComponent +@DeveloperApi class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean) extends Serializable { @@ -69,14 +69,10 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali } } - /** - * Creates a param pair with the given value (for Java). - */ + /** Creates a param pair with the given value (for Java). */ def w(value: T): ParamPair[T] = this -> value - /** - * Creates a param pair with the given value (for Scala). - */ + /** Creates a param pair with the given value (for Scala). */ def ->(value: T): ParamPair[T] = ParamPair(this, value) override final def toString: String = s"${parent}__$name" @@ -174,7 +170,11 @@ object ParamValidators { // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... -/** Specialized version of [[Param[Double]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Double]]] for Java. + */ +@DeveloperApi class DoubleParam(parent: String, name: String, doc: String, isValid: Double => Boolean) extends Param[Double](parent, name, doc, isValid) { @@ -186,10 +186,15 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double => def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Double): ParamPair[Double] = super.w(value) } -/** Specialized version of [[Param[Int]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Int]]] for Java. + */ +@DeveloperApi class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolean) extends Param[Int](parent, name, doc, isValid) { @@ -201,10 +206,15 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Int): ParamPair[Int] = super.w(value) } -/** Specialized version of [[Param[Float]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Float]]] for Java. + */ +@DeveloperApi class FloatParam(parent: String, name: String, doc: String, isValid: Float => Boolean) extends Param[Float](parent, name, doc, isValid) { @@ -216,10 +226,15 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Float): ParamPair[Float] = super.w(value) } -/** Specialized version of [[Param[Long]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Long]]] for Java. + */ +@DeveloperApi class LongParam(parent: String, name: String, doc: String, isValid: Long => Boolean) extends Param[Long](parent, name, doc, isValid) { @@ -231,47 +246,60 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Long): ParamPair[Long] = super.w(value) } -/** Specialized version of [[Param[Boolean]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Boolean]]] for Java. + */ +@DeveloperApi class BooleanParam(parent: String, name: String, doc: String) // No need for isValid extends Param[Boolean](parent, name, doc) { def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } -/** Specialized version of [[Param[Array[String]]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Array[String]]]] for Java. + */ +@DeveloperApi class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean) extends Param[Array[String]](parent, name, doc, isValid) { def this(parent: Params, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) - override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value) - /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray) } -/** Specialized version of [[Param[Array[Double]]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Array[Double]]]] for Java. + */ +@DeveloperApi class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array[Double] => Boolean) extends Param[Array[Double]](parent, name, doc, isValid) { def this(parent: Params, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) - override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value) - /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ - def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray) + def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] = + w(value.asScala.map(_.asInstanceOf[Double]).toArray) } /** + * :: Experimental :: * A param amd its value. */ +@Experimental case class ParamPair[T](param: Param[T], value: T) { // This is *the* place Param.validate is called. Whenever a parameter is specified, we should // always construct a ParamPair so that validate is called. @@ -279,11 +307,11 @@ case class ParamPair[T](param: Param[T], value: T) { } /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Trait for components that take parameters. This also provides an internal param map to store * parameter values attached to the instance. */ -@AlphaComponent +@DeveloperApi trait Params extends Identifiable with Serializable { /** @@ -303,19 +331,6 @@ trait Params extends Identifiable with Serializable { .map(m => m.invoke(this).asInstanceOf[Param[_]]) } - /** - * Validates parameter values stored internally plus the input parameter map. - * Raises an exception if any parameter is invalid. - * - * This only needs to check for interactions between parameters. - * Parameter value checks which do not depend on other parameters are handled by - * [[Param.validate()]]. This method does not handle input/output column parameters; - * those are checked during schema validation. - */ - def validateParams(paramMap: ParamMap): Unit = { - copy(paramMap).validateParams() - } - /** * Validates parameter values stored internally. * Raise an exception if any parameter value is invalid. @@ -541,10 +556,10 @@ trait Params extends Identifiable with Serializable { abstract class JavaParams extends Params /** - * :: AlphaComponent :: + * :: Experimental :: * A param to value map. */ -@AlphaComponent +@Experimental final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { @@ -665,6 +680,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) def size: Int = map.size } +@Experimental object ParamMap { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 1ffb5eddc36bd..8ffbcf0d8bc71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -33,7 +33,7 @@ private[shared] object SharedParamsCodeGen { val params = Seq( ParamDesc[Double]("regParam", "regularization parameter (>= 0)", isValid = "ParamValidators.gtEq(0)"), - ParamDesc[Int]("maxIter", "max number of iterations (>= 0)", + ParamDesc[Int]("maxIter", "maximum number of iterations (>= 0)", isValid = "ParamValidators.gtEq(0)"), ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index ed08417bd4df8..a0c8ccdac9ad9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -45,10 +45,10 @@ private[ml] trait HasRegParam extends Params { private[ml] trait HasMaxIter extends Params { /** - * Param for max number of iterations (>= 0). + * Param for maximum number of iterations (>= 0). * @group param */ - final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0)) + final val maxIter: IntParam = new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", ParamValidators.gtEq(0)) /** @group getParam */ final def getMaxIter: Int = $(maxIter) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 2a5ddbfae5cdf..df009d855ecbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -31,25 +31,50 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.netlib.util.intW import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} import org.apache.spark.util.random.XORShiftRandom +/** + * Common params for ALS and ALSModel. + */ +private[recommendation] trait ALSModelParams extends Params with HasPredictionCol { + /** + * Param for the column name for user ids. + * Default: "user" + * @group param + */ + val userCol = new Param[String](this, "userCol", "column name for user ids") + + /** @group getParam */ + def getUserCol: String = $(userCol) + + /** + * Param for the column name for item ids. + * Default: "item" + * @group param + */ + val itemCol = new Param[String](this, "itemCol", "column name for item ids") + + /** @group getParam */ + def getItemCol: String = $(itemCol) +} + /** * Common params for ALS. */ -private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam +private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam with HasPredictionCol with HasCheckpointInterval with HasSeed { /** @@ -105,26 +130,6 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR /** @group getParam */ def getAlpha: Double = $(alpha) - /** - * Param for the column name for user ids. - * Default: "user" - * @group param - */ - val userCol = new Param[String](this, "userCol", "column name for user ids") - - /** @group getParam */ - def getUserCol: String = $(userCol) - - /** - * Param for the column name for item ids. - * Default: "item" - * @group param - */ - val itemCol = new Param[String](this, "itemCol", "column name for item ids") - - /** @group getParam */ - def getItemCol: String = $(itemCol) - /** * Param for the column name for ratings. * Default: "rating" @@ -156,58 +161,66 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - require(schema($(userCol)).dataType == IntegerType) - require(schema($(itemCol)).dataType== IntegerType) + SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) + SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) val ratingType = schema($(ratingCol)).dataType require(ratingType == FloatType || ratingType == DoubleType) - val predictionColName = $(predictionCol) - require(!schema.fieldNames.contains(predictionColName), - s"Prediction column $predictionColName already exists.") - val newFields = schema.fields :+ StructField($(predictionCol), FloatType, nullable = false) - StructType(newFields) + SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } } /** + * :: Experimental :: * Model fitted by ALS. + * + * @param rank rank of the matrix factorization model + * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features` + * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features` */ +@Experimental class ALSModel private[ml] ( override val uid: String, - k: Int, - userFactors: RDD[(Int, Array[Float])], - itemFactors: RDD[(Int, Array[Float])]) - extends Model[ALSModel] with ALSParams { + val rank: Int, + @transient val userFactors: DataFrame, + @transient val itemFactors: DataFrame) + extends Model[ALSModel] with ALSModelParams { + + /** @group setParam */ + def setUserCol(value: String): this.type = set(userCol, value) + + /** @group setParam */ + def setItemCol(value: String): this.type = set(itemCol, value) /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) override def transform(dataset: DataFrame): DataFrame = { - import dataset.sqlContext.implicits._ - val users = userFactors.toDF("id", "features") - val items = itemFactors.toDF("id", "features") - // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => if (userFeatures != null && itemFeatures != null) { - blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) + blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) } else { Float.NaN } } dataset - .join(users, dataset($(userCol)) === users("id"), "left") - .join(items, dataset($(itemCol)) === items("id"), "left") - .select(dataset("*"), predict(users("features"), items("features")).as($(predictionCol))) + .join(userFactors, dataset($(userCol)) === userFactors("id"), "left") + .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left") + .select(dataset("*"), + predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) } override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) + SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) + SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } } /** + * :: Experimental :: * Alternating Least Squares (ALS) matrix factorization. * * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, @@ -236,6 +249,7 @@ class ALSModel private[ml] ( * indicated user * preferences rather than explicit ratings given to items. */ +@Experimental class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { import org.apache.spark.ml.recommendation.ALS.Rating @@ -295,6 +309,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { } override def fit(dataset: DataFrame): ALSModel = { + import dataset.sqlContext.implicits._ val ratings = dataset .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), col($(ratingCol)).cast(FloatType)) @@ -306,7 +321,9 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), alpha = $(alpha), nonnegative = $(nonnegative), checkpointInterval = $(checkpointInterval), seed = $(seed)) - val model = new ALSModel(uid, $(rank), userFactors, itemFactors).setParent(this) + val userDF = userFactors.toDF("id", "features") + val itemDF = itemFactors.toDF("id", "features") + val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) copyValues(model) } @@ -326,7 +343,11 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { @DeveloperApi object ALS extends Logging { - /** Rating class for better code readability. */ + /** + * :: DeveloperApi :: + * Rating class for better code readability. + */ + @DeveloperApi case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) /** Trait for least squares solvers applied to the normal equation. */ @@ -487,8 +508,10 @@ object ALS extends Logging { } /** + * :: DeveloperApi :: * Implementation of the ALS algorithm. */ + @DeveloperApi def train[ID: ClassTag]( // scalastyle:ignore ratings: RDD[Rating[ID]], rank: Int = 10, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index e67df21b2e4ae..43b68e7bb20fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{TreeRegressorParams, DecisionTreeParams, DecisionTreeModel, Node} +import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -31,13 +31,12 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm * for regression. * It supports both continuous and categorical features. */ -@AlphaComponent +@Experimental final class DecisionTreeRegressor(override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] with DecisionTreeParams with TreeRegressorParams { @@ -79,19 +78,19 @@ final class DecisionTreeRegressor(override val uid: String) } } +@Experimental object DecisionTreeRegressor { /** Accessor for supported impurities: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression. * It supports both continuous and categorical features. * @param rootNode Root of the decision tree */ -@AlphaComponent +@Experimental final class DecisionTreeRegressionModel private[ml] ( override val uid: String, override val rootNode: Node) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 4249ff5c1ebc7..b7e374bb6cb49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -20,10 +20,10 @@ package org.apache.spark.ml.regression import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{GBTParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -35,13 +35,12 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * learning algorithm for regression. * It supports both continuous and categorical features. */ -@AlphaComponent +@Experimental final class GBTRegressor(override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] with GBTParams with TreeRegressorParams with Logging { @@ -134,6 +133,7 @@ final class GBTRegressor(override val uid: String) } } +@Experimental object GBTRegressor { // The losses below should be lowercase. /** Accessor for supported loss settings: squared (L2), absolute (L1) */ @@ -141,7 +141,7 @@ object GBTRegressor { } /** - * :: AlphaComponent :: + * :: Experimental :: * * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * model for regression. @@ -149,7 +149,7 @@ object GBTRegressor { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ -@AlphaComponent +@Experimental final class GBTRegressionModel( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], @@ -198,7 +198,7 @@ private[ml] object GBTRegressionModel { require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 3ebb78f79201a..70cd8e9e87fae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -23,7 +23,7 @@ import breeze.linalg.{DenseVector => BDV, norm => brzNorm} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} @@ -44,8 +44,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol /** - * :: AlphaComponent :: - * + * :: Experimental :: * Linear regression. * * The learning objective is to minimize the squared error, with regularization. @@ -58,7 +57,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams * - L1 (Lasso) * - L2 + L1 (elastic net) */ -@AlphaComponent +@Experimental class LinearRegression(override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams with Logging { @@ -84,7 +83,7 @@ class LinearRegression(override val uid: String) setDefault(elasticNetParam -> 0.0) /** - * Set the maximal number of iterations. + * Set the maximum number of iterations. * Default is 100. * @group setParam */ @@ -190,11 +189,10 @@ class LinearRegression(override val uid: String) } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Model produced by [[LinearRegression]]. */ -@AlphaComponent +@Experimental class LinearRegressionModel private[ml] ( override val uid: String, val weights: Vector, @@ -323,7 +321,7 @@ private class LeastSquaresAggregator( } (weightsArray, -sum + labelMean / labelStd, weightsArray.length) } - + private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) private val gradientSumArray = Array.ofDim[Double](dim) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 82437aa8de294..49a1f7ce8c995 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{RandomForestParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -31,12 +31,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression. * It supports both continuous and categorical features. */ -@AlphaComponent +@Experimental final class RandomForestRegressor(override val uid: String) extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] with RandomForestParams with TreeRegressorParams { @@ -89,6 +88,7 @@ final class RandomForestRegressor(override val uid: String) } } +@Experimental object RandomForestRegressor { /** Accessor for supported impurity settings: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities @@ -99,13 +99,12 @@ object RandomForestRegressor { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. * It supports both continuous and categorical features. * @param _trees Decision trees in the ensemble. */ -@AlphaComponent +@Experimental final class RandomForestRegressionModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel]) @@ -153,7 +152,7 @@ private[ml] object RandomForestRegressionModel { require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } new RandomForestRegressionModel(parent.uid, newTrees) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d2dec0c76cb12..4242154be14ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -17,14 +17,16 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} - /** + * :: DeveloperApi :: * Decision tree node interface. */ +@DeveloperApi sealed abstract class Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree @@ -89,10 +91,12 @@ private[ml] object Node { } /** + * :: DeveloperApi :: * Decision tree leaf node. * @param prediction Prediction this node makes * @param impurity Impurity measure at this node (for training data) */ +@DeveloperApi final class LeafNode private[ml] ( override val prediction: Double, override val impurity: Double) extends Node { @@ -118,6 +122,7 @@ final class LeafNode private[ml] ( } /** + * :: DeveloperApi :: * Internal Decision Tree node. * @param prediction Prediction this node would make if it were a leaf node * @param impurity Impurity measure at this node (for training data) @@ -127,6 +132,7 @@ final class LeafNode private[ml] ( * @param rightChild Right-hand child node * @param split Information about the test used to split to the left or right child. */ +@DeveloperApi final class InternalNode private[ml] ( override val prediction: Double, override val impurity: Double, @@ -153,9 +159,9 @@ final class InternalNode private[ml] ( override private[tree] def subtreeToString(indentFactor: Int = 0): String = { val prefix: String = " " * indentFactor - prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" + + prefix + s"If (${InternalNode.splitToString(split, left = true)})\n" + leftChild.subtreeToString(indentFactor + 1) + - prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" + + prefix + s"Else (${InternalNode.splitToString(split, left = false)})\n" + rightChild.subtreeToString(indentFactor + 1) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index 90f1d052764d3..7acdeeee72d23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -17,15 +17,18 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} import org.apache.spark.mllib.tree.model.{Split => OldSplit} /** + * :: DeveloperApi :: * Interface for a "Split," which specifies a test made at a decision tree node * to choose the left or right path. */ +@DeveloperApi sealed trait Split extends Serializable { /** Index of feature which this split tests */ @@ -52,12 +55,14 @@ private[tree] object Split { } /** + * :: DeveloperApi :: * Split which tests a categorical feature. * @param featureIndex Index of the feature to test * @param _leftCategories If the feature value is in this set of categories, then the split goes * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ +@DeveloperApi final class CategoricalSplit private[ml] ( override val featureIndex: Int, _leftCategories: Array[Double], @@ -125,11 +130,13 @@ final class CategoricalSplit private[ml] ( } /** + * :: DeveloperApi :: * Split which tests a continuous feature. * @param featureIndex Index of the feature to test * @param threshold If the feature value is <= this threshold, then the split goes left. * Otherwise, it goes right. */ +@DeveloperApi final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) extends Split { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 816fcedf2efb3..a0c5238d966bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.tree -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} @@ -26,12 +25,10 @@ import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldG import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} /** - * :: DeveloperApi :: * Parameters for Decision Tree-based algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait DecisionTreeParams extends PredictorParams { /** @@ -265,12 +262,10 @@ private[ml] object TreeRegressorParams { } /** - * :: DeveloperApi :: * Parameters for Decision Tree-based ensemble algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** @@ -307,12 +302,10 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { } /** - * :: DeveloperApi :: * Parameters for Random Forest algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait RandomForestParams extends TreeEnsembleParams { /** @@ -377,12 +370,10 @@ private[ml] object RandomForestParams { } /** - * :: DeveloperApi :: * Parameters for Gradient-Boosted Tree algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index e21ff94a20f54..cb29392e8bc63 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ @@ -79,10 +79,10 @@ private[ml] trait CrossValidatorParams extends Params { } /** - * :: AlphaComponent :: + * :: Experimental :: * K-fold cross validation. */ -@AlphaComponent +@Experimental class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging { @@ -102,12 +102,6 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) - override def validateParams(paramMap: ParamMap): Unit = { - getEstimatorParamMaps.foreach { eMap => - getEstimator.validateParams(eMap ++ paramMap) - } - } - override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -141,26 +135,35 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - copyValues(new CrossValidatorModel(uid, bestModel).setParent(this)) + copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + + override def validateParams(): Unit = { + super.validateParams() + val est = $(estimator) + for (paramMap <- $(estimatorParamMaps)) { + est.copy(paramMap).validateParams() + } + } } /** - * :: AlphaComponent :: + * :: Experimental :: * Model from k-fold cross validation. */ -@AlphaComponent +@Experimental class CrossValidatorModel private[ml] ( override val uid: String, - val bestModel: Model[_]) + val bestModel: Model[_], + val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def validateParams(paramMap: ParamMap): Unit = { - bestModel.validateParams(paramMap) + override def validateParams(): Unit = { + bestModel.validateParams() } override def transform(dataset: DataFrame): DataFrame = { @@ -171,4 +174,12 @@ class CrossValidatorModel private[ml] ( override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + + override def copy(extra: ParamMap): CrossValidatorModel = { + val copied = new CrossValidatorModel( + uid, + bestModel.copy(extra).asInstanceOf[Model[_]], + avgMetrics.clone()) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala index dafe73d82c00a..98a8f0330ca45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.tuning import scala.annotation.varargs import scala.collection.mutable -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ /** - * :: AlphaComponent :: + * :: Experimental :: * Builder for a param grid used in grid search-based model selection. */ -@AlphaComponent +@Experimental class ParamGridBuilder { private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 2fa54df6fc2b2..8f66bc808a007 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -43,7 +43,8 @@ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} +import org.apache.spark.mllib.stat.{ + KernelDensity, MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.loss.Losses @@ -392,14 +393,14 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[Vector], wt: Vector, mu: Array[Object], - si: Array[Object]): RDD[Vector] = { + si: Array[Object]): RDD[Vector] = { val weight = wt.toArray val mean = mu.map(_.asInstanceOf[DenseVector]) val sigma = si.map(_.asInstanceOf[DenseMatrix]) val gaussians = Array.tabulate(weight.length){ i => new MultivariateGaussian(mean(i), sigma(i)) - } + } val model = new GaussianMixtureModel(weight, gaussians) model.predictSoft(data).map(Vectors.dense) } @@ -428,7 +429,7 @@ private[python] class PythonMLLibAPI extends Serializable { if (seed != null) als.setSeed(seed) - val model = als.run(ratingsJRDD.rdd) + val model = als.run(ratingsJRDD.rdd) new MatrixFactorizationModelWrapper(model) } @@ -459,7 +460,7 @@ private[python] class PythonMLLibAPI extends Serializable { if (seed != null) als.setSeed(seed) - val model = als.run(ratingsJRDD.rdd) + val model = als.run(ratingsJRDD.rdd) new MatrixFactorizationModelWrapper(model) } @@ -494,7 +495,7 @@ private[python] class PythonMLLibAPI extends Serializable { def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = { new Normalizer(p).transform(rdd) } - + /** * Java stub for StandardScaler.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. @@ -945,6 +946,15 @@ private[python] class PythonMLLibAPI extends Serializable { r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any]))) } + /** + * Java stub for the estimate method of KernelDensity + */ + def estimateKernelDensity( + sample: JavaRDD[Double], + bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = { + return new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate( + points.asScala.toArray) + } } @@ -1242,7 +1252,7 @@ private[spark] object SerDe extends Serializable { } /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ - def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { + def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { rdd.map(x => Array(x._1, x._2)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index c88410ac0ff43..fc509d2ba1470 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.IndexedSeq import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV} import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils @@ -36,11 +37,11 @@ import org.apache.spark.util.Utils * independent Gaussian distributions with associated "mixing" weights * specifying each's contribution to the composite. * - * Given a set of sample points, this class will maximize the log-likelihood - * for a mixture of k Gaussians, iterating until the log-likelihood changes by + * Given a set of sample points, this class will maximize the log-likelihood + * for a mixture of k Gaussians, iterating until the log-likelihood changes by * less than convergenceTol, or until it has reached the max number of iterations. * While this process is generally guaranteed to converge, it is not guaranteed - * to find a global optimum. + * to find a global optimum. * * Note: For high-dimensional data (with many features), this algorithm may perform poorly. * This is due to high-dimensional data (a) making it difficult to cluster at all (based @@ -53,24 +54,24 @@ import org.apache.spark.util.Utils */ @Experimental class GaussianMixture private ( - private var k: Int, - private var convergenceTol: Double, + private var k: Int, + private var convergenceTol: Double, private var maxIterations: Int, private var seed: Long) extends Serializable { - + /** * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01, * maxIterations: 100, seed: random}. */ def this() = this(2, 0.01, 100, Utils.random.nextLong()) - + // number of samples per cluster to use when initializing Gaussians private val nSamples = 5 - - // an initializing GMM can be provided rather than using the + + // an initializing GMM can be provided rather than using the // default random starting point private var initialModel: Option[GaussianMixtureModel] = None - + /** Set the initial GMM starting point, bypassing the random initialization. * You must call setK() prior to calling this method, and the condition * (model.k == this.k) must be met; failure will result in an IllegalArgumentException @@ -83,37 +84,37 @@ class GaussianMixture private ( } this } - + /** Return the user supplied initial GMM, if supplied */ def getInitialModel: Option[GaussianMixtureModel] = initialModel - + /** Set the number of Gaussians in the mixture model. Default: 2 */ def setK(k: Int): this.type = { this.k = k this } - + /** Return the number of Gaussians in the mixture model */ def getK: Int = k - + /** Set the maximum number of iterations to run. Default: 100 */ def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } - + /** Return the maximum number of iterations to run */ def getMaxIterations: Int = maxIterations - + /** - * Set the largest change in log-likelihood at which convergence is + * Set the largest change in log-likelihood at which convergence is * considered to have occurred. */ def setConvergenceTol(convergenceTol: Double): this.type = { this.convergenceTol = convergenceTol this } - + /** * Return the largest change in log-likelihood at which convergence is * considered to have occurred. @@ -132,41 +133,41 @@ class GaussianMixture private ( /** Perform expectation maximization */ def run(data: RDD[Vector]): GaussianMixtureModel = { val sc = data.sparkContext - + // we will operate on the data as breeze data val breezeData = data.map(_.toBreeze).cache() - + // Get length of the input vectors val d = breezeData.first().length - + // Determine initial weights and corresponding Gaussians. // If the user supplied an initial GMM, we use those values, otherwise // we start with uniform weights, a random mean from the data, and // diagonal covariance matrices using component variances - // derived from the samples + // derived from the samples val (weights, gaussians) = initialModel match { case Some(gmm) => (gmm.weights, gmm.gaussians) - + case None => { val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) - (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => + (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) - new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) + new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) }) } } - - var llh = Double.MinValue // current log-likelihood + + var llh = Double.MinValue // current log-likelihood var llhp = 0.0 // previous log-likelihood - + var iter = 0 while (iter < maxIterations && math.abs(llh-llhp) > convergenceTol) { // create and broadcast curried cluster contribution function val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_) - + // aggregate the cluster contribution for all sample points val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _) - + // Create new distributions based on the partial assignments // (often referred to as the "M" step in literature) val sumWeights = sums.weights.sum @@ -179,22 +180,25 @@ class GaussianMixture private ( gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i)) i = i + 1 } - + llhp = llh // current becomes previous llh = sums.logLikelihood // this is the freshly computed log-likelihood iter += 1 - } - + } + new GaussianMixtureModel(weights, gaussians) } - + + /** Java-friendly version of [[run()]] */ + def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd) + /** Average of dense breeze vectors */ private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { val v = BDV.zeros[Double](x(0).length) x.foreach(xi => v += xi) - v / x.length.toDouble + v / x.length.toDouble } - + /** * Construct matrix where diagonal entries are element-wise * variance of input vectors (computes biased variance) @@ -210,14 +214,14 @@ class GaussianMixture private ( // companion class to provide zero constructor for ExpectationSum private object ExpectationSum { def zero(k: Int, d: Int): ExpectationSum = { - new ExpectationSum(0.0, Array.fill(k)(0.0), - Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) + new ExpectationSum(0.0, Array.fill(k)(0.0), + Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d, d))) } - + // compute cluster contributions for each input point // (U, T) => U for aggregation def add( - weights: Array[Double], + weights: Array[Double], dists: Array[MultivariateGaussian]) (sums: ExpectationSum, x: BV[Double]): ExpectationSum = { val p = weights.zip(dists).map { @@ -235,7 +239,7 @@ private object ExpectationSum { i = i + 1 } sums - } + } } // Aggregation class for partial expectation results @@ -244,9 +248,9 @@ private class ExpectationSum( val weights: Array[Double], val means: Array[BDV[Double]], val sigmas: Array[BreezeMatrix[Double]]) extends Serializable { - + val k = weights.length - + def +=(x: ExpectationSum): ExpectationSum = { var i = 0 while (i < k) { @@ -257,5 +261,5 @@ private class ExpectationSum( } logLikelihood += x.logLikelihood this - } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 86353aed81156..cb807c8038101 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable} @@ -34,10 +35,10 @@ import org.apache.spark.sql.{SQLContext, Row} /** * :: Experimental :: * - * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points - * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are - * the respective mean and covariance for each Gaussian distribution i=1..k. - * + * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points + * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are + * the respective mean and covariance for each Gaussian distribution i=1..k. + * * @param weights Weights for each Gaussian distribution in the mixture, where weights(i) is * the weight for Gaussian i, and weights.sum == 1 * @param gaussians Array of MultivariateGaussian where gaussians(i) represents @@ -45,9 +46,9 @@ import org.apache.spark.sql.{SQLContext, Row} */ @Experimental class GaussianMixtureModel( - val weights: Array[Double], - val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{ - + val weights: Array[Double], + val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable { + require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") override protected def formatVersion = "1.0" @@ -64,20 +65,24 @@ class GaussianMixtureModel( val responsibilityMatrix = predictSoft(points) responsibilityMatrix.map(r => r.indexOf(r.max)) } - + + /** Java-friendly version of [[predict()]] */ + def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = + predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] + /** * Given the input vectors, return the membership value of each vector - * to all mixture components. + * to all mixture components. */ def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext val bcDists = sc.broadcast(gaussians) val bcWeights = sc.broadcast(weights) - points.map { x => + points.map { x => computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k) } } - + /** * Compute the partial assignments for each vector */ @@ -89,7 +94,7 @@ class GaussianMixtureModel( val p = weights.zip(dists).map { case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt) } - val pSum = p.sum + val pSum = p.sum for (i <- 0 until k) { p(i) /= pSum } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 6cf26445f20a0..974b26924dfb8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum} import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx.{VertexId, EdgeContext, Graph} import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} import org.apache.spark.rdd.RDD @@ -345,6 +346,11 @@ class DistributedLDAModel private ( } } + /** Java-friendly version of [[topicDistributions]] */ + def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = { + JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) + } + // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 6fa2fe053c6a4..8e5154b902d1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -273,7 +273,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * Default: 1024, following the original Online LDA paper. */ def setTau0(tau0: Double): this.type = { - require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") + require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") this.tau0 = tau0 this } @@ -339,7 +339,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { override private[clustering] def initialize( docs: RDD[(Long, Vector)], - lda: LDA): OnlineLDAOptimizer = { + lda: LDA): OnlineLDAOptimizer = { this.k = lda.getK this.corpusSize = docs.count() this.vocabSize = docs.first()._2.size @@ -458,7 +458,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * uses digamma which is accurate but expensive. */ private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { - val rowSum = sum(alpha(breeze.linalg.*, ::)) + val rowSum = sum(alpha(breeze.linalg.*, ::)) val digAlpha = digamma(alpha) val digRowSum = digamma(rowSum) val result = digAlpha(::, breeze.linalg.*) - digRowSum diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 1ed01c9d8ba0b..e7a243f854e33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -121,7 +121,7 @@ class PowerIterationClustering private[clustering] ( import org.apache.spark.mllib.clustering.PowerIterationClustering._ /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, - * initMode: "random"}. + * initMode: "random"}. */ def this() = this(k = 2, maxIterations = 100, initMode = "random") @@ -243,7 +243,7 @@ object PowerIterationClustering extends Logging { /** * Generates random vertex properties (v0) to start power iteration. - * + * * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing a random vector * with unit 1-norm @@ -266,7 +266,7 @@ object PowerIterationClustering extends Logging { * Generates the degree vector as the vertex properties (v0) to start power iteration. * It is not exactly the node degrees but just the normalized sum similarities. Call it * as degree vector because it is used in the PIC paper. - * + * * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing the degree vector */ @@ -276,7 +276,7 @@ object PowerIterationClustering extends Logging { val v0 = g.vertices.mapValues(_ / sum) GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) } - + /** * Runs power iteration. * @param g input graph with edges representing the normalized affinity matrix (W) and vertices diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 812014a041719..d9b34cec64894 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -21,8 +21,10 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaSparkContext._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -178,7 +180,7 @@ class StreamingKMeans( /** Set the decay factor directly (for forgetful algorithms). */ def setDecayFactor(a: Double): this.type = { - this.decayFactor = decayFactor + this.decayFactor = a this } @@ -234,6 +236,9 @@ class StreamingKMeans( } } + /** Java-friendly version of `trainOn`. */ + def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream) + /** * Use the clustering model to make predictions on batches of data from a DStream. * @@ -245,6 +250,11 @@ class StreamingKMeans( data.map(model.predict) } + /** Java-friendly version of `predictOn`. */ + def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = { + JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]]) + } + /** * Use the model to make predictions on the values of a DStream and carry over its keys. * @@ -257,6 +267,14 @@ class StreamingKMeans( data.mapValues(model.predict) } + /** Java-friendly version of `predictOnValues`. */ + def predictOnValues[K]( + data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = { + implicit val tag = fakeClassTag[K] + JavaPairDStream.fromPairDStream( + predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Integer)]]) + } + /** Check whether cluster centers have been initialized. */ private[this] def assertInitialized(): Unit = { if (model.clusterCenters == null) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 9cc2d0ffcab7d..5f8c1dea237b4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -108,7 +108,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf * (ordered by statistic value descending) */ @Experimental -class ChiSqSelector (val numTopFeatures: Int) { +class ChiSqSelector (val numTopFeatures: Int) extends Serializable { /** * Returns a ChiSquared feature selector. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala index b0985baf9b278..d67fe6c3ee4f8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala @@ -25,10 +25,10 @@ import org.apache.spark.mllib.linalg._ * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a * provided "weight" vector. In other words, it scales each column of the dataset by a scalar * multiplier. - * @param scalingVector The values used to scale the reference vector's individual components. + * @param scalingVec The values used to scale the reference vector's individual components. */ @Experimental -class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { +class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer { /** * Does the hadamard product transformation. @@ -37,15 +37,15 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { * @return transformed vector. */ override def transform(vector: Vector): Vector = { - require(vector.size == scalingVector.size, - s"vector sizes do not match: Expected ${scalingVector.size} but found ${vector.size}") + require(vector.size == scalingVec.size, + s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}") vector match { case dv: DenseVector => val values: Array[Double] = dv.values.clone() - val dim = scalingVector.size + val dim = scalingVec.size var i = 0 while (i < dim) { - values(i) *= scalingVector(i) + values(i) *= scalingVec(i) i += 1 } Vectors.dense(values) @@ -54,7 +54,7 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { val dim = values.length var i = 0 while (i < dim) { - values(i) *= scalingVector(indices(i)) + values(i) *= scalingVec(indices(i)) i += 1 } Vectors.sparse(size, indices, values) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index a89eea0e21be2..efbfeb4059f5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -144,7 +144,7 @@ private object IDF { * Since arrays are initialized to 0 by default, * we just omit changing those entries. */ - if(df(j) >= minDocFreq) { + if (df(j) >= minDocFreq) { inv(j) = math.log((m + 1.0) / (df(j) + 1.0)) } j += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 6ae6917eae595..c73b8f258060d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -90,7 +90,7 @@ class StandardScalerModel ( @DeveloperApi def setWithMean(withMean: Boolean): this.type = { - require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null") + require(!(withMean && this.mean == null), "cannot set withMean to true while mean is null") this.withMean = withMean this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 9106b73dfcd76..51546d41c36a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -42,32 +42,32 @@ import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.sql.{SQLContext, Row} /** - * Entry in vocabulary + * Entry in vocabulary */ private case class VocabWord( var word: String, var cn: Int, var point: Array[Int], var code: Array[Int], - var codeLen:Int + var codeLen: Int ) /** * :: Experimental :: * Word2Vec creates vector representation of words in a text corpus. * The algorithm first constructs a vocabulary from the corpus - * and then learns vector representation of words in the vocabulary. - * The vector representation can be used as features in + * and then learns vector representation of words in the vocabulary. + * The vector representation can be used as features in * natural language processing and machine learning algorithms. - * - * We used skip-gram model in our implementation and hierarchical softmax + * + * We used skip-gram model in our implementation and hierarchical softmax * method to train the model. The variable names in the implementation * matches the original C implementation. * - * For original C implementation, see https://code.google.com/p/word2vec/ - * For research papers, see + * For original C implementation, see https://code.google.com/p/word2vec/ + * For research papers, see * Efficient Estimation of Word Representations in Vector Space - * and + * and * Distributed Representations of Words and Phrases and their Compositionality. */ @Experimental @@ -79,7 +79,7 @@ class Word2Vec extends Serializable with Logging { private var numIterations = 1 private var seed = Utils.random.nextLong() private var minCount = 5 - + /** * Sets vector size (default: 100). */ @@ -122,15 +122,15 @@ class Word2Vec extends Serializable with Logging { this } - /** - * Sets minCount, the minimum number of times a token must appear to be included in the word2vec + /** + * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). */ def setMinCount(minCount: Int): this.type = { this.minCount = minCount this } - + private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 @@ -150,13 +150,13 @@ class Word2Vec extends Serializable with Logging { .map(x => VocabWord( x._1, x._2, - new Array[Int](MAX_CODE_LENGTH), - new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), 0)) .filter(_.cn >= minCount) .collect() .sortWith((a, b) => a.cn > b.cn) - + vocabSize = vocab.length require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " + "the setting of minCount, which could be large enough to remove all your words in sentences.") @@ -198,8 +198,8 @@ class Word2Vec extends Serializable with Logging { } var pos1 = vocabSize - 1 var pos2 = vocabSize - - var min1i = 0 + + var min1i = 0 var min2i = 0 a = 0 @@ -268,15 +268,15 @@ class Word2Vec extends Serializable with Logging { val words = dataset.flatMap(x => x) learnVocab(words) - + createBinaryTree() - + val sc = dataset.context val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) - + val sentences: RDD[Array[Int]] = words.mapPartitions { iter => new Iterator[Array[Int]] { def hasNext: Boolean = iter.hasNext @@ -297,7 +297,7 @@ class Word2Vec extends Serializable with Logging { } } } - + val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) @@ -402,7 +402,7 @@ class Word2Vec extends Serializable with Logging { } } newSentences.unpersist() - + val word2VecMap = mutable.HashMap.empty[String, Array[Float]] var i = 0 while (i < vocabSize) { @@ -469,7 +469,7 @@ class Word2VecModel private[mllib] ( val norm1 = blas.snrm2(n, v1, 1) val norm2 = blas.snrm2(n, v2, 1) if (norm1 == 0 || norm2 == 0) return 0.0 - blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 + blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2 } override protected def formatVersion = "1.0" @@ -480,7 +480,7 @@ class Word2VecModel private[mllib] ( /** * Transforms a word to its vector representation - * @param word a word + * @param word a word * @return vector representation of word */ def transform(word: String): Vector = { @@ -495,18 +495,18 @@ class Word2VecModel private[mllib] ( /** * Find synonyms of a word * @param word a word - * @param num number of synonyms to find + * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) - findSynonyms(vector,num) + findSynonyms(vector, num) } /** * Find synonyms of the vector representation of a word * @param vector vector representation of a word - * @param num number of synonyms to find + * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index ec38529cf8fae..557119f7b1cd1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -228,7 +228,7 @@ private[spark] object BLAS extends Serializable with Logging { } _nativeBLAS } - + /** * A := alpha * x * x^T^ + A * @param alpha a real scalar that will be multiplied to x * x^T^. @@ -264,7 +264,7 @@ private[spark] object BLAS extends Serializable with Logging { j += 1 } i += 1 - } + } } private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) { @@ -505,7 +505,7 @@ private[spark] object BLAS extends Serializable with Logging { nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta, y.values, 1) } - + /** * y := alpha * A * x + beta * y * For `DenseMatrix` A and `SparseVector` x. @@ -557,7 +557,7 @@ private[spark] object BLAS extends Serializable with Logging { } } } - + /** * y := alpha * A * x + beta * y * For `SparseMatrix` A and `SparseVector` x. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 866936aa4f118..ae3ba3099c878 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -81,7 +81,7 @@ private[mllib] object EigenValueDecomposition { require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE, s"k = $k and/or n = $n are too large to compute an eigendecomposition") - + var ido = new intW(0) var info = new intW(0) var resid = new Array[Double](n) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index f6bcdf83cd337..2ffa497a99d93 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -176,27 +176,31 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { } override def serialize(obj: Any): Row = { - val row = new GenericMutableRow(4) obj match { case SparseVector(size, indices, values) => + val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) row.update(2, indices.toSeq) row.update(3, values.toSeq) + row case DenseVector(values) => + val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) row.update(3, values.toSeq) + row + // TODO: There are bugs in UDT serialization because we don't have a clear separation between + // TODO: internal SQL types and language specific types (including UDT). UDT serialize and + // TODO: deserialize may get called twice. See SPARK-7186. + case row: Row => + row } - row } override def deserialize(datum: Any): Vector = { datum match { - // TODO: something wrong with UDT serialization - case v: Vector => - v case row: Row => require(row.length == 4, s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") @@ -211,6 +215,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { val values = row.getAs[Iterable[Double]](3).toArray new DenseVector(values) } + // TODO: There are bugs in UDT serialization because we don't have a clear separation between + // TODO: internal SQL types and language specific types (including UDT). UDT serialize and + // TODO: deserialize may get called twice. See SPARK-7186. + case v: Vector => + v } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 9a89a6f3a515f..1626da9c3d2ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -219,7 +219,7 @@ class RowMatrix( val computeMode = mode match { case "auto" => - if(k > 5000) { + if (k > 5000) { logWarning(s"computing svd with k=$k and n=$n, please check necessity") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 4b7d0589c973b..06e45e10c5bf4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -179,7 +179,7 @@ object GradientDescent extends Logging { * if it's L2 updater; for L1 updater, the same logic is followed. */ var regVal = updater.compute( - weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 + weights, Vectors.zeros(weights.size), 0, 1, regParam)._2 for (i <- 1 to numIterations) { val bcWeights = data.context.broadcast(weights) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index 34b447584e521..622b53a252ac5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -27,10 +27,10 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel */ private[mllib] class BinaryClassificationPMMLModelExport( - model : GeneralizedLinearModel, + model : GeneralizedLinearModel, description : String, normalizationMethod : RegressionNormalizationMethodType, - threshold: Double) + threshold: Double) extends PMMLModelExport { populateBinaryClassificationPMML() @@ -72,7 +72,7 @@ private[mllib] class BinaryClassificationPMMLModelExport( .withUsageType(FieldUsageType.ACTIVE)) regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } - + // add target field val targetField = FieldName.create("target") dataDictionary @@ -80,9 +80,9 @@ private[mllib] class BinaryClassificationPMMLModelExport( miningSchema .withMiningFields(new MiningField(targetField) .withUsageType(FieldUsageType.TARGET)) - + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) - + pmml.setDataDictionary(dataDictionary) pmml.withModels(regressionModel) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index ebdeae50bb32f..c5fdecd3ca17f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -25,7 +25,7 @@ import scala.beans.BeanProperty import org.dmg.pmml.{Application, Header, PMML, Timestamp} private[mllib] trait PMMLModelExport { - + /** * Holder of the exported model in PMML format */ @@ -33,7 +33,7 @@ private[mllib] trait PMMLModelExport { val pmml: PMML = new PMML setHeader(pmml) - + private def setHeader(pmml: PMML): Unit = { val version = getClass.getPackage.getImplementationVersion val app = new Application().withName("Apache Spark MLlib").withVersion(version) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala index c16e83d6a067d..29bd689e1185a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.LinearRegressionModel import org.apache.spark.mllib.regression.RidgeRegressionModel private[mllib] object PMMLModelExportFactory { - + /** - * Factory object to help creating the necessary PMMLModelExport implementation + * Factory object to help creating the necessary PMMLModelExport implementation * taking as input the machine learning model (for example KMeansModel). */ def createPMMLModelExport(model: Any): PMMLModelExport = { @@ -44,7 +44,7 @@ private[mllib] object PMMLModelExportFactory { new GeneralizedLinearPMMLModelExport(lasso, "lasso regression") case svm: SVMModel => new BinaryClassificationPMMLModelExport( - svm, "linear SVM", RegressionNormalizationMethodType.NONE, + svm, "linear SVM", RegressionNormalizationMethodType.NONE, svm.getThreshold.getOrElse(0.0)) case logistic: LogisticRegressionModel => if (logistic.numClasses == 2) { @@ -60,5 +60,5 @@ private[mllib] object PMMLModelExportFactory { "PMML Export not supported for model: " + model.getClass.getName) } } - + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 8341bb86afd71..174d5e0f6c9f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -52,7 +52,7 @@ object RandomRDDs { numPartitions: Int = 0, seed: Long = Utils.random.nextLong()): RDD[Double] = { val uniform = new UniformGenerator() - randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed) + randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed) } /** @@ -234,7 +234,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param shape shape parameter (> 0) for the gamma distribution - * @param scale scale parameter (> 0) for the gamma distribution + * @param scale scale parameter (> 0) for the gamma distribution * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). @@ -293,7 +293,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param mean mean for the log normal distribution - * @param std standard deviation for the log normal distribution + * @param std standard deviation for the log normal distribution * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). @@ -671,7 +671,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param shape shape parameter (> 0) for the gamma distribution. - * @param scale scale parameter (> 0) for the gamma distribution. + * @param scale scale parameter (> 0) for the gamma distribution. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index dddefe1944e9d..93290e6508529 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -175,7 +175,7 @@ class ALS private ( /** * :: DeveloperApi :: * Sets storage level for final RDDs (user/product used in MatrixFactorizationModel). The default - * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g. + * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g. * `MEMORY_AND_DISK_SER` and set `spark.rdd.compress` to `true` to reduce the space requirement, * at the cost of speed. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 26be30ff9d6fd..6709bd79bc820 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -195,11 +195,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ val initialWeights = { if (numOfLinearPredictor == 1) { - Vectors.dense(new Array[Double](numFeatures)) + Vectors.zeros(numFeatures) } else if (addIntercept) { - Vectors.dense(new Array[Double]((numFeatures + 1) * numOfLinearPredictor)) + Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) } else { - Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor)) + Vectors.zeros(numFeatures * numOfLinearPredictor) } } run(input, initialWeights) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 3ea63dd8c0acd..f3b46c75c05f3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -170,15 +170,15 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { case class Data(boundary: Double, prediction: Double) def save( - sc: SparkContext, - path: String, - boundaries: Array[Double], - predictions: Array[Double], + sc: SparkContext, + path: String, + boundaries: Array[Double], + predictions: Array[Double], isotonic: Boolean): Unit = { val sqlContext = new SQLContext(sc) val metadata = compact(render( - ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("isotonic" -> isotonic))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) @@ -203,7 +203,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { implicit val formats = DefaultFormats val (loadedClassName, version, metadata) = loadMetadata(sc, path) - val isotonic = (metadata \ "isotonic").extract[Boolean] + val isotonic = (metadata \ "isotonic").extract[Boolean] val classNameV1_0 = SaveLoadV1_0.thisClassName (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index cea8f3f47307b..aee51bf22d8d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -83,15 +83,7 @@ abstract class StreamingLinearAlgorithm[ throw new IllegalArgumentException("Model must be initialized before starting training.") } data.foreachRDD { (rdd, time) => - val initialWeights = - model match { - case Some(m) => - m.weights - case None => - val numFeatures = rdd.first().features.size - Vectors.dense(numFeatures) - } - model = Some(algorithm.run(rdd, initialWeights)) + model = Some(algorithm.run(rdd, model.get.weights)) logInfo("Model updated at time %s".format(time.toString)) val display = model.get.weights.size match { case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index a49153bf73c0d..235e043c7754b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -79,7 +79,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( this } - /** Set the initial weights. Default: [0.0, 0.0]. */ + /** Set the initial weights. */ def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index a6bfe26e1e4f5..58a50f9c19f14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -93,7 +93,7 @@ class KernelDensity extends Serializable { x._1(i) += normPdf(y, bandwidth, logStandardDeviationPlusHalfLog2Pi, points(i)) i += 1 } - (x._1, n) + (x._1, x._2 + 1) }, (x, y) => { blas.daxpy(n, 1.0, y._1, 1, x._1, 1) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 0b1755613aac4..d321cc554c1cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -70,7 +70,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - val localCurrMean= currMean + val localCurrMean = currMean val localCurrM2n = currM2n val localCurrM2 = currM2 val localCurrL1 = currL1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index b3fad0c52d655..900007ec6bc74 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.regression.LabeledPoint @@ -80,6 +81,10 @@ object Statistics { */ def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) + /** Java-friendly version of [[corr()]] */ + def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double = + corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]]) + /** * Compute the correlation for the input RDDs using the specified method. * Methods currently supported: `pearson` (default), `spearman`. @@ -96,6 +101,10 @@ object Statistics { */ def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) + /** Java-friendly version of [[corr()]] */ + def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double = + corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method) + /** * Conduct Pearson's chi-squared goodness of fit test of the observed data against the * expected distribution. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index cd6add9d60b0d..cf51b24ff777f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -29,102 +29,102 @@ import org.apache.spark.mllib.util.MLUtils * the event that the covariance matrix is singular, the density will be computed in a * reduced dimensional subspace under which the distribution is supported. * (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]]) - * + * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ @DeveloperApi class MultivariateGaussian ( - val mu: Vector, + val mu: Vector, val sigma: Matrix) extends Serializable { require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") - + private val breezeMu = mu.toBreeze.toDenseVector - + /** * private[mllib] constructor - * + * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = { this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma)) } - + /** * Compute distribution dependent constants: * rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t - * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) + * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants - + /** Returns density of this multivariate Gaussian at given point, x */ def pdf(x: Vector): Double = { pdf(x.toBreeze) } - + /** Returns the log-density of this multivariate Gaussian at given point, x */ def logpdf(x: Vector): Double = { logpdf(x.toBreeze) } - + /** Returns density of this multivariate Gaussian at given point, x */ private[mllib] def pdf(x: BV[Double]): Double = { math.exp(logpdf(x)) } - + /** Returns the log-density of this multivariate Gaussian at given point, x */ private[mllib] def logpdf(x: BV[Double]): Double = { val delta = x - breezeMu val v = rootSigmaInv * delta u + v.t * v * -0.5 } - + /** * Calculate distribution dependent components used for the density function: * pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu)) * where k is length of the mean vector. - * - * We here compute distribution-fixed parts + * + * We here compute distribution-fixed parts * log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) * and * D^(-1/2)^ * U, where sigma = U * D * U.t - * + * * Both the determinant and the inverse can be computed from the singular value decomposition * of sigma. Noting that covariance matrices are always symmetric and positive semi-definite, * we can use the eigendecomposition. We also do not compute the inverse directly; noting - * that - * + * that + * * sigma = U * D * U.t - * inv(Sigma) = U * inv(D) * U.t + * inv(Sigma) = U * inv(D) * U.t * = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U) - * + * * and thus - * + * * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^ - * - * To guard against singular covariance matrices, this method computes both the + * + * To guard against singular covariance matrices, this method computes both the * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered * to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and * relation to the maximum singular value (same tolerance used by, e.g., Octave). */ private def calculateCovarianceConstants: (DBM[Double], Double) = { val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t - + // For numerical stability, values are considered to be non-zero only if they exceed tol. // This prevents any inverted value from exceeding (eps * n * max(d))^-1 val tol = MLUtils.EPSILON * max(d) * d.length - + try { // log(pseudo-determinant) is sum of the logs of all non-zero singular values val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum - - // calculate the root-pseudo-inverse of the diagonal matrix of singular values + + // calculate the root-pseudo-inverse of the diagonal matrix of singular values // by inverting the square root of all non-zero values val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray)) - + (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) } catch { case uex: UnsupportedOperationException => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index e597fce2babd1..23c8d7c7c8075 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -196,7 +196,7 @@ private[stat] object ChiSqTest extends Logging { * Pearson's independence test on the input contingency matrix. * TODO: optimize for SparseMatrix when it becomes supported. */ - def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = { + def chiSquaredMatrix(counts: Matrix, methodName: String = PEARSON.name): ChiSqTestResult = { val method = methodFromString(methodName) val numRows = counts.numRows val numCols = counts.numCols diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index dfe3a0b6913ef..cecd1fed896d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -169,7 +169,7 @@ object DecisionTree extends Serializable with Logging { numClasses: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, - categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { + categoricalFeaturesInfo: Map[Int, Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree(strategy).run(input) @@ -768,7 +768,7 @@ object DecisionTree extends Serializable with Logging { */ private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { + rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) val predict = calculatePredict(parentNodeAgg) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 1f779584dcffd..a835f96d5d0e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -60,12 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false) + case Regression => + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, - remappedInput, boostingStrategy, validate=false) + GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -93,8 +93,8 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoostedTrees.boost( - input, validationInput, boostingStrategy, validate=true) + case Regression => + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map( @@ -102,7 +102,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) val remappedValidationInput = validationInput.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate=true) + validate = true) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -270,7 +270,7 @@ object GradientBoostedTrees extends Logging { logInfo(s"$timer") if (persistedInput) input.unpersist() - + if (validate) { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index b347c450c1aa8..069959976a188 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -249,7 +249,7 @@ private class RandomForest ( try { nodeIdCache.get.deleteAllCheckpoints() } catch { - case e:IOException => + case e: IOException => logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") } } @@ -474,7 +474,7 @@ object RandomForest extends Serializable with Logging { val (treeIndex, node) = nodeQueue.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - Some(SamplingUtils.reservoirSampleAndCount(Range(0, + Some(SamplingUtils.reservoirSampleAndCount(Range(0, metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1) } else { None diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 431a839817eac..a6d1398fc267b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -83,7 +83,7 @@ class Node ( def predict(features: Vector) : Double = { if (isLeaf) { predict.predict - } else{ + } else { if (split.get.featureType == Continuous) { if (features(split.get.feature) <= split.get.threshold) { leftNode.get.predict(features) @@ -151,9 +151,9 @@ class Node ( s"(feature ${split.feature} > ${split.threshold})" } case Categorical => if (left) { - s"(feature ${split.feature} in ${split.categories.mkString("{",",","}")})" + s"(feature ${split.feature} in ${split.categories.mkString("{", ",", "}")})" } else { - s"(feature ${split.feature} not in ${split.categories.mkString("{",",","}")})" + s"(feature ${split.feature} not in ${split.categories.mkString("{", ",", "}")})" } } } @@ -161,9 +161,9 @@ class Node ( if (isLeaf) { prefix + s"Predict: ${predict.predict}\n" } else { - prefix + s"If ${splitToString(split.get, left=true)}\n" + + prefix + s"If ${splitToString(split.get, left = true)}\n" + leftNode.get.subtreeToString(indentFactor + 1) + - prefix + s"Else ${splitToString(split.get, left=false)}\n" + + prefix + s"Else ${splitToString(split.get, left = false)}\n" + rightNode.get.subtreeToString(indentFactor + 1) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 0c5b4f9d04a74..bd73a866c8a82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -82,8 +82,7 @@ object MFDataGenerator { BLAS.gemm(z, A, B, 1.0, fullData) val df = rank * (m + n - rank) - val sampSize = scala.math.min(scala.math.round(trainSampFact * df), - scala.math.round(.99 * m * n)).toInt + val sampSize = math.min(math.round(trainSampFact * df), math.round(.99 * m * n)).toInt val rand = new Random() val mn = m * n val shuffled = rand.shuffle((0 until mn).toList) @@ -102,8 +101,8 @@ object MFDataGenerator { // optionally generate testing data if (test) { - val testSampSize = scala.math - .min(scala.math.round(sampSize * testSampFact),scala.math.round(mn - sampSize)).toInt + val testSampSize = math.min( + math.round(sampSize * testSampFact), math.round(mn - sampSize)).toInt val testOmega = shuffled.slice(sampSize, sampSize + testSampSize) val testOrdered = testOmega.sortWith(_ < _).toArray val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 681f4c618d302..52d6468a72af7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -82,6 +82,18 @@ object MLUtils { val value = indexAndValue(1).toDouble (index, value) }.unzip + + // check if indices are one-based and in ascending order + var previous = -1 + var i = 0 + val indicesLength = indices.length + while (i < indicesLength) { + val current = indices(i) + require(current > previous, "indices should be one-based and in ascending order" ) + previous = current + i += 1 + } + (label, indices.toArray, values.toArray) } @@ -265,7 +277,7 @@ object MLUtils { } Vectors.fromBreeze(vector1) } - + /** * Returns the squared Euclidean distance between two vectors. The following formula will be used * if it does not introduce too much numerical error: diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java new file mode 100644 index 0000000000000..d5bd230a957a1 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaBucketizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaBucketizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void bucketizerTest() { + double[] splits = {-0.5, 0.0, 0.5}; + + JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) + )); + StructType schema = new StructType(new StructField[] { + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame dataset = jsql.createDataFrame(data, schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits); + + Row[] result = bucketizer.transform(dataset).select("result").collect(); + + for (Row r : result) { + double index = r.getDouble(0); + Assert.assertTrue((index >= 0) && (index <= 1)); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index da2218056307e..599e9cfd23ad4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -55,9 +55,9 @@ public void tearDown() { @Test public void hashingTF() { JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") + RowFactory.create(0.0, "Hi I heard about Spark"), + RowFactory.create(0.0, "I wish Java could use case classes"), + RowFactory.create(1.0, "Logistic regression models are neat") )); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java new file mode 100644 index 0000000000000..35b18c5308f61 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaStringIndexerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStringIndexerSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + sqlContext = null; + } + + @Test + public void testStringIndexer() { + StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("label", StringType, false) + }); + JavaRDD rdd = jsc.parallelize( + Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c"))); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + StringIndexer indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex"); + DataFrame output = indexer.fit(dataset).transform(dataset); + + Assert.assertArrayEquals( + new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, + output.orderBy("id").select("id", "labelIndex").collect()); + } + + /** An alias for RowFactory.create. */ + private Row c(Object... values) { + return RowFactory.create(values); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java new file mode 100644 index 0000000000000..b7c564caad3bd --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaVectorAssemblerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void testVectorAssembler() { + StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("x", DoubleType, false), + createStructField("y", new VectorUDT(), false), + createStructField("name", StringType, false), + createStructField("z", new VectorUDT(), false), + createStructField("n", LongType, false) + }); + Row row = RowFactory.create( + 0, 0.0, Vectors.dense(1.0, 2.0), "a", + Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[] {"x", "y", "z", "n"}) + .setOutputCol("features"); + DataFrame output = assembler.transform(dataset); + Assert.assertEquals( + Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), + output.select("features").first().getAs(0)); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java index e7df10dfa63ac..9890155e9f865 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -50,6 +50,7 @@ public void testParams() { testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a"); Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0); Assert.assertEquals(testParams.getMyStringParam(), "a"); + Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0); } @Test diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 947ae3a2ce06f..ff5929235ac2c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -51,7 +51,8 @@ public String uid() { public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); } public JavaTestParams setMyIntParam(int value) { - set(myIntParam_, value); return this; + set(myIntParam_, value); + return this; } private DoubleParam myDoubleParam_; @@ -60,7 +61,8 @@ public JavaTestParams setMyIntParam(int value) { public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); } public JavaTestParams setMyDoubleParam(double value) { - set(myDoubleParam_, value); return this; + set(myDoubleParam_, value); + return this; } private Param myStringParam_; @@ -69,7 +71,18 @@ public JavaTestParams setMyDoubleParam(double value) { public String getMyStringParam() { return getOrDefault(myStringParam_); } public JavaTestParams setMyStringParam(String value) { - set(myStringParam_, value); return this; + set(myStringParam_, value); + return this; + } + + private DoubleArrayParam myDoubleArrayParam_; + public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; } + + public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); } + + public JavaTestParams setMyDoubleArrayParam(double[] value) { + set(myDoubleArrayParam_, value); + return this; } private void init() { @@ -79,8 +92,14 @@ private void init() { List validStrings = Lists.newArrayList("a", "b"); myStringParam_ = new Param(this, "myStringParam", "this is a string param", ParamValidators.inArray(validStrings)); - setDefault(myIntParam_, 1); - setDefault(myDoubleParam_, 0.5); + myDoubleArrayParam_ = + new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param"); + + setDefault(myIntParam(), 1); + setDefault(myIntParam().w(1)); + setDefault(myDoubleParam(), 0.5); setDefault(myIntParam().w(1), myDoubleParam().w(0.5)); + setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); + setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0})); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala index 67c262d0f9d8d..928301523fba9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala +++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class IdentifiableSuite extends FunSuite { +class IdentifiableSuite extends SparkFunSuite { import IdentifiableSuite.Test diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java similarity index 95% rename from mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java rename to mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java index 640d2ec55e4e7..55787f8606d48 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml.classification; +package org.apache.spark.mllib.classification; import java.io.Serializable; import java.util.List; @@ -28,7 +28,6 @@ import org.junit.Test; import org.apache.spark.SparkConf; -import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java new file mode 100644 index 0000000000000..467a7a69e8f30 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +public class JavaGaussianMixtureSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaGaussianMixture"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runGaussianMixture() { + List points = Lists.newArrayList( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0) + ); + + JavaRDD data = sc.parallelize(points, 2); + GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234) + .run(data); + assertEquals(model.gaussians().length, 2); + JavaRDD predictions = model.predict(data); + predictions.first(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 96c2da169961f..581c033f08ebe 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -107,6 +107,10 @@ public void distributedLDAModel() { // Check: log probabilities assert(model.logLikelihood() < 0.0); assert(model.logPrior() < 0.0); + + // Check: topic distributions + JavaPairRDD topicDistributions = model.javaTopicDistributions(); + assertEquals(topicDistributions.count(), corpus.count()); } @Test diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java new file mode 100644 index 0000000000000..3b0e879eec77f --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.apache.spark.streaming.JavaTestUtils.*; + +import org.apache.spark.SparkConf; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +public class JavaStreamingKMeansSuite implements Serializable { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void javaAPI() { + List trainingBatch = Lists.newArrayList( + Vectors.dense(1.0), + Vectors.dense(0.0)); + JavaDStream training = + attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); + List> testBatch = Lists.newArrayList( + new Tuple2(10, Vectors.dense(1.0)), + new Tuple2(11, Vectors.dense(0.0))); + JavaPairDStream test = JavaPairDStream.fromJavaDStream( + attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + StreamingKMeans skmeans = new StreamingKMeans() + .setK(1) + .setDecayFactor(1.0) + .setInitialCenters(new Vector[]{Vectors.dense(1.0)}, new double[]{0.0}); + skmeans.trainOn(training); + JavaPairDStream prediction = skmeans.predictOnValues(test); + attachTestOutputStream(prediction.count()); + runStreams(ssc, 2, 2); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java new file mode 100644 index 0000000000000..62f7f26b7c98f --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat; + +import java.io.Serializable; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaStatisticsSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaStatistics"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void testCorr() { + JavaRDD x = sc.parallelize(Lists.newArrayList(1.0, 2.0, 3.0, 4.0)); + JavaRDD y = sc.parallelize(Lists.newArrayList(1.1, 2.2, 3.1, 4.3)); + + Double corr1 = Statistics.corr(x, y); + Double corr2 = Statistics.corr(x, y, "pearson"); + // Check default method + assertEquals(corr1, corr2); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 2b04a3034782e..29394fefcbc43 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,15 +17,17 @@ package org.apache.spark.ml +import scala.collection.JavaConverters._ + import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar.mock +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamMap import org.apache.spark.sql.DataFrame -class PipelineSuite extends FunSuite { +class PipelineSuite extends SparkFunSuite { abstract class MyModel extends Model[MyModel] @@ -81,4 +83,19 @@ class PipelineSuite extends FunSuite { pipeline.fit(dataset) } } + + test("pipeline model constructors") { + val transform0 = mock[Transformer] + val model1 = mock[MyModel] + + val stages = Array(transform0, model1) + val pipelineModel0 = new PipelineModel("pipeline0", stages) + assert(pipelineModel0.uid === "pipeline0") + assert(pipelineModel0.stages === stages) + + val stagesAsList = stages.toList.asJava + val pipelineModel1 = new PipelineModel("pipeline1", stagesAsList) + assert(pipelineModel1.uid === "pipeline1") + assert(pipelineModel1.stages === stages) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala index 17ddd335deb6d..512cffb1acb66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.attribute -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class AttributeGroupSuite extends FunSuite { +class AttributeGroupSuite extends SparkFunSuite { test("attribute group") { val attrs = Array( diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index ec9b717e41ce8..72b575d022547 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.ml.attribute -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ -class AttributeSuite extends FunSuite { +class AttributeSuite extends SparkFunSuite { test("default numeric attribute") { val attr: NumericAttribute = NumericAttribute.defaultAttr diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 3fdc66be8a314..ae40b0b8ff854 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint @@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeClassifierSuite.compareAPIs @@ -251,7 +250,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { */ } -private[ml] object DecisionTreeClassifierSuite extends FunSuite { +private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { /** * Train 2 decision trees on the given dataset, one using the old API and one using the new API. @@ -266,7 +265,7 @@ private[ml] object DecisionTreeClassifierSuite extends FunSuite { val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newTree = dt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldTreeAsNew = DecisionTreeClassificationModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index ea86867f1161a..1302da3c373ff 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext { +class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import GBTClassifierSuite.compareAPIs @@ -128,7 +127,7 @@ private object GBTClassifierSuite { val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) val newModel = gbt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 9f77d5f3efc55..a755cac3ea76e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ @transient var binaryDataset: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 770b56890fa45..1d04ccb509057 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS @@ -30,7 +29,7 @@ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { +class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ @transient var rdd: RDD[LabeledPoint] = _ @@ -94,6 +93,15 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features) ova.fit(datasetWithLabelMetadata) } + + test("SPARK-8049: OneVsRest shouldn't output temp columns") { + val logReg = new LogisticRegression() + .setMaxIter(1) + val ovr = new OneVsRest() + .setClassifier(logReg) + val output = ovr.fit(dataset).transform(dataset) + assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index cdbbacab8e0e3..eee9355a67be3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint @@ -32,7 +31,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[RandomForestClassifier]]. */ -class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestClassifierSuite.compareAPIs @@ -158,7 +157,7 @@ private object RandomForestClassifierSuite { data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newModel = rf.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 3ea7aad5274f2..36a1ac6b7996d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.ml.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext { +class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { test("Regression Evaluator: default params") { /** @@ -39,7 +38,7 @@ class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext { val dataset = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) - + /** * Using the following R code to load the data, train the model and evaluate metrics. * diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 8f6c6b39dc93b..7953bd0417191 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends FunSuite with MLlibTestSparkContext { +class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var data: Array[Double] = _ @@ -48,7 +47,7 @@ class BinarizerSuite extends FunSuite with MLlibTestSparkContext { test("Binarize continuous features with setter") { val threshold: Double = 0.2 - val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) + val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) val dataFrame: DataFrame = sqlContext.createDataFrame( data.zip(thresholdBinarized)).toDF("feature", "expected") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 0391bd8427c2c..507a8a7db24c7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -19,15 +19,13 @@ package org.apache.spark.ml.feature import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class BucketizerSuite extends FunSuite with MLlibTestSparkContext { +class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { test("Bucket continuous features, without -inf,inf") { // Check a set of valid feature values. @@ -110,7 +108,7 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext { } } -private object BucketizerSuite extends FunSuite { +private object BucketizerSuite extends SparkFunSuite { /** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */ def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = { require(feature >= splits.head) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 2e4beb0bfff63..7b2d70e644005 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -26,7 +25,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class HashingTFSuite extends FunSuite with MLlibTestSparkContext { +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { val hashingTF = new HashingTF diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index f85e85471617a..d83772e8be755 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class IDFSuite extends FunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index 9d09f24709e23..9f03470b7f328 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} -class NormalizerSuite extends FunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var data: Array[Vector] = _ @transient var dataFrame: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 056b9eda86bba..2e5036a844562 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col - -class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { +class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { def stringIndexed(): DataFrame = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) @@ -36,15 +36,16 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { indexer.transform(df) } - test("OneHotEncoder includeFirst = true") { + test("OneHotEncoder dropLast = false") { val transformed = stringIndexed() val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") + .setDropLast(false) val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => - val vec = r.get(1).asInstanceOf[Vector] + val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1), vec(2)) }.collect().toSet // a -> 0, b -> 2, c -> 1 @@ -53,22 +54,46 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { assert(output === expected) } - test("OneHotEncoder includeFirst = false") { + test("OneHotEncoder dropLast = true") { val transformed = stringIndexed() val encoder = new OneHotEncoder() - .setIncludeFirst(false) .setInputCol("labelIndex") .setOutputCol("labelVec") val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => - val vec = r.get(1).asInstanceOf[Vector] + val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1)) }.collect().toSet // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0), - (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)) + val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), + (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) assert(output === expected) } + test("input column with ML attribute") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoder() + .setInputCol("size") + .setOutputCol("encoded") + val output = encoder.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1)) + } + + test("input column without ML attribute") { + val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") + val encoder = new OneHotEncoder() + .setInputCol("index") + .setOutputCol("encoded") + val output = encoder.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index aa230ca073d5b..feca866cd711d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite import org.scalatest.exceptions.TestFailedException +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext { +class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext { test("Polynomial expansion with default parameter") { val data = Array( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 89c2fe45573aa..5f557e16e5150 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.mllib.util.MLlibTestSparkContext -class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { +class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { test("StringIndexer") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) @@ -61,4 +60,12 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) } + + test("StringIndexerModel should keep silent if the input column does not exist.") { + val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) + .setInputCol("label") + .setOutputCol("labelIndex") + val df = sqlContext.range(0L, 10L) + assert(indexerModel.transform(df).eq(df)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index eabda089d0988..ac279cb3215c2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -19,15 +19,14 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { +class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.ml.feature.RegexTokenizerSuite._ test("RegexTokenizer") { @@ -60,7 +59,7 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { } } -object RegexTokenizerSuite extends FunSuite { +object RegexTokenizerSuite extends SparkFunSuite { def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { t.transform(dataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index d0cd62c5e4864..489abb5af7130 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col -class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { +class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { test("assemble") { import org.apache.spark.ml.feature.VectorAssembler.assemble @@ -61,4 +61,39 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0))) } } + + test("ML attributes") { + val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") + val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) + val user = new AttributeGroup("user", Array( + NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"), + NumericAttribute.defaultAttr.withName("salary"))) + val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0))) + val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") + .select( + col("browser").as("browser", browser.toMetadata()), + col("hour").as("hour", hour.toMetadata()), + col("count"), // "count" is an integer column without ML attribute + col("user").as("user", user.toMetadata()), + col("ad")) // "ad" is a vector column without ML attribute + val assembler = new VectorAssembler() + .setInputCols(Array("browser", "hour", "count", "user", "ad")) + .setOutputCol("features") + val output = assembler.transform(df) + val schema = output.schema + val features = AttributeGroup.fromStructField(schema("features")) + assert(features.size === 7) + val browserOut = features.getAttr(0) + assert(browserOut === browser.withIndex(0).withName("browser")) + val hourOut = features.getAttr(1) + assert(hourOut === hour.withIndex(1).withName("hour")) + val countOut = features.getAttr(2) + assert(countOut === NumericAttribute.defaultAttr.withName("count").withIndex(2)) + val userGenderOut = features.getAttr(3) + assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3)) + val userSalaryOut = features.getAttr(4) + assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4)) + assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5)) + assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index b11b029c6343e..06affc7305cf5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.ml.feature import scala.beans.{BeanInfo, BeanProperty} -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { import VectorIndexerSuite.FeatureData diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 43a09cc418703..94ebc3aebfa37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} -class Word2VecSuite extends FunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { test("Word2Vec") { val sqlContext = new SQLContext(sc) @@ -35,9 +34,9 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext { val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val codes = Map( - "a" -> Array(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451), - "b" -> Array(1.0309048891067505,-1.29472815990448,0.22276712954044342), - "c" -> Array(-0.08456747233867645,0.5137411952018738,0.11731560528278351) + "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451), + "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342), + "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351) ) val expected = doc.map { sentence => diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 1505ad872536b..778abcba22c10 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -19,8 +19,7 @@ package org.apache.spark.ml.impl import scala.collection.JavaConverters._ -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.tree._ @@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, DataFrame} -private[ml] object TreeTests extends FunSuite { +private[ml] object TreeTests extends SparkFunSuite { /** * Convert the given data to a DataFrame, and set the features and label metadata. diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index d270ad7613af1..96094d7a099aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.param -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ParamsSuite extends FunSuite { +class ParamsSuite extends SparkFunSuite { test("param") { val solver = new TestParams() @@ -27,7 +27,7 @@ class ParamsSuite extends FunSuite { import solver.{maxIter, inputCol} assert(maxIter.name === "maxIter") - assert(maxIter.doc === "max number of iterations (>= 0)") + assert(maxIter.doc === "maximum number of iterations (>= 0)") assert(maxIter.parent === uid) assert(maxIter.toString === s"${uid}__maxIter") assert(!maxIter.isValid(-1)) @@ -36,7 +36,7 @@ class ParamsSuite extends FunSuite { solver.setMaxIter(5) assert(solver.explainParam(maxIter) === - "maxIter: max number of iterations (>= 0) (default: 10, current: 5)") + "maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)") assert(inputCol.toString === s"${uid}__inputCol") @@ -120,7 +120,7 @@ class ParamsSuite extends FunSuite { intercept[NoSuchElementException](solver.getInputCol) assert(solver.explainParam(maxIter) === - "maxIter: max number of iterations (>= 0) (default: 10, current: 100)") + "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)") assert(solver.explainParams() === Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n")) @@ -135,7 +135,7 @@ class ParamsSuite extends FunSuite { intercept[IllegalArgumentException] { solver.validateParams() } - solver.validateParams(ParamMap(inputCol -> "input")) + solver.copy(ParamMap(inputCol -> "input")).validateParams() solver.setInputCol("input") assert(solver.isSet(inputCol)) assert(solver.isDefined(inputCol)) @@ -202,7 +202,7 @@ class ParamsSuite extends FunSuite { } } -object ParamsSuite extends FunSuite { +object ParamsSuite extends SparkFunSuite { /** * Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala index ca18fa1ad3c15..eb5408d3fee7c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.ml.param.shared -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.Params -class SharedParamsSuite extends FunSuite { +class SharedParamsSuite extends SparkFunSuite { test("outputCol") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 9a35555e52b90..2e5cfe7027eb6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -25,9 +25,8 @@ import scala.collection.mutable.ArrayBuffer import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.scalatest.FunSuite -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -36,7 +35,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.Utils -class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { private var tempDir: File = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 1196a772dfdd4..33aa9d0d62343 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, @@ -28,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeRegressorSuite.compareAPIs @@ -69,7 +68,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { // TODO: test("model save/load") SPARK-6725 } -private[ml] object DecisionTreeRegressorSuite extends FunSuite { +private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { /** * Train 2 decision trees on the given dataset, one using the old API and one using the new API. @@ -83,7 +82,7 @@ private[ml] object DecisionTreeRegressorSuite extends FunSuite { val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newTree = dt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldTreeAsNew = DecisionTreeRegressionModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 40e7e3273e965..98fb3d3f5f22c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[GBTRegressor]]. */ -class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext { +class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import GBTRegressorSuite.compareAPIs @@ -129,7 +128,7 @@ private object GBTRegressorSuite { val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = gbt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTRegressionModel.fromOld( oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 50a78631fa6d6..732e2c42be144 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.DenseVector import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 3efffbb763b78..b24ecaa57c89b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[RandomForestRegressor]]. */ -class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestRegressorSuite.compareAPIs @@ -98,7 +97,7 @@ class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext { */ } -private object RandomForestRegressorSuite extends FunSuite { +private object RandomForestRegressorSuite extends SparkFunSuite { /** * Train 2 models on the given dataset, one using the old API and one using the new API. @@ -114,7 +113,7 @@ private object RandomForestRegressorSuite extends FunSuite { data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = rf.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestRegressionModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 05313d440fbf6..9b3619f0046ea 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -17,15 +17,19 @@ package org.apache.spark.ml.tuning -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, DataFrame} +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.types.StructType -class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { +class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ @@ -52,5 +56,56 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) + assert(cvModel.avgMetrics.length === lrParamMaps.length) + } + + test("validateParams should check estimatorParamMaps") { + import CrossValidatorSuite._ + + val est = new MyEstimator("est") + val eval = new MyEvaluator + val paramMaps = new ParamGridBuilder() + .addGrid(est.inputCol, Array("input1", "input2")) + .build() + + val cv = new CrossValidator() + .setEstimator(est) + .setEstimatorParamMaps(paramMaps) + .setEvaluator(eval) + + cv.validateParams() // This should pass. + + val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") + cv.setEstimatorParamMaps(invalidParamMaps) + intercept[IllegalArgumentException] { + cv.validateParams() + } + } +} + +object CrossValidatorSuite { + + abstract class MyModel extends Model[MyModel] + + class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { + + override def validateParams(): Unit = require($(inputCol).nonEmpty) + + override def fit(dataset: DataFrame): MyModel = { + throw new UnsupportedOperationException + } + + override def transformSchema(schema: StructType): StructType = { + throw new UnsupportedOperationException + } + } + + class MyEvaluator extends Evaluator { + + override def evaluate(dataset: DataFrame): Double = { + throw new UnsupportedOperationException + } + + override val uid: String = "eval" } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala index 20aa100112bfe..810b70049ec15 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.ml.tuning import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.{ParamMap, TestParams} -class ParamGridBuilderSuite extends FunSuite { +class ParamGridBuilderSuite extends SparkFunSuite { val solver = new TestParams() import solver.{inputCol, maxIter} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index a629dba8a426f..59944416d96a6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.api.python -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.recommendation.Rating -class PythonMLLibAPISuite extends FunSuite { +class PythonMLLibAPISuite extends SparkFunSuite { SerDe.initialize() @@ -84,7 +83,7 @@ class PythonMLLibAPISuite extends FunSuite { val smt = new SparseMatrix( 3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9), - isTransposed=true) + isTransposed = true) val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix] assert(smt.toArray === nsmt.toArray) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 966811a5a3263..e8f3d0c4db20a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConversions._ import scala.util.Random import scala.util.control.Breaks._ -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -119,7 +119,7 @@ object LogisticRegressionSuite { } // Preventing the overflow when we compute the probability val maxMargin = margins.max - if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin + if (maxMargin > 0) for (i <- 0 until nClasses) margins(i) -= maxMargin // Computing the probabilities for each class from the margins. val norm = { @@ -130,7 +130,7 @@ object LogisticRegressionSuite { } temp } - for (i <-0 until nClasses) probs(i) /= norm + for (i <- 0 until nClasses) probs(i) /= norm // Compute the cumulative probability so we can generate a random number and assign a label. for (i <- 1 until nClasses) probs(i) += probs(i - 1) @@ -169,7 +169,7 @@ object LogisticRegressionSuite { } -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], input: Seq[LabeledPoint], @@ -541,7 +541,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M } -class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction using SGD optimizer") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index c111a78a55806..f7fc8730606af 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -21,9 +21,8 @@ import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import breeze.stats.distributions.{Multinomial => BrzMultinomial} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -86,7 +85,7 @@ object NaiveBayesSuite { pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial) } -class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { import NaiveBayes.{Multinomial, Bernoulli} @@ -163,7 +162,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { val theta = Array( Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 - Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 ).map(_.map(math.log)) val testData = NaiveBayesSuite.generateNaiveBayesInput( @@ -286,7 +285,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } -class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { +class NaiveBayesClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 10 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 6de098b383ba3..b1d78cba9e3dc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -21,9 +21,8 @@ import scala.collection.JavaConversions._ import scala.util.Random import org.jblas.DoubleMatrix -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -46,7 +45,7 @@ object SVMSuite { nPoints: Int, seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) + val weightsMat = new DoubleMatrix(1, weights.length, weights : _*) val x = Array.fill[Array[Double]](nPoints)( Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) val y = x.map { xi => @@ -62,7 +61,7 @@ object SVMSuite { } -class SVMSuite extends FunSuite with MLlibTestSparkContext { +class SVMSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -91,7 +90,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD) val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -117,7 +116,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -127,8 +126,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD) - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -145,7 +144,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val initialB = -1.0 val initialC = -1.0 @@ -159,8 +158,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD, initialWeights) - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData,2) + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -177,7 +176,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val testRDD = sc.parallelize(testData, 2) val testRDDInvalid = testRDD.map { lp => @@ -229,7 +228,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { } } -class SVMClusterSuite extends FunSuite with LocalClusterSparkContext { +class SVMClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index 5683b55e8500a..e98b61e13e21f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -19,15 +19,14 @@ package org.apache.spark.mllib.classification import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.TestSuiteBase -class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { +class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 30000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index f356ffa3e3a26..b218d72f1268a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.clustering -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vectors, Matrices} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { +class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { test("single cluster") { val data = sc.parallelize(Array( Vectors.dense(6.0, 9.0), @@ -47,7 +46,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { } } - + test("two clusters") { val data = sc.parallelize(GaussianTestData.data) @@ -63,7 +62,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { val Ew = Array(1.0 / 3.0, 2.0 / 3.0) val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) - + val gmm = new GaussianMixture() .setK(2) .setInitialModel(initialGmm) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 0f2b26d462ad2..0dbbd7127444f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.mllib.clustering import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class KMeansSuite extends FunSuite with MLlibTestSparkContext { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM} @@ -75,7 +74,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { val center = Vectors.dense(1.0, 2.0, 3.0) // Make sure code runs. - var model = KMeans.train(data, k=2, maxIterations=1) + var model = KMeans.train(data, k = 2, maxIterations = 1) assert(model.clusterCenters.size === 2) } @@ -87,7 +86,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { 2) // Make sure code runs. - var model = KMeans.train(data, k=3, maxIterations=1) + var model = KMeans.train(data, k = 3, maxIterations = 1) assert(model.clusterCenters.size === 3) } @@ -281,7 +280,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { } } -object KMeansSuite extends FunSuite { +object KMeansSuite extends SparkFunSuite { def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = { val singlePoint = isSparse match { case true => @@ -305,7 +304,7 @@ object KMeansSuite extends FunSuite { } } -class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext { +class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index d5b7d96335744..406affa25539d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseMatrix => BDM} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class LDASuite extends FunSuite with MLlibTestSparkContext { +class LDASuite extends SparkFunSuite with MLlibTestSparkContext { import LDASuite._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 6d6fe6fe46bab..19e65f1b53ab5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { +class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.PowerIterationClustering._ @@ -58,7 +56,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext predictions(a.cluster) += a.id } assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) - + val model2 = new PowerIterationClustering() .setK(2) .setInitializationMode("degree") @@ -94,11 +92,13 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext */ val similarities = Seq[(Long, Long, Double)]( (0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (2, 3, 1.0)) + // scalastyle:off val expected = Array( Array(0.0, 1.0/3.0, 1.0/3.0, 1.0/3.0), Array(1.0/2.0, 0.0, 1.0/2.0, 0.0), Array(1.0/3.0, 1.0/3.0, 0.0, 1.0/3.0), Array(1.0/2.0, 0.0, 1.0/2.0, 0.0)) + // scalastyle:on val w = normalize(sc.parallelize(similarities, 2)) w.edges.collect().foreach { case Edge(i, j, x) => assert(x ~== expected(i.toInt)(j.toInt) absTol 1e-14) @@ -128,7 +128,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext } } -object PowerIterationClusteringSuite extends FunSuite { +object PowerIterationClusteringSuite extends SparkFunSuite { def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = { val assignments = sc.parallelize( (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k)))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index f90025d535e45..ac01622b8a089 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.clustering -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.TestSuiteBase import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.random.XORShiftRandom -class StreamingKMeansSuite extends FunSuite with TestSuiteBase { +class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { override def maxWaitTimeMillis: Int = 30000 @@ -133,6 +132,13 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { assert(math.abs(c1) ~== 0.8 absTol 0.6) } + test("SPARK-7946 setDecayFactor") { + val kMeans = new StreamingKMeans() + assert(kMeans.decayFactor === 1.0) + kMeans.setDecayFactor(2.0) + assert(kMeans.decayFactor === 2.0) + } + def StreamingKMeansDataGenerator( numPoints: Int, numBatches: Int, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 79847633ff0dc..87ccc7eda44ea 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext { +class AreaUnderCurveSuite extends SparkFunSuite with MLlibTestSparkContext { test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index e0224f960cc43..99d52fabc5309 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext { +class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 7dc4f3cfbc4e4..d55bc8c3ec09f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Matrices import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext { +class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala index 2537dd62c92f2..f3b19aeb42f84 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext { +class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multilabel evaluation metrics") { /* * Documents true labels (5x class0, 3x class1, 4x class2): diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index 609eed983ff4e..c0924a213a844 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext { +class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Ranking metrics: map, ndcg") { val predictionAndLabels = sc.parallelize( Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 670b4c34e6095..9de2bdb6d7246 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { +class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("regression metrics") { val predictionAndObservations = sc.parallelize( - Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)), 2) + Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5, "explained variance regression score mismatch") @@ -39,7 +38,7 @@ class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { test("regression metrics with complete fitting") { val predictionAndObservations = sc.parallelize( - Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)), 2) + Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) assert(metrics.explainedVariance ~== 1.0 absTol 1E-5, "explained variance regression score mismatch") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index 747f5914598ec..889727fb55823 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext -class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext { +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { /* * Contingency tables diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala index f3a482abda873..ccbf8a91cdd37 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class ElementwiseProductSuite extends FunSuite with MLlibTestSparkContext { +class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext { test("elementwise (hadamard) product should properly apply vector to dense data set") { val denseData = Array( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala index 0c4dfb7b97c7f..cf279c02334e9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -class HashingTFSuite extends FunSuite with MLlibTestSparkContext { +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("hashing tf on a single doc") { val hashingTF = new HashingTF(1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 0a5cad7caf8e4..21163633051e5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class IDFSuite extends FunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { test("idf") { val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index 5c4af2b99e68b..34122d6ed2e95 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - import breeze.linalg.{norm => brzNorm} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class NormalizerSuite extends FunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { val data = Array( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala index 758af588f1c69..e57f49191378f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.util.MLlibTestSparkContext -class PCASuite extends FunSuite with MLlibTestSparkContext { +class PCASuite extends SparkFunSuite with MLlibTestSparkContext { private val data = Array( Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 7f94564b2a3ae..6ab2fa6770123 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD -class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { // When the input data is all constant, the variance is zero. The standardization against // zero variance is not well-defined, but we decide to just set it into zero here. @@ -360,7 +359,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { } withClue("model needs std and mean vectors to be equal size when both are provided") { intercept[IllegalArgumentException] { - val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0,1.0)) + val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0, 1.0)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 98a98a7599bcb..b6818369208d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class Word2VecSuite extends FunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { // TODO: add more tests diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index bd5b9cc3afa10..66ae3543ecc4e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -16,11 +16,10 @@ */ package org.apache.spark.mllib.fpm -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { +class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { test("FP-Growth using String type") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala index 04017f67c311d..a56d7b3579213 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.mllib.fpm import scala.language.existentials -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class FPTreeSuite extends FunSuite with MLlibTestSparkContext { +class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("add transaction") { val tree = new FPTree[String] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index 699f009f0f2ec..d34888af2d73b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -17,18 +17,16 @@ package org.apache.spark.mllib.impl -import org.scalatest.FunSuite - import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext { +class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { import PeriodicGraphCheckpointerSuite._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 64ecd12ea7ded..b0f3f71113c57 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.linalg.BLAS._ -class BLASSuite extends FunSuite { +class BLASSuite extends SparkFunSuite { test("copy") { val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0)) @@ -140,7 +139,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dA) assert(dA ~== expected absTol 1e-15) - + val dB = new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0)) @@ -149,7 +148,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dB) } } - + val dC = new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8)) @@ -158,7 +157,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dC) } } - + val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5)) withClue("Size of vector must match the rank of matrix") { @@ -256,13 +255,13 @@ class BLASSuite extends FunSuite { val dA = new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) - + val dA2 = new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true) val sA2 = new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0), true) - + val dx = new DenseVector(Array(1.0, 2.0, 3.0)) val sx = dx.toSparse val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) @@ -271,7 +270,7 @@ class BLASSuite extends FunSuite { assert(sA.multiply(dx) ~== expected absTol 1e-15) assert(dA.multiply(sx) ~== expected absTol 1e-15) assert(sA.multiply(sx) ~== expected absTol 1e-15) - + val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) val y2 = y1.copy val y3 = y1.copy @@ -288,7 +287,7 @@ class BLASSuite extends FunSuite { val y14 = y1.copy val y15 = y1.copy val y16 = y1.copy - + val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) @@ -296,42 +295,42 @@ class BLASSuite extends FunSuite { gemv(1.0, sA, dx, 2.0, y2) gemv(1.0, dA, sx, 2.0, y3) gemv(1.0, sA, sx, 2.0, y4) - + gemv(1.0, dA2, dx, 2.0, y5) gemv(1.0, sA2, dx, 2.0, y6) gemv(1.0, dA2, sx, 2.0, y7) gemv(1.0, sA2, sx, 2.0, y8) - + gemv(2.0, dA, dx, 2.0, y9) gemv(2.0, sA, dx, 2.0, y10) gemv(2.0, dA, sx, 2.0, y11) gemv(2.0, sA, sx, 2.0, y12) - + gemv(2.0, dA2, dx, 2.0, y13) gemv(2.0, sA2, dx, 2.0, y14) gemv(2.0, dA2, sx, 2.0, y15) gemv(2.0, sA2, sx, 2.0, y16) - + assert(y1 ~== expected2 absTol 1e-15) assert(y2 ~== expected2 absTol 1e-15) assert(y3 ~== expected2 absTol 1e-15) assert(y4 ~== expected2 absTol 1e-15) - + assert(y5 ~== expected2 absTol 1e-15) assert(y6 ~== expected2 absTol 1e-15) assert(y7 ~== expected2 absTol 1e-15) assert(y8 ~== expected2 absTol 1e-15) - + assert(y9 ~== expected3 absTol 1e-15) assert(y10 ~== expected3 absTol 1e-15) assert(y11 ~== expected3 absTol 1e-15) assert(y12 ~== expected3 absTol 1e-15) - + assert(y13 ~== expected3 absTol 1e-15) assert(y14 ~== expected3 absTol 1e-15) assert(y15 ~== expected3 absTol 1e-15) assert(y16 ~== expected3 absTol 1e-15) - + withClue("columns of A don't match the rows of B") { intercept[Exception] { gemv(1.0, dA.transpose, dx, 2.0, y1) @@ -346,12 +345,12 @@ class BLASSuite extends FunSuite { gemv(1.0, sA.transpose, sx, 2.0, y1) } } - + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - + val dATT = dAT.transpose val sATT = sAT.transpose diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 2031032373971..dc04258e41d27 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM} -class BreezeMatrixConversionSuite extends FunSuite { +import org.apache.spark.SparkFunSuite + +class BreezeMatrixConversionSuite extends SparkFunSuite { test("dense matrix to breeze") { val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) val breeze = mat.toBreeze.asInstanceOf[BDM[Double]] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala index 8abdac72902c6..3772c9235ad3a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import org.apache.spark.SparkFunSuite + /** * Test Breeze vector conversions. */ -class BreezeVectorConversionSuite extends FunSuite { +class BreezeVectorConversionSuite extends SparkFunSuite { val arr = Array(0.1, 0.2, 0.3, 0.4) val n = 20 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 86119ec38101e..8dbb70f5d1c4c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.mllib.linalg import java.util.Random import org.mockito.Mockito.when -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar._ import scala.collection.mutable.{Map => MutableMap} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ -class MatricesSuite extends FunSuite { +class MatricesSuite extends SparkFunSuite { test("dense matrix construction") { val m = 3 val n = 2 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 24755e9ff46fc..c4ae0a16f7c04 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -20,12 +20,11 @@ package org.apache.spark.mllib.linalg import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.util.TestingUtils._ -class VectorsSuite extends FunSuite { +class VectorsSuite extends SparkFunSuite { val arr = Array(0.1, 0.0, 0.3, 0.4) val n = 4 @@ -215,13 +214,13 @@ class VectorsSuite extends FunSuite { val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze) - // SparseVector vs. SparseVector - assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) + // SparseVector vs. SparseVector + assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) // DenseVector vs. SparseVector assert(Vectors.sqdist(denseVector1, sparseVector2) ~== squaredDist relTol 1E-8) // DenseVector vs. DenseVector assert(Vectors.sqdist(denseVector1, denseVector2) ~== squaredDist relTol 1E-8) - } + } } test("foreachActive") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index 949d1c9939570..93fe04c139b9a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -20,14 +20,13 @@ package org.apache.spark.mllib.linalg.distributed import java.{util => ju} import breeze.linalg.{DenseMatrix => BDM} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { +class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 5 val n = 4 @@ -57,11 +56,13 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { val random = new ju.Random() // This should generate a 4x4 grid of 1x2 blocks. val part0 = GridPartitioner(4, 7, suggestedNumPartitions = 12) + // scalastyle:off val expected0 = Array( Array(0, 0, 4, 4, 8, 8, 12), Array(1, 1, 5, 5, 9, 9, 13), Array(2, 2, 6, 6, 10, 10, 14), Array(3, 3, 7, 7, 11, 11, 15)) + // scalastyle:on for (i <- 0 until 4; j <- 0 until 7) { assert(part0.getPartition((i, j)) === expected0(i)(j)) assert(part0.getPartition((i, j, random.nextInt())) === expected0(i)(j)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index 04b36a9ef9990..f3728cd036a3f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.linalg.distributed -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors -class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { +class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 5 val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 2ab53cc13db71..4a7b99a976f0a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.linalg.distributed -import org.scalatest.FunSuite - import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrices, Vectors} -class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { +class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 4 val n = 3 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 27bb19f472e1e..b6cb53d0c743e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.mllib.linalg.distributed import scala.util.Random import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} -class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { +class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 4 val n = 3 @@ -240,7 +240,7 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { } } -class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext { +class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext { var mat: RowMatrix = _ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 86481c6e66200..a5a59e9fad5ae 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.mllib.optimization import scala.collection.JavaConversions._ import scala.util.Random -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -42,7 +43,7 @@ object GradientDescentSuite { offset: Double, scale: Double, nPoints: Int, - seed: Int): Seq[LabeledPoint] = { + seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) @@ -61,7 +62,7 @@ object GradientDescentSuite { } } -class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { test("Assert the loss is decreasing.") { val nPoints = 10000 @@ -140,7 +141,7 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc } } -class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext { +class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index c8f2adcf155a7..d07b9d5b89227 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization import scala.util.Random -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { val nPoints = 10000 val A = 2.0 @@ -229,7 +230,7 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { } } -class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext { +class LBFGSClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small") { val m = 10 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index 22855e4e8f247..d8f9b8c33963d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.optimization import scala.util.Random -import org.scalatest.FunSuite - import org.jblas.{DoubleMatrix, SimpleBlas} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ -class NNLSSuite extends FunSuite { +class NNLSSuite extends SparkFunSuite { /** Generate an NNLS problem whose optimal solution is the all-ones vector. */ def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = { val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*) @@ -68,12 +67,14 @@ class NNLSSuite extends FunSuite { test("NNLS: nonnegativity constraint active") { val n = 5 + // scalastyle:off val ata = new DoubleMatrix(Array( Array( 4.377, -3.531, -1.306, -0.139, 3.418), Array(-3.531, 4.344, 0.934, 0.305, -2.140), Array(-1.306, 0.934, 2.644, -0.203, -0.170), Array(-0.139, 0.305, -0.203, 5.883, 1.428), Array( 3.418, -2.140, -0.170, 1.428, 4.684))) + // scalastyle:on val atb = new DoubleMatrix(Array(-1.632, 2.115, 1.094, -1.025, -0.636)) val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala index 0b646cf1ce6c4..4c6e76e47419b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel import org.dmg.pmml.RegressionNormalizationMethodType -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.util.LinearDataGenerator -class BinaryClassificationPMMLModelExportSuite extends FunSuite { +class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite { test("logistic regression PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) @@ -53,13 +53,13 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite { // ensure logistic regression has normalization method set to LOGIT assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT) } - + test("linear SVM PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) - + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) - + // assert that the PMML format is as expected assert(svmModelExport.isInstanceOf[PMMLModelExport]) val pmml = svmModelExport.getPmml @@ -80,5 +80,5 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite { // ensure linear SVM has normalization method set to NONE assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE) } - + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala index f9afbd888dfc5..1d32309481787 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator -class GeneralizedLinearPMMLModelExportSuite extends FunSuite { +class GeneralizedLinearPMMLModelExportSuite extends SparkFunSuite { test("linear regression PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala index b985d0446d7b0..b3f9750afa730 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.ClusteringModel -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors -class KMeansPMMLModelExportSuite extends FunSuite { +class KMeansPMMLModelExportSuite extends SparkFunSuite { test("KMeansPMMLModelExport generate PMML format") { val clusterCenters = Array( @@ -45,5 +45,5 @@ class KMeansPMMLModelExportSuite extends FunSuite { val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) } - + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index f28a4ac8ad01f..af49450961750 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.pmml.export -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel} import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator -class PMMLModelExportFactorySuite extends FunSuite { +class PMMLModelExportFactorySuite extends SparkFunSuite { test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { val clusterCenters = Array( @@ -61,25 +60,25 @@ class PMMLModelExportFactorySuite extends FunSuite { test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport " + "when passing a LogisticRegressionModel or SVMModel") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) - + val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) val logisticRegressionModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) - + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) } - + test("PMMLModelExportFactory throw IllegalArgumentException " + "when passing a Multinomial Logistic Regression") { /** 3 classes, 2 features */ val multiclassLogisticRegressionModel = new LogisticRegressionModel( - weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3) - + intercept[IllegalArgumentException] { PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index b792d819fdabb..a5ca1518f82f5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.random import scala.math -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter // TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged -class RandomDataGeneratorSuite extends FunSuite { +class RandomDataGeneratorSuite extends SparkFunSuite { def apiChecks(gen: RandomDataGenerator[Double]) { // resetting seed should generate the same sequence of random numbers diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 63f2ea916d457..413db2000d6d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.random import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD} @@ -34,7 +33,7 @@ import org.apache.spark.util.StatCounter * * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged */ -class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable { +class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Serializable { def testGeneratedRDD(rdd: RDD[Double], expectedSize: Long, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala index 57216e8eb4a55..10f5a2be48f7c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.rdd -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ -class MLPairRDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { +class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("topByKey") { val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5, 1), (3, 5)), 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 6d6c0aa5be812..bc64172614830 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.rdd -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ -class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { +class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("sliding") { val data = 0 until 6 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index b3798940ddc38..05b87728d6fdb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConversions._ import scala.math.abs import scala.util.Random -import org.scalatest.FunSuite import org.jblas.DoubleMatrix +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.storage.StorageLevel @@ -84,7 +84,7 @@ object ALSSuite { } -class ALSSuite extends FunSuite with MLlibTestSparkContext { +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { test("rank-1 matrices") { testALS(50, 100, 1, 15, 0.7, 0.3) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index 2c92866f3893d..2c8ed057a516a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.recommendation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { +class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkContext { val rank = 2 var userFeatures: RDD[(Int, Array[Double])] = _ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala index 3b38bdf5ef5eb..ea4f2865757c1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.regression -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { private def round(d: Double) = { math.round(d * 100).toDouble / 100 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index 110c44a7193fd..d8364a06de4da 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -class LabeledPointSuite extends FunSuite { +class LabeledPointSuite extends SparkFunSuite { test("parse labeled points") { val points = Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index c9f5dc069ef2e..08a152ffc7a23 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -32,7 +31,7 @@ private object LassoSuite { val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class LassoSuite extends FunSuite with MLlibTestSparkContext { +class LassoSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -67,11 +66,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationData = LinearDataGenerator + .generateLinearInput(A, Array[Double](B, C), nPoints, 17) .map { case LabeledPoint(label, features) => LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) } - val validationRDD = sc.parallelize(validationData, 2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -110,11 +110,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationData = LinearDataGenerator + .generateLinearInput(A, Array[Double](B, C), nPoints, 17) .map { case LabeledPoint(label, features) => LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) } - val validationRDD = sc.parallelize(validationData,2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -141,7 +142,7 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { } } -class LassoClusterSuite extends FunSuite with LocalClusterSparkContext { +class LassoClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 3781931c2f819..f88a1c33c9f7c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -32,7 +31,7 @@ private object LinearRegressionSuite { val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -150,7 +149,7 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { } } -class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class LinearRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index d6c93cc0e49cd..7a781fee634c8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.regression import scala.util.Random import org.jblas.DoubleMatrix -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -33,7 +33,7 @@ private object RidgeRegressionSuite { val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { +class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = { predictions.zip(input).map { case (prediction, expected) => @@ -101,7 +101,7 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { } } -class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 26604dbe6c1ef..9a379406d5061 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.mllib.regression import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.TestSuiteBase -class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { +class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 20000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index d20a09b4b4925..c292ced75e870 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, SpearmanCorrelation} import org.apache.spark.mllib.util.MLlibTestSparkContext -class CorrelationSuite extends FunSuite with MLlibTestSparkContext { +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext { // test input data val xData = Array(1.0, 0.0, -2.0) @@ -96,11 +95,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext { val X = sc.parallelize(data) val defaultMat = Statistics.corr(X) val pearsonMat = Statistics.corr(X, "pearson") + // scalastyle:off val expected = BDM( (1.00000000, 0.05564149, Double.NaN, 0.4004714), (0.05564149, 1.00000000, Double.NaN, 0.9135959), (Double.NaN, Double.NaN, 1.00000000, Double.NaN), - (0.40047142, 0.91359586, Double.NaN,1.0000000)) + (0.40047142, 0.91359586, Double.NaN, 1.0000000)) + // scalastyle:on assert(matrixApproxEqual(defaultMat.toBreeze, expected)) assert(matrixApproxEqual(pearsonMat.toBreeze, expected)) } @@ -108,11 +109,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext { test("corr(X) spearman") { val X = sc.parallelize(data) val spearmanMat = Statistics.corr(X, "spearman") + // scalastyle:off val expected = BDM( (1.0000000, 0.1054093, Double.NaN, 0.4000000), (0.1054093, 1.0000000, Double.NaN, 0.9486833), (Double.NaN, Double.NaN, 1.00000000, Double.NaN), (0.4000000, 0.9486833, Double.NaN, 1.0000000)) + // scalastyle:on assert(matrixApproxEqual(spearmanMat.toBreeze, expected)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 15418e6035965..b084a5fb4313f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.mllib.stat import java.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.test.ChiSqTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext { +class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { test("chi squared pearson goodness of fit") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala index 14bb1cebf0b8f..5feccdf33681a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala @@ -18,19 +18,19 @@ package org.apache.spark.mllib.stat import org.apache.commons.math3.distribution.NormalDistribution -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class KernelDensitySuite extends FunSuite with MLlibTestSparkContext { +class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext { test("kernel density single sample") { val rdd = sc.parallelize(Array(5.0)) val evaluationPoints = Array(5.0, 6.0) val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) val normal = new NormalDistribution(5.0, 3.0) val acceptableErr = 1e-6 - assert(densities(0) - normal.density(5.0) < acceptableErr) - assert(densities(0) - normal.density(6.0) < acceptableErr) + assert(math.abs(densities(0) - normal.density(5.0)) < acceptableErr) + assert(math.abs(densities(1) - normal.density(6.0)) < acceptableErr) } test("kernel density multiple samples") { @@ -40,7 +40,9 @@ class KernelDensitySuite extends FunSuite with MLlibTestSparkContext { val normal1 = new NormalDistribution(5.0, 3.0) val normal2 = new NormalDistribution(10.0, 3.0) val acceptableErr = 1e-6 - assert(densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2 < acceptableErr) - assert(densities(0) - (normal1.density(6.0) + normal2.density(6.0)) / 2 < acceptableErr) + assert(math.abs( + densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2) < acceptableErr) + assert(math.abs( + densities(1) - (normal1.density(6.0) + normal2.density(6.0)) / 2) < acceptableErr) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 23b0eec865de6..07efde4f5e6dc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.TestingUtils._ -class MultivariateOnlineSummarizerSuite extends FunSuite { +class MultivariateOnlineSummarizerSuite extends SparkFunSuite { test("basic error handing") { val summarizer = new MultivariateOnlineSummarizer diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index fac2498e4dcb3..aa60deb665aeb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -17,49 +17,48 @@ package org.apache.spark.mllib.stat.distribution -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{ Vectors, Matrices } import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext { +class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext { test("univariate") { val x1 = Vectors.dense(0.0) val x2 = Vectors.dense(1.5) - + val mu = Vectors.dense(0.0) val sigma1 = Matrices.dense(1, 1, Array(1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5) - + val sigma2 = Matrices.dense(1, 1, Array(4.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5) } - + test("multivariate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) - + val mu = Vectors.dense(0.0, 0.0) val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5) - + val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5) } - + test("multivariate degenerate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) - + val mu = Vectors.dense(0.0, 0.0) val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0)) val dist = new MultivariateGaussian(mu, sigma) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index ce983eb27fa35..356d957f15909 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ @@ -34,7 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils -class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { ///////////////////////////////////////////////////////////////////////////// // Tests examining individual elements of training @@ -859,7 +858,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } } -object DecisionTreeSuite extends FunSuite { +object DecisionTreeSuite extends SparkFunSuite { def validateClassifier( model: DecisionTreeModel, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 55b0bac7d49fe..84dd3b342d4c0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.mllib.tree -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} @@ -32,7 +31,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[GradientBoostedTrees]]. */ -class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { +class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext { test("Regression with continuous features: SquaredError") { GradientBoostedTreesSuite.testCombinations.foreach { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 92b498580af03..49aff21fe7914 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.tree -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. */ -class ImpuritySuite extends FunSuite with MLlibTestSparkContext { +class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext { test("Gini impurity does not support negative labels") { val gini = new GiniAggregator(2) intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 4ed66953cb628..e6df5d974bf36 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.tree import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ @@ -35,7 +34,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[RandomForest]]. */ -class RandomForestSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala index b184e936672ca..9d756da410325 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.tree.impl -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.EnsembleTestHelper import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suite for [[BaggedPoint]]. */ -class BaggedPointSuite extends FunSuite with MLlibTestSparkContext { +class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 668fc1d43c5d6..70219e9ad9d3e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -21,19 +21,19 @@ import java.io.File import scala.io.Source -import org.scalatest.FunSuite - import breeze.linalg.{squaredDistance => breezeSquaredDistance} import com.google.common.base.Charsets import com.google.common.io.Files +import org.apache.spark.SparkException +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils._ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { +class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { test("epsilon computation") { assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.") @@ -63,7 +63,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { val fastSquaredDist3 = fastSquaredDistance(v2, norm2, v3, norm3, precision) assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m") - if (m > 10) { + if (m > 10) { val v4 = Vectors.sparse(n, indices.slice(0, m - 10), indices.map(i => a(i) + 0.5).slice(0, m - 10)) val norm4 = Vectors.norm(v4, 2.0) @@ -109,6 +109,40 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { Utils.deleteRecursively(tempDir) } + test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") { + val lines = + """ + |0 + |0 0:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + + test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") { + val lines = + """ + |0 + |0 3:4.0 2:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + test("saveAsLibSVMFile") { val examples = sc.parallelize(Seq( LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))), @@ -168,7 +202,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { "Each training+validation set combined should contain all of the data.") } // K fold cross validation should only have each element in the validation set exactly once - assert(foldedRdds.map(_._2).reduce((x,y) => x.union(y)).collect().sorted === + assert(foldedRdds.map(_._2).reduce((x, y) => x.union(y)).collect().sorted === data.collect().sorted) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala index f68fb95eac4e4..8dcb9ba9be108 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.mllib.util -import org.scalatest.FunSuite +import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.SparkException - -class NumericParserSuite extends FunSuite { +class NumericParserSuite extends SparkFunSuite { test("parser") { val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index 59e6c778806f4..8f475f30249d6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.mllib.util +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.scalatest.FunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.scalatest.exceptions.TestFailedException -class TestingUtilsSuite extends FunSuite { +class TestingUtilsSuite extends SparkFunSuite { test("Comparing doubles using relative error.") { diff --git a/network/common/pom.xml b/network/common/pom.xml index 0c3147761cfc5..a85e0a66f4a30 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 7dc7c65825e34..4b5bfcb6f04bc 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index 1e2e9c80af6cc..a99f7c4392d3d 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/pom.xml b/pom.xml index c72d7cbf843ef..e9700a5d7b149 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -114,11 +114,10 @@ UTF-8 UTF-8 - org.spark-project.akka - 2.3.4-spark - 1.6 + com.typesafe.akka + 2.3.11 + 1.7 spark - 2.0.1 0.21.1 shaded-protobuf 1.7.10 @@ -137,7 +136,7 @@ 0.13.1 10.10.1.1 - 1.6.0rc3 + 1.7.0 1.2.4 8.1.14.v20131031 3.0.0.v201112011016 @@ -180,7 +179,7 @@ compile ${session.executionRootDirectory} @@ -269,6 +268,18 @@ false + + + spark-1.4-staging + Spark 1.4 RC4 Staging Repository + https://repository.apache.org/content/repositories/orgapachespark-1112 + + true + + + false + + @@ -576,7 +587,7 @@ io.netty netty-all - 4.0.23.Final + 4.0.28.Final org.apache.derby @@ -1069,13 +1080,13 @@ - com.twitter + org.apache.parquet parquet-column ${parquet.version} ${parquet.deps.scope} - com.twitter + org.apache.parquet parquet-hadoop ${parquet.version} ${parquet.deps.scope} @@ -1205,15 +1216,6 @@ -target ${java.version} - - - - org.scalamacros - paradise_${scala.version} - ${scala.macros.version} - - @@ -1252,7 +1254,9 @@ ${test.java.home} + test true + ${project.build.directory}/tmp ${spark.test.home} 1 false @@ -1284,7 +1288,9 @@ ${test.java.home} + test true + ${project.build.directory}/tmp ${spark.test.home} 1 false @@ -1426,6 +1432,8 @@ 2.3 false + + false @@ -1542,6 +1550,26 @@ + + + org.apache.maven.plugins + maven-antrun-plugin + + + create-tmp-dir + generate-test-resources + + run + + + + + + + + + + org.apache.maven.plugins @@ -1664,6 +1692,8 @@ 0.98.7-hadoop1 hadoop1 1.8.8 + org.spark-project.akka + 2.3.4-spark @@ -1753,22 +1783,6 @@ sql/hive-thriftserver - - hive-0.12.0 - - 0.12.0-protobuf-2.5 - 0.12.0 - 10.4.2.0 - - - - hive-0.13.1 - - 0.13.1a - 0.13.1 - 10.10.1.1 - - scala-2.10 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index dde92949fa175..5812b72f0aa78 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,7 +91,8 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.3.0" + // TODO: Change this once Spark 1.4.0 is released + val previousSparkVersion = "1.4.0-rc4" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 11b439e7875fc..8a93ca2999510 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,10 +34,31 @@ import com.typesafe.tools.mima.core.ProblemFilters._ object MimaExcludes { def excludes(version: String) = version match { + case v if v.startsWith("1.5") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + // JavaRDDLike is not meant to be extended by user programs + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.partitioner"), + // Mima false positive (was a private[spark] class) + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.util.collection.PairIterator"), + // Removing a testing method from a private class + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), + // SQL execution is considered private. + excludePackage("org.apache.spark.sql.execution") + ) case v if v.startsWith("1.4") => Seq( MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("ml"), + // SPARK-7910 Adding a method to get the partioner to JavaRDD, + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), // These are needed if checking against the sbt build, since they are part of diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b9515a12bc573..d7e374558c5e2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConversions._ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ -import sbtunidoc.Plugin.genjavadocSettings import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings @@ -52,6 +51,11 @@ object BuildCommons { // Root project. val spark = ProjectRef(buildLocation, "spark") val sparkHome = buildLocation + + val testTempDir = s"$sparkHome/target/tmp" + if (!new File(testTempDir).isDirectory()) { + require(new File(testTempDir).mkdirs()) + } } object SparkBuild extends PomBuild { @@ -118,7 +122,12 @@ object SparkBuild extends PomBuild { lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") - lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq ( + lazy val sparkGenjavadocSettings: Seq[sbt.Def.Setting[_]] = Seq( + libraryDependencies += compilerPlugin( + "org.spark-project" %% "genjavadoc-plugin" % unidocGenjavadocVersion.value cross CrossVersion.full), + scalacOptions <+= target.map(t => "-P:genjavadoc:out=" + (t / "java"))) + + lazy val sharedSettings = graphSettings ++ sparkGenjavadocSettings ++ Seq ( javaHome := sys.env.get("JAVA_HOME") .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) .map(file), @@ -126,7 +135,7 @@ object SparkBuild extends PomBuild { retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, - unidocGenjavadocVersion := "0.8", + unidocGenjavadocVersion := "0.9-spark0", resolvers += Resolver.mavenLocal, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), @@ -174,9 +183,6 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) - /* Catalyst macro settings */ - enable(Catalyst.settings)(catalyst) - /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -271,14 +277,6 @@ object OldDeps { ) } -object Catalyst { - lazy val settings = Seq( - addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), - // Quasiquotes break compiling scala doc... - // TODO: Investigate fixing this. - sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen"))) -} - object SQL { lazy val settings = Seq( initialCommands in console := @@ -503,6 +501,7 @@ object TestSettings { "SPARK_DIST_CLASSPATH" -> (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), + javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir", javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", @@ -511,6 +510,7 @@ object TestSettings { javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", + javaOptions in Test += "-Dderby.system.durability=test", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test += "-ea", diff --git a/project/plugins.sbt b/project/plugins.sbt index 7096b0d3ee7de..75bd604a1b857 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -25,7 +25,7 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") -addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1") +addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3") addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 0d21a132048a5..adca90ddaf397 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -261,3 +261,7 @@ def _start_update_server(): thread.daemon = True thread.start() return server + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index aeb7ad4f2f83e..44d90f1437bc9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -324,10 +324,12 @@ def stop(self): with SparkContext._lock: SparkContext._active_spark_context = None - def range(self, start, end, step=1, numSlices=None): + def range(self, start, end=None, step=1, numSlices=None): """ Create a new RDD of int containing elements from `start` to `end` - (exclusive), increased by `step` every element. + (exclusive), increased by `step` every element. Can be called the same + way as python's built-in range() function. If called with a single argument, + the argument is interpreted as `end`, and `start` is set to 0. :param start: the start value :param end: the end value (exclusive) @@ -335,9 +337,17 @@ def range(self, start, end, step=1, numSlices=None): :param numSlices: the number of partitions of the new RDD :return: An RDD of int + >>> sc.range(5).collect() + [0, 1, 2, 3, 4] + >>> sc.range(2, 4).collect() + [2, 3] >>> sc.range(1, 7, 2).collect() [1, 3, 5] """ + if end is None: + end = start + start = 0 + return self.parallelize(xrange(start, end, step), numSlices) def parallelize(self, c, numSlices=None): diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 23c37167b3711..d8ddb78c6d639 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -205,7 +205,7 @@ def getMetricName(self): def setParams(self, predictionCol="prediction", labelCol="label", metricName="rmse"): """ - setParams(self, predictionCol="prediction", labelCol="label", + setParams(self, predictionCol="prediction", labelCol="label", \ metricName="rmse") Sets params for regression evaluator. """ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index b0479d9b074db..ddb33f427ac64 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -324,65 +324,73 @@ def getP(self): @inherit_doc class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): """ - A one-hot encoder that maps a column of label indices to a column of binary vectors, with - at most a single one-value. By default, the binary vector has an element for each category, so - with 5 categories, an input value of 2.0 would map to an output vector of - (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so - the output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value - of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns - linearly dependent because they sum up to one. - - TODO: This method requires the use of StringIndexer first. Decouple them. + A one-hot encoder that maps a column of category indices to a + column of binary vectors, with at most a single one-value per row + that indicates the input category index. + For example with 5 categories, an input value of 2.0 would map to + an output vector of `[0.0, 0.0, 1.0, 0.0]`. + The last category is not included by default (configurable via + :py:attr:`dropLast`) because it makes the vector entries sum up to + one, and hence linearly dependent. + So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + Note that this is different from scikit-learn's OneHotEncoder, + which keeps all categories. + The output vectors are sparse. + + .. seealso:: + + :py:class:`StringIndexer` for converting categorical values into + category indices >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> model = stringIndexer.fit(stringIndDf) >>> td = model.transform(stringIndDf) - >>> encoder = OneHotEncoder(includeFirst=False, inputCol="indexed", outputCol="features") + >>> encoder = OneHotEncoder(inputCol="indexed", outputCol="features") >>> encoder.transform(td).head().features - SparseVector(2, {}) + SparseVector(2, {0: 1.0}) >>> encoder.setParams(outputCol="freqs").transform(td).head().freqs - SparseVector(2, {}) - >>> params = {encoder.includeFirst: True, encoder.outputCol: "test"} + SparseVector(2, {0: 1.0}) + >>> params = {encoder.dropLast: False, encoder.outputCol: "test"} >>> encoder.transform(td, params).head().test SparseVector(3, {0: 1.0}) """ # a placeholder to make it appear in the generated doc - includeFirst = Param(Params._dummy(), "includeFirst", "include first category") + dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category") @keyword_only - def __init__(self, includeFirst=True, inputCol=None, outputCol=None): + def __init__(self, dropLast=True, inputCol=None, outputCol=None): """ __init__(self, includeFirst=True, inputCol=None, outputCol=None) """ super(OneHotEncoder, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) - self.includeFirst = Param(self, "includeFirst", "include first category") - self._setDefault(includeFirst=True) + self.dropLast = Param(self, "dropLast", "whether to drop the last category") + self._setDefault(dropLast=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, includeFirst=True, inputCol=None, outputCol=None): + def setParams(self, dropLast=True, inputCol=None, outputCol=None): """ - setParams(self, includeFirst=True, inputCol=None, outputCol=None) + setParams(self, dropLast=True, inputCol=None, outputCol=None) Sets params for this OneHotEncoder. """ kwargs = self.setParams._input_kwargs return self._set(**kwargs) - def setIncludeFirst(self, value): + def setDropLast(self, value): """ - Sets the value of :py:attr:`includeFirst`. + Sets the value of :py:attr:`dropLast`. """ - self._paramMap[self.includeFirst] = value + self._paramMap[self.dropLast] = value return self - def getIncludeFirst(self): + def getDropLast(self): """ - Gets the value of includeFirst or its default value. + Gets the value of dropLast or its default value. """ - return self.getOrDefault(self.includeFirst) + return self.getOrDefault(self.dropLast) @inherit_doc diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index b3e0dd7abf681..b06099ac0aee6 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -63,8 +63,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha indicated user preferences rather than explicit ratings given to items. + >>> df = sqlContext.createDataFrame( + ... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], + ... ["user", "item", "rating"]) >>> als = ALS(rank=10, maxIter=5) >>> model = als.fit(df) + >>> model.rank + 10 + >>> model.userFactors.orderBy("id").collect() + [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)] >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions[0] @@ -260,6 +267,27 @@ class ALSModel(JavaModel): Model fitted by ALS. """ + @property + def rank(self): + """rank of the matrix factorization model""" + return self._call_java("rank") + + @property + def userFactors(self): + """ + a DataFrame that stores user factors in two columns: `id` and + `features` + """ + return self._call_java("userFactors") + + @property + def itemFactors(self): + """ + a DataFrame that stores item factors in two columns: `id` and + `features` + """ + return self._call_java("itemFactors") + if __name__ == "__main__": import doctest @@ -272,8 +300,6 @@ class ALSModel(JavaModel): sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - globs['df'] = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), - (2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"]) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() if failure_count: diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 497841b6c8ce6..0bf988fd72f14 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -91,20 +91,19 @@ class CrossValidator(Estimator): >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator >>> from pyspark.mllib.linalg import Vectors >>> dataset = sqlContext.createDataFrame( - ... [(Vectors.dense([0.0, 1.0]), 0.0), - ... (Vectors.dense([1.0, 2.0]), 1.0), - ... (Vectors.dense([0.55, 3.0]), 0.0), - ... (Vectors.dense([0.45, 4.0]), 1.0), - ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10, + ... [(Vectors.dense([0.0]), 0.0), + ... (Vectors.dense([0.4]), 1.0), + ... (Vectors.dense([0.5]), 0.0), + ... (Vectors.dense([0.6]), 1.0), + ... (Vectors.dense([1.0]), 1.0)] * 10, ... ["features", "label"]) >>> lr = LogisticRegression() - >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build() + >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - >>> # SPARK-7432: The following test is flaky. - >>> # cvModel = cv.fit(dataset) - >>> # expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset) - >>> # cvModel.transform(dataset).collect() == expected.collect() + >>> cvModel = cv.fit(dataset) + >>> evaluator.evaluate(cvModel.transform(dataset)) + 0.8333... """ # a placeholder to make it appear in the generated doc diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 07507b2ad0d05..acba3a717d21a 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -23,16 +23,10 @@ # MLlib currently needs NumPy 1.4+, so complain if lower import numpy -if numpy.version.version < '1.4': + +ver = [int(x) for x in numpy.version.version.split('.')[:2]] +if ver < [1, 4]: raise Exception("MLlib requires NumPy 1.4+") __all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', 'recommendation', 'regression', 'stat', 'tree', 'util'] - -import sys -from . import rand as random -modname = __name__ + '.random' -random.__name__ = modname -random.RandomRDDs.__module__ = modname -sys.modules[modname] = random -del modname, sys diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index ba6058978880a..855e85f57155e 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -27,7 +27,7 @@ from pyspark import RDD, SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer - +from pyspark.sql import DataFrame, SQLContext # Hack for support float('inf') in Py4j _old_smart_decode = py4j.protocol.smart_decode @@ -99,6 +99,9 @@ def _java2py(sc, r, encoding="bytes"): jrdd = sc._jvm.SerDe.javaToPython(r) return RDD(jrdd, sc) + if clsName == 'DataFrame': + return DataFrame(r, SQLContext(sc)) + if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) elif isinstance(r, (JavaArray, JavaList)): diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index aab5e5f4b77b5..c5cf3a4e7ff22 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -27,6 +27,8 @@ class BinaryClassificationMetrics(JavaModelWrapper): """ Evaluator for binary classification. + :param scoreAndLabels: an RDD of (score, label) pairs + >>> scoreAndLabels = sc.parallelize([ ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2) >>> metrics = BinaryClassificationMetrics(scoreAndLabels) @@ -38,9 +40,6 @@ class BinaryClassificationMetrics(JavaModelWrapper): """ def __init__(self, scoreAndLabels): - """ - :param scoreAndLabels: an RDD of (score, label) pairs - """ sc = scoreAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ @@ -76,6 +75,9 @@ class RegressionMetrics(JavaModelWrapper): """ Evaluator for regression. + :param predictionAndObservations: an RDD of (prediction, + observation) pairs. + >>> predictionAndObservations = sc.parallelize([ ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)]) >>> metrics = RegressionMetrics(predictionAndObservations) @@ -92,9 +94,6 @@ class RegressionMetrics(JavaModelWrapper): """ def __init__(self, predictionAndObservations): - """ - :param predictionAndObservations: an RDD of (prediction, observation) pairs. - """ sc = predictionAndObservations.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([ @@ -148,6 +147,8 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. + :param predictionAndLabels an RDD of (prediction, label) pairs. + >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) >>> metrics = MulticlassMetrics(predictionAndLabels) @@ -176,9 +177,6 @@ class MulticlassMetrics(JavaModelWrapper): """ def __init__(self, predictionAndLabels): - """ - :param predictionAndLabels an RDD of (prediction, label) pairs. - """ sc = predictionAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ @@ -277,6 +275,9 @@ class RankingMetrics(JavaModelWrapper): """ Evaluator for ranking algorithms. + :param predictionAndLabels: an RDD of (predicted ranking, + ground truth set) pairs. + >>> predictionAndLabels = sc.parallelize([ ... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]), ... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]), @@ -298,9 +299,6 @@ class RankingMetrics(JavaModelWrapper): """ def __init__(self, predictionAndLabels): - """ - :param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs. - """ sc = predictionAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndLabels, @@ -347,6 +345,10 @@ class MultilabelMetrics(JavaModelWrapper): """ Evaluator for multilabel classification. + :param predictionAndLabels: an RDD of (predictions, labels) pairs, + both are non-null Arrays, each with + unique elements. + >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]), ... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]), ... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])]) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index aac305db6c19a..da90554f41437 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -68,6 +68,8 @@ class Normalizer(VectorTransformer): For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization. + :param p: Normalization in L^p^ space, p = 2 by default. + >>> v = Vectors.dense(range(3)) >>> nor = Normalizer(1) >>> nor.transform(v) @@ -82,9 +84,6 @@ class Normalizer(VectorTransformer): DenseVector([0.0, 0.5, 1.0]) """ def __init__(self, p=2.0): - """ - :param p: Normalization in L^p^ space, p = 2 by default. - """ assert p >= 1.0, "p should be greater than 1.0" self.p = float(p) @@ -94,7 +93,7 @@ def transform(self, vector): :param vector: vector or RDD of vector to be normalized. :return: normalized vector. If the norm of the input is zero, it - will return the input vector. + will return the input vector. """ sc = SparkContext._active_spark_context assert sc is not None, "SparkContext should be initialized first" @@ -164,6 +163,13 @@ class StandardScaler(object): variance using column summary statistics on the samples in the training set. + :param withMean: False by default. Centers the data with mean + before scaling. It will build a dense output, so this + does not work on sparse input and will raise an + exception. + :param withStd: True by default. Scales the data to unit + standard deviation. + >>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])] >>> dataset = sc.parallelize(vs) >>> standardizer = StandardScaler(True, True) @@ -174,14 +180,6 @@ class StandardScaler(object): DenseVector([0.7071, -0.7071, 0.7071]) """ def __init__(self, withMean=False, withStd=True): - """ - :param withMean: False by default. Centers the data with mean - before scaling. It will build a dense output, so this - does not work on sparse input and will raise an - exception. - :param withStd: True by default. Scales the data to unit - standard deviation. - """ if not (withMean or withStd): warnings.warn("Both withMean and withStd are false. The model does nothing.") self.withMean = withMean @@ -193,7 +191,7 @@ def fit(self, dataset): for later scaling. :param data: The data used to compute the mean and variance - to build the transformation model. + to build the transformation model. :return: a StandardScalarModel """ dataset = dataset.map(_convert_to_vector) @@ -223,6 +221,8 @@ class ChiSqSelector(object): Creates a ChiSquared feature selector. + :param numTopFeatures: number of features that selector will select. + >>> data = [ ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})), ... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})), @@ -236,9 +236,6 @@ class ChiSqSelector(object): DenseVector([5.0]) """ def __init__(self, numTopFeatures): - """ - :param numTopFeatures: number of features that selector will select. - """ self.numTopFeatures = int(numTopFeatures) def fit(self, data): @@ -246,9 +243,9 @@ def fit(self, data): Returns a ChiSquared feature selector. :param data: an `RDD[LabeledPoint]` containing the labeled dataset - with categorical features. Real-valued features will be - treated as categorical for each distinct value. - Apply feature discretizer before using this function. + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + Apply feature discretizer before using this function. """ jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data) return ChiSqSelectorModel(jmodel) @@ -263,15 +260,14 @@ class HashingTF(object): Note: the terms must be hashable (can not be dict/set/list...). + :param numFeatures: number of features (default: 2^20) + >>> htf = HashingTF(100) >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) SparseVector(100, {...}) """ def __init__(self, numFeatures=1 << 20): - """ - :param numFeatures: number of features (default: 2^20) - """ self.numFeatures = numFeatures def indexOf(self, term): @@ -311,7 +307,7 @@ def transform(self, x): Call transform directly on the RDD instead. :param x: an RDD of term frequency vectors or a term frequency - vector + vector :return: an RDD of TF-IDF vectors or a TF-IDF vector """ if isinstance(x, RDD): @@ -342,6 +338,9 @@ class IDF(object): `minDocFreq`). For terms that are not in at least `minDocFreq` documents, the IDF is found as 0, resulting in TF-IDFs of 0. + :param minDocFreq: minimum of documents in which a term + should appear for filtering + >>> n = 4 >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)), ... Vectors.dense([0.0, 1.0, 2.0, 3.0]), @@ -362,10 +361,6 @@ class IDF(object): SparseVector(4, {1: 0.0, 3: 0.5754}) """ def __init__(self, minDocFreq=0): - """ - :param minDocFreq: minimum of documents in which a term - should appear for filtering - """ self.minDocFreq = minDocFreq def fit(self, dataset): diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/random.py similarity index 100% rename from python/pyspark/mllib/rand.py rename to python/pyspark/mllib/random.py diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py new file mode 100644 index 0000000000000..7da921976d4d2 --- /dev/null +++ b/python/pyspark/mllib/stat/KernelDensity.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +if sys.version > '3': + xrange = range + +import numpy as np + +from pyspark.mllib.common import callMLlibFunc +from pyspark.rdd import RDD + + +class KernelDensity(object): + """ + .. note:: Experimental + + Estimate probability density at required points given a RDD of samples + from the population. + + >>> kd = KernelDensity() + >>> sample = sc.parallelize([0.0, 1.0]) + >>> kd.setSample(sample) + >>> kd.estimate([0.0, 1.0]) + array([ 0.12938758, 0.12938758]) + """ + def __init__(self): + self._bandwidth = 1.0 + self._sample = None + + def setBandwidth(self, bandwidth): + """Set bandwidth of each sample. Defaults to 1.0""" + self._bandwidth = bandwidth + + def setSample(self, sample): + """Set sample points from the population. Should be a RDD""" + if not isinstance(sample, RDD): + raise TypeError("samples should be a RDD, received %s" % type(sample)) + self._sample = sample + + def estimate(self, points): + """Estimate the probability density at points""" + points = list(points) + densities = callMLlibFunc( + "estimateKernelDensity", self._sample, self._bandwidth, points) + return np.asarray(densities) diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py index e3e128513e0d7..c8a721d3fe41c 100644 --- a/python/pyspark/mllib/stat/__init__.py +++ b/python/pyspark/mllib/stat/__init__.py @@ -22,6 +22,7 @@ from pyspark.mllib.stat._statistics import * from pyspark.mllib.stat.distribution import MultivariateGaussian from pyspark.mllib.stat.test import ChiSqTestResult +from pyspark.mllib.stat.KernelDensity import KernelDensity __all__ = ["Statistics", "MultivariateStatisticalSummary", "ChiSqTestResult", - "MultivariateGaussian"] + "MultivariateGaussian", "KernelDensity"] diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 1d0b16cade8bb..81c420ce16541 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -362,7 +362,7 @@ def _spill(self): self.spills += 1 gc.collect() # release the memory as much as possible - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 def items(self): """ Return all merged items as iterator """ @@ -515,7 +515,7 @@ def load(f): gc.collect() batch //= 2 limit = self._next_limit() - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 DiskBytesSpilled += os.path.getsize(path) os.unlink(path) # data will be deleted after close @@ -630,7 +630,7 @@ def _spill(self): self.values = [] gc.collect() DiskBytesSpilled += self._file.tell() - pos - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 class ExternalListOfList(ExternalList): @@ -794,7 +794,7 @@ def _spill(self): self.spills += 1 gc.collect() # release the memory as much as possible - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 def _merged_items(self, index): size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index))) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 8fee92ae3aed5..ad9c891ba1c04 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -45,22 +45,19 @@ def since(version): + """ + A decorator that annotates a function to append the version of Spark the function was added. + """ + import re + indent_p = re.compile(r'\n( +)') + def deco(f): - f.__doc__ = f.__doc__.rstrip() + "\n\n.. versionadded:: %s" % version + indents = indent_p.findall(f.__doc__) + indent = ' ' * (min(len(m) for m in indents) if indents else 0) + f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version) return f return deco -# fix the module name conflict for Python 3+ -import sys -from . import _types as types -modname = __name__ + '.types' -types.__name__ = modname -# update the __module__ for all objects, make them picklable -for v in types.__dict__.values(): - if hasattr(v, "__module__") and v.__module__.endswith('._types'): - v.__module__ = modname -sys.modules[modname] = types -del modname, sys from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext @@ -70,7 +67,9 @@ def deco(f): from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter from pyspark.sql.window import Window, WindowSpec + __all__ = [ 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', + 'DataFrameReader', 'DataFrameWriter' ] diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 8dc5039f587f0..1ecec5b126505 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -315,6 +315,14 @@ def between(self, lowerBound, upperBound): """ A boolean expression that is evaluated to true if the value of this expression is between the given columns. + + >>> df.select(df.name, df.age.between(2, 4)).show() + +-----+--------------------------+ + | name|((age >= 2) && (age <= 4))| + +-----+--------------------------+ + |Alice| true| + | Bob| false| + +-----+--------------------------+ """ return (self >= lowerBound) & (self <= upperBound) @@ -328,12 +336,20 @@ def when(self, condition, value): :param condition: a boolean :class:`Column` expression. :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() + +-----+--------------------------------------------------------+ + | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0| + +-----+--------------------------------------------------------+ + |Alice| -1| + | Bob| 1| + +-----+--------------------------------------------------------+ """ - sc = SparkContext._active_spark_context if not isinstance(condition, Column): raise TypeError("condition should be a Column") v = value._jc if isinstance(value, Column) else value - jc = sc._jvm.functions.when(condition._jc, v) + jc = self._jc.when(condition._jc, v) return Column(jc) @since(1.4) @@ -345,9 +361,18 @@ def otherwise(self, value): See :func:`pyspark.sql.functions.when` for example usage. :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show() + +-----+---------------------------------+ + | name|CASE WHEN (age > 3) THEN 1 ELSE 0| + +-----+---------------------------------+ + |Alice| 0| + | Bob| 1| + +-----+---------------------------------+ """ v = value._jc if isinstance(value, Column) else value - jc = self._jc.otherwise(value) + jc = self._jc.otherwise(v) return Column(jc) @since(1.4) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 22f6257dfe02d..599c9ac5794a2 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -124,11 +124,14 @@ def getConf(self, key, defaultValue): @property @since("1.3.1") def udf(self): - """Returns a :class:`UDFRegistration` for UDF registration.""" + """Returns a :class:`UDFRegistration` for UDF registration. + + :return: :class:`UDFRegistration` + """ return UDFRegistration(self) @since(1.4) - def range(self, start, end, step=1, numPartitions=None): + def range(self, start, end=None, step=1, numPartitions=None): """ Create a :class:`DataFrame` with single LongType column named `id`, containing elements in a range from `start` to `end` (exclusive) with @@ -138,14 +141,24 @@ def range(self, start, end, step=1, numPartitions=None): :param end: the end value (exclusive) :param step: the incremental step (default: 1) :param numPartitions: the number of partitions of the DataFrame - :return: A new DataFrame + :return: :class:`DataFrame` >>> sqlContext.range(1, 7, 2).collect() [Row(id=1), Row(id=3), Row(id=5)] + + If only one argument is specified, it will be used as the end value. + + >>> sqlContext.range(3).collect() + [Row(id=0), Row(id=1), Row(id=2)] """ if numPartitions is None: numPartitions = self._sc.defaultParallelism - jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) + + if end is None: + jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions)) + else: + jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) + return DataFrame(jdf, self) @ignore_unicode_prefix @@ -195,8 +208,8 @@ def _inferSchema(self, rdd, samplingRatio=None): raise ValueError("The first row in RDD is empty, " "can not infer schema") if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.sql.Row instead") + warnings.warn("Using RDD of dict to inferSchema is deprecated. " + "Use pyspark.sql.Row instead") if samplingRatio is None: schema = _infer_schema(first) @@ -219,7 +232,7 @@ def inferSchema(self, rdd, samplingRatio=None): """ .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. """ - warnings.warn("inferSchema is deprecated, please use createDataFrame instead") + warnings.warn("inferSchema is deprecated, please use createDataFrame instead.") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") @@ -262,6 +275,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): :class:`list`, or :class:`pandas.DataFrame`. :param schema: a :class:`StructType` or list of column names. default None. :param samplingRatio: the sample ratio of rows used for inferring + :return: :class:`DataFrame` >>> l = [('Alice', 1)] >>> sqlContext.createDataFrame(l).collect() @@ -359,18 +373,15 @@ def registerDataFrameAsTable(self, df, tableName): else: raise ValueError("Can only register DataFrame as table") - @since(1.0) def parquetFile(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlContext.parquetFile(parquetFile) - >>> sorted(df.collect()) == sorted(df2.collect()) - True + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead. + + >>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ + warnings.warn("parquetFile is deprecated. Use read.parquet() instead.") gateway = self._sc._gateway jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) for i in range(0, len(paths)): @@ -378,39 +389,15 @@ def parquetFile(self, *paths): jdf = self._ssql_ctx.parquetFile(jpaths) return DataFrame(jdf, self) - @since(1.0) def jsonFile(self, path, schema=None, samplingRatio=1.0): """Loads a text file storing one JSON object per line as a :class:`DataFrame`. - If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead. - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> with open(jsonFile, 'w') as f: - ... f.writelines(jsonStrings) - >>> df1 = sqlContext.jsonFile(jsonFile) - >>> df1.printSchema() - root - |-- field1: long (nullable = true) - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field4: long (nullable = true) - - >>> from pyspark.sql.types import * - >>> schema = StructType([ - ... StructField("field2", StringType()), - ... StructField("field3", - ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) - >>> df2 = sqlContext.jsonFile(jsonFile, schema) - >>> df2.printSchema() - root - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field5: array (nullable = true) - | | |-- element: integer (containsNull = true) + >>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes + [('age', 'bigint'), ('name', 'string')] """ + warnings.warn("jsonFile is deprecated. Use read.json() instead.") if schema is None: df = self._ssql_ctx.jsonFile(path, samplingRatio) else: @@ -462,21 +449,16 @@ def func(iterator): df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return DataFrame(df, self) - @since(1.3) def load(self, path=None, source=None, schema=None, **options): """Returns the dataset in a data source as a :class:`DataFrame`. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Optionally, a schema can be provided as the schema of the returned DataFrame. + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead. """ + warnings.warn("load is deprecated. Use read.load() instead.") return self.read.load(path, source, schema, **options) @since(1.3) - def createExternalTable(self, tableName, path=None, source=None, - schema=None, **options): + def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): """Creates an external table based on the dataset in a data source. It returns the DataFrame associated with the external table. @@ -487,6 +469,8 @@ def createExternalTable(self, tableName, path=None, source=None, Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and created external table. + + :return: :class:`DataFrame` """ if path is not None: options["path"] = path @@ -508,6 +492,8 @@ def createExternalTable(self, tableName, path=None, source=None, def sql(self, sqlQuery): """Returns a :class:`DataFrame` representing the result of the given query. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() @@ -519,6 +505,8 @@ def sql(self, sqlQuery): def table(self, tableName): """Returns the specified table as a :class:`DataFrame`. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -536,6 +524,9 @@ def tables(self, dbName=None): The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` (a column with :class:`BooleanType` indicating if a table is a temporary one or not). + :param dbName: string, name of the database to use. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() @@ -550,7 +541,8 @@ def tables(self, dbName=None): def tableNames(self, dbName=None): """Returns a list of names of tables in the database ``dbName``. - If ``dbName`` is not specified, the current database will be used. + :param dbName: string, name of the database to use. Default to the current database. + :return: list of table names, in string >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> "table1" in sqlContext.tableNames() @@ -585,8 +577,7 @@ def read(self): Returns a :class:`DataFrameReader` that can be used to read data in as a :class:`DataFrame`. - >>> sqlContext.read - + :return: :class:`DataFrameReader` """ return DataFrameReader(self) @@ -644,10 +635,14 @@ def register(self, name, f, returnType=StringType()): def _test(): + import os import doctest from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.context + + os.chdir(os.environ["SPARK_HOME"]) + globs = pyspark.sql.context.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 936487519a645..9615e576497cd 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -22,6 +22,7 @@ if sys.version >= '3': basestring = unicode = str long = int + from functools import reduce else: from itertools import imap as map @@ -44,7 +45,7 @@ class DataFrame(object): A :class:`DataFrame` is equivalent to a relational table in Spark SQL, and can be created using various functions in :class:`SQLContext`:: - people = sqlContext.parquetFile("...") + people = sqlContext.read.parquet("...") Once created, it can be manipulated using the various domain-specific-language (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. @@ -56,8 +57,8 @@ class DataFrame(object): A more concrete example:: # To create DataFrame using SQLContext - people = sqlContext.parquetFile("...") - department = sqlContext.parquetFile("...") + people = sqlContext.read.parquet("...") + department = sqlContext.read.parquet("...") people.filter(people.age > 30).join(department, people.deptId == department.id)) \ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) @@ -120,21 +121,12 @@ def toJSON(self, use_unicode=True): rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - @since(1.3) def saveAsParquetFile(self, path): """Saves the contents as a Parquet file, preserving the schema. - Files that are written out using this method can be read back in as - a :class:`DataFrame` using :func:`SQLContext.parquetFile`. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlContext.parquetFile(parquetFile) - >>> sorted(df2.collect()) == sorted(df.collect()) - True + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.parquet` instead. """ + warnings.warn("saveAsParquetFile is deprecated. Use write.parquet() instead.") self._jdf.saveAsParquetFile(path) @since(1.3) @@ -151,69 +143,45 @@ def registerTempTable(self, name): """ self._jdf.registerTempTable(name) - @since(1.3) def registerAsTable(self, name): - """DEPRECATED: use :func:`registerTempTable` instead""" - warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + """ + .. note:: Deprecated in 1.4, use :func:`registerTempTable` instead. + """ + warnings.warn("Use registerTempTable instead of registerAsTable.") self.registerTempTable(name) - @since(1.3) def insertInto(self, tableName, overwrite=False): """Inserts the contents of this :class:`DataFrame` into the specified table. - Optionally overwriting any existing data. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.insertInto` instead. """ + warnings.warn("insertInto is deprecated. Use write.insertInto() instead.") self.write.insertInto(tableName, overwrite) - @since(1.3) def saveAsTable(self, tableName, source=None, mode="error", **options): """Saves the contents of this :class:`DataFrame` to a data source as a table. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Additionally, mode is used to specify the behavior of the saveAsTable operation when - table already exists in the data source. There are four modes: - - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.saveAsTable` instead. """ + warnings.warn("insertInto is deprecated. Use write.saveAsTable() instead.") self.write.saveAsTable(tableName, source, mode, **options) @since(1.3) def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: - - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.save` instead. """ + warnings.warn("insertInto is deprecated. Use write.save() instead.") return self.write.save(path, source, mode, **options) @property @since(1.4) def write(self): """ - Interface for saving the content of the :class:`DataFrame` out - into external storage. - - :return :class:`DataFrameWriter` + Interface for saving the content of the :class:`DataFrame` out into external storage. - .. note:: Experimental - - >>> df.write - + :return: :class:`DataFrameWriter` """ return DataFrameWriter(self) @@ -536,36 +504,52 @@ def alias(self, alias): @ignore_unicode_prefix @since(1.3) - def join(self, other, joinExprs=None, joinType=None): + def join(self, other, on=None, how=None): """Joins with another :class:`DataFrame`, using the given join expression. The following performs a full outer join between ``df1`` and ``df2``. :param other: Right side of the join - :param joinExprs: a string for join column name, or a join expression (Column). - If joinExprs is a string indicating the name of the join column, - the column must exist on both sides, and this performs an inner equi-join. - :param joinType: str, default 'inner'. + :param on: a string for join column name, a list of column names, + , a join expression (Column) or a list of Columns. + If `on` is a string or a list of string indicating the name of the join column(s), + the column(s) must exist on both sides, and this performs an inner equi-join. + :param how: str, default 'inner'. One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + >>> cond = [df.name == df3.name, df.age == df3.age] + >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() + [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)] + >>> df.join(df2, 'name').select(df.name, df2.height).collect() [Row(name=u'Bob', height=85)] + + >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect() + [Row(name=u'Bob', age=5)] """ - if joinExprs is None: + if on is not None and not isinstance(on, list): + on = [on] + + if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) - elif isinstance(joinExprs, basestring): - jdf = self._jdf.join(other._jdf, joinExprs) + + if isinstance(on[0], basestring): + jdf = self._jdf.join(other._jdf, self._jseq(on)) else: - assert isinstance(joinExprs, Column), "joinExprs should be Column" - if joinType is None: - jdf = self._jdf.join(other._jdf, joinExprs._jc) + assert isinstance(on[0], Column), "on should be Column or list of Column" + if len(on) > 1: + on = reduce(lambda x, y: x.__and__(y), on) + else: + on = on[0] + if how is None: + jdf = self._jdf.join(other._jdf, on._jc, "inner") else: - assert isinstance(joinType, basestring), "joinType should be basestring" - jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, on._jc, how) return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix @@ -636,6 +620,9 @@ def describe(self, *cols): This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical columns. + .. note:: This function is meant for exploratory data analysis, as we make no \ + guarantee about the backward compatibility of the schema of the resulting DataFrame. + >>> df.describe().show() +-------+---+ |summary|age| @@ -646,16 +633,30 @@ def describe(self, *cols): | min| 2| | max| 5| +-------+---+ + >>> df.describe(['age', 'name']).show() + +-------+---+-----+ + |summary|age| name| + +-------+---+-----+ + | count| 2| 2| + | mean|3.5| null| + | stddev|1.5| null| + | min| 2|Alice| + | max| 5| Bob| + +-------+---+-----+ """ + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] jdf = self._jdf.describe(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix @since(1.3) def head(self, n=None): - """ - Returns the first ``n`` rows as a list of :class:`Row`, - or the first :class:`Row` if ``n`` is ``None.`` + """Returns the first ``n`` rows. + + :param n: int, default 1. Number of rows to return. + :return: If n is greater than 1, return a list of :class:`Row`. + If n is 1, return a single Row. >>> df.head() Row(age=2, name=u'Alice') @@ -745,7 +746,7 @@ def selectExpr(self, *expr): This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() - [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] + [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)] """ if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] @@ -925,8 +926,7 @@ def dropDuplicates(self, subset=None): @since("1.3.1") def dropna(self, how='any', thresh=None, subset=None): """Returns a new :class:`DataFrame` omitting rows with null values. - - This is an alias for ``na.drop()``. + :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other. :param how: 'any' or 'all'. If 'any', drop a row if it contains any nulls. @@ -936,13 +936,6 @@ def dropna(self, how='any', thresh=None, subset=None): This overwrites the `how` parameter. :param subset: optional list of column names to consider. - >>> df4.dropna().show() - +---+------+-----+ - |age|height| name| - +---+------+-----+ - | 10| 80|Alice| - +---+------+-----+ - >>> df4.na.drop().show() +---+------+-----+ |age|height| name| @@ -968,6 +961,7 @@ def dropna(self, how='any', thresh=None, subset=None): @since("1.3.1") def fillna(self, value, subset=None): """Replace null values, alias for ``na.fill()``. + :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other. :param value: int, long, float, string, or dict. Value to replace null values with. @@ -979,7 +973,7 @@ def fillna(self, value, subset=None): For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. - >>> df4.fillna(50).show() + >>> df4.na.fill(50).show() +---+------+-----+ |age|height| name| +---+------+-----+ @@ -989,16 +983,6 @@ def fillna(self, value, subset=None): | 50| 50| null| +---+------+-----+ - >>> df4.fillna({'age': 50, 'name': 'unknown'}).show() - +---+------+-------+ - |age|height| name| - +---+------+-------+ - | 10| 80| Alice| - | 5| null| Bob| - | 50| null| Tom| - | 50| null|unknown| - +---+------+-------+ - >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() +---+------+-------+ |age|height| name| @@ -1030,6 +1014,8 @@ def fillna(self, value, subset=None): @since(1.4) def replace(self, to_replace, value, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. + :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are + aliases of each other. :param to_replace: int, long, float, string, or list. Value to be replaced. @@ -1045,7 +1031,7 @@ def replace(self, to_replace, value, subset=None): For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. - >>> df4.replace(10, 20).show() + >>> df4.na.replace(10, 20).show() +----+------+-----+ | age|height| name| +----+------+-----+ @@ -1055,7 +1041,7 @@ def replace(self, to_replace, value, subset=None): |null| null| null| +----+------+-----+ - >>> df4.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() + >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| +----+------+----+ @@ -1106,9 +1092,9 @@ def replace(self, to_replace, value, subset=None): @since(1.4) def corr(self, col1, col2, method=None): """ - Calculates the correlation of two columns of a DataFrame as a double value. Currently only - supports the Pearson Correlation Coefficient. - :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases. + Calculates the correlation of two columns of a DataFrame as a double value. + Currently only supports the Pearson Correlation Coefficient. + :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases of each other. :param col1: The name of the first column :param col2: The name of the second column @@ -1170,6 +1156,9 @@ def freqItems(self, cols, support=None): "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. + .. note:: This function is meant for exploratory data analysis, as we make no \ + guarantee about the backward compatibility of the schema of the resulting DataFrame. + :param cols: Names of the columns to calculate frequent items for as a list or tuple of strings. :param support: The frequency with which to consider an item 'frequent'. Default is 1%. @@ -1214,15 +1203,30 @@ def withColumnRenamed(self, existing, new): @since(1.4) @ignore_unicode_prefix - def drop(self, colName): + def drop(self, col): """Returns a new :class:`DataFrame` that drops the specified column. - :param colName: string, name of the column to drop. + :param col: a string name of the column to drop, or a + :class:`Column` to drop. >>> df.drop('age').collect() [Row(name=u'Alice'), Row(name=u'Bob')] + + >>> df.drop(df.age).collect() + [Row(name=u'Alice'), Row(name=u'Bob')] + + >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect() + [Row(age=5, height=85, name=u'Bob')] + + >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect() + [Row(age=5, name=u'Bob', height=85)] """ - jdf = self._jdf.drop(colName) + if isinstance(col, basestring): + jdf = self._jdf.drop(col) + elif isinstance(col, Column): + jdf = self._jdf.drop(col._jc) + else: + raise TypeError("col should be a string or a Column") return DataFrame(jdf, self.sql_ctx) @since(1.3) @@ -1239,7 +1243,10 @@ def toPandas(self): import pandas as pd return pd.DataFrame.from_records(self.collect(), columns=self.columns) + ########################################################################################## # Pandas compatibility + ########################################################################################## + groupby = groupBy drop_duplicates = dropDuplicates @@ -1259,6 +1266,8 @@ def _to_scala_map(sc, jm): class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. + + .. versionadded:: 1.4 """ def __init__(self, df): @@ -1274,9 +1283,16 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ + def replace(self, to_replace, value, subset=None): + return self.df.replace(to_replace, value, subset) + + replace.__doc__ = DataFrame.replace.__doc__ + class DataFrameStatFunctions(object): """Functionality for statistic functions with :class:`DataFrame`. + + .. versionadded:: 1.4 """ def __init__(self, df): @@ -1316,6 +1332,8 @@ def _test(): .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() + globs['df3'] = sc.parallelize([Row(name='Alice', age=2), + Row(name='Bob', age=5)]).toDF() globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), Row(name='Bob', age=5, height=None), Row(name='Tom', age=None, height=None), diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b6fd413bec7db..f036644acc961 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -43,6 +43,44 @@ def _df(self, jdf): from pyspark.sql.dataframe import DataFrame return DataFrame(jdf, self._sqlContext) + @since(1.4) + def format(self, source): + """Specifies the input data source format. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json') + >>> df.dtypes + [('age', 'bigint'), ('name', 'string')] + + """ + self._jreader = self._jreader.format(source) + return self + + @since(1.4) + def schema(self, schema): + """Specifies the input schema. + + Some data sources (e.g. JSON) can infer the input schema automatically from data. + By specifying the schema here, the underlying data source can skip the schema + inference step, and thus speed up data loading. + + :param schema: a StructType object + """ + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + return self + + @since(1.4) + def options(self, **options): + """Adds input options for the underlying data source. + """ + for k in options: + self._jreader = self._jreader.option(k, options[k]) + return self + @since(1.4) def load(self, path=None, format=None, schema=None, **options): """Loads data from a data source and returns it as a :class`DataFrame`. @@ -51,21 +89,20 @@ def load(self, path=None, format=None, schema=None, **options): :param format: optional string for format of the data source. Default to 'parquet'. :param schema: optional :class:`StructType` for the input schema. :param options: all other string options + + >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned') + >>> df.dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ - jreader = self._jreader if format is not None: - jreader = jreader.format(format) + self.format(format) if schema is not None: - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) - jreader = jreader.schema(jschema) - for k in options: - jreader = jreader.option(k, options[k]) + self.schema(schema) + self.options(**options) if path is not None: - return self._df(jreader.load(path)) + return self._df(self._jreader.load(path)) else: - return self._df(jreader.load()) + return self._df(self._jreader.load()) @since(1.4) def json(self, path, schema=None): @@ -79,47 +116,25 @@ def json(self, path, schema=None): :param path: string, path to the JSON dataset. :param schema: an optional :class:`StructType` for the input schema. - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> with open(jsonFile, 'w') as f: - ... f.writelines(jsonStrings) - >>> df1 = sqlContext.read.json(jsonFile) - >>> df1.printSchema() - root - |-- field1: long (nullable = true) - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field4: long (nullable = true) - - >>> from pyspark.sql.types import * - >>> schema = StructType([ - ... StructField("field2", StringType()), - ... StructField("field3", - ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) - >>> df2 = sqlContext.read.json(jsonFile, schema) - >>> df2.printSchema() - root - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field5: array (nullable = true) - | | |-- element: integer (containsNull = true) + >>> df = sqlContext.read.json('python/test_support/sql/people.json') + >>> df.dtypes + [('age', 'bigint'), ('name', 'string')] + """ - if schema is None: - jdf = self._jreader.json(path) - else: - jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) - jdf = self._jreader.schema(jschema).json(path) - return self._df(jdf) + if schema is not None: + self.schema(schema) + return self._df(self._jreader.json(path)) @since(1.4) def table(self, tableName): """Returns the specified table as a :class:`DataFrame`. - >>> sqlContext.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlContext.read.table("table1") - >>> sorted(df.collect()) == sorted(df2.collect()) - True + :param tableName: string, name of the table. + + >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df.registerTempTable('tmpTable') + >>> sqlContext.read.table('tmpTable').dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ return self._df(self._jreader.table(tableName)) @@ -127,13 +142,9 @@ def table(self, tableName): def parquet(self, *path): """Loads a Parquet file, returning the result as a :class:`DataFrame`. - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlContext.read.parquet(parquetFile) - >>> sorted(df.collect()) == sorted(df2.collect()) - True + >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df.dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path))) @@ -195,40 +206,88 @@ def __init__(self, df): self._jwrite = df._jdf.write() @since(1.4) - def save(self, path=None, format=None, mode="error", **options): - """ - Saves the contents of the :class:`DataFrame` to a data source. + def mode(self, saveMode): + """Specifies the behavior when data or table already exists. - The data source is specified by the ``format`` and a set of ``options``. - If ``format`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: + Options include: * `append`: Append contents of this :class:`DataFrame` to existing data. * `overwrite`: Overwrite existing data. * `error`: Throw an exception if data already exists. * `ignore`: Silently ignore this operation if data already exists. + >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self._jwrite = self._jwrite.mode(saveMode) + return self + + @since(1.4) + def format(self, source): + """Specifies the underlying output data source. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self._jwrite = self._jwrite.format(source) + return self + + @since(1.4) + def options(self, **options): + """Adds output options for the underlying data source. + """ + for k in options: + self._jwrite = self._jwrite.option(k, options[k]) + return self + + @since(1.4) + def partitionBy(self, *cols): + """Partitions the output by the given columns on the file system. + + If specified, the output is laid out on the file system similar + to Hive's partitioning scheme. + + :param cols: name of columns + + >>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + if len(cols) == 1 and isinstance(cols[0], (list, tuple)): + cols = cols[0] + self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) + return self + + @since(1.4) + def save(self, path=None, format=None, mode="error", **options): + """Saves the contents of the :class:`DataFrame` to a data source. + + The data source is specified by the ``format`` and a set of ``options``. + If ``format`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + :param path: the path in a Hadoop supported file system :param format: the format used to save - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. :param options: all other string options + + >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - jwrite = self._jwrite.mode(mode) + self.mode(mode).options(**options) if format is not None: - jwrite = jwrite.format(format) - for k in options: - jwrite = jwrite.option(k, options[k]) + self.format(format) if path is None: - jwrite.save() + self._jwrite.save() else: - jwrite.save(path) + self._jwrite.save(path) + @since(1.4) def insertInto(self, tableName, overwrite=False): - """ - Inserts the content of the :class:`DataFrame` to the specified table. + """Inserts the content of the :class:`DataFrame` to the specified table. + It requires that the schema of the class:`DataFrame` is the same as the schema of the table. @@ -238,8 +297,7 @@ def insertInto(self, tableName, overwrite=False): @since(1.4) def saveAsTable(self, name, format=None, mode="error", **options): - """ - Saves the content of the :class:`DataFrame` as the specified table. + """Saves the content of the :class:`DataFrame` as the specified table. In the case the table already exists, behavior of this function depends on the save mode, specified by the `mode` function (default to throwing an exception). @@ -256,72 +314,61 @@ def saveAsTable(self, name, format=None, mode="error", **options): :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) :param options: all other string options """ - jwrite = self._jwrite.mode(mode) + self.mode(mode).options(**options) if format is not None: - jwrite = jwrite.format(format) - for k in options: - jwrite = jwrite.option(k, options[k]) - return jwrite.saveAsTable(name) + self.format(format) + self._jwrite.saveAsTable(name) @since(1.4) def json(self, path, mode="error"): - """ - Saves the content of the :class:`DataFrame` in JSON format at the - specified path. + """Saves the content of the :class:`DataFrame` in JSON format at the specified path. - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. - :param path: the path in any Hadoop supported file system - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ - return self._jwrite.mode(mode).json(path) + self._jwrite.mode(mode).json(path) @since(1.4) def parquet(self, path, mode="error"): - """ - Saves the content of the :class:`DataFrame` in Parquet format at the - specified path. + """Saves the content of the :class:`DataFrame` in Parquet format at the specified path. - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. - :param path: the path in any Hadoop supported file system - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - return self._jwrite.mode(mode).parquet(path) + self._jwrite.mode(mode).parquet(path) @since(1.4) def jdbc(self, url, table, mode="error", properties={}): - """ - Saves the content of the :class:`DataFrame` to a external database table - via JDBC. - - In the case the table already exists in the external database, - behavior of this function depends on the save mode, specified by the `mode` - function (default to throwing an exception). There are four modes: + """Saves the content of the :class:`DataFrame` to a external database table via JDBC. - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + .. note:: Don't create too many partitions in parallel on a large cluster;\ + otherwise Spark might crash your external database systems. - :param url: a JDBC URL of the form `jdbc:subprotocol:subname` + :param url: a JDBC URL of the form ``jdbc:subprotocol:subname`` :param table: Name of the table in the external database. - :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. :param properties: JDBC database connection arguments, a list of - arbitrary string tag/value. Normally at least a - "user" and "password" property should be included. + arbitrary string tag/value. Normally at least a + "user" and "password" property should be included. """ jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: @@ -331,24 +378,23 @@ def jdbc(self, url, table, mode="error", properties={}): def _test(): import doctest + import os + import tempfile from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.readwriter + + os.chdir(os.environ["SPARK_HOME"]) + globs = pyspark.sql.readwriter.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') + + globs['tempfile'] = tempfile + globs['os'] = os globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ - .toDF(StructType([StructField('age', IntegerType()), - StructField('name', StringType())])) - jsonStrings = [ - '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' - ] - globs['jsonStrings'] = jsonStrings + globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') + (failure_count, test_count) = doctest.testmod( pyspark.sql.readwriter, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5c53c3a8ed4f1..a6fce50c76c2b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -100,6 +100,15 @@ def test_data_type_eq(self): lt2 = pickle.loads(pickle.dumps(LongType())) self.assertEquals(lt, lt2) + # regression test for SPARK-7978 + def test_decimal_type(self): + t1 = DecimalType() + t2 = DecimalType(10, 2) + self.assertTrue(t2 is not t1) + self.assertNotEqual(t1, t2) + t3 = DecimalType(8) + self.assertNotEqual(t2, t3) + class SQLTests(ReusedPySparkTestCase): @@ -122,6 +131,8 @@ def test_range(self): self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2) + self.assertEqual(self.sqlCtx.range(-2).count(), 0) + self.assertEqual(self.sqlCtx.range(3).count(), 3) def test_explode(self): from pyspark.sql.functions import explode @@ -744,8 +755,10 @@ def setUpClass(cls): try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: + cls.tearDownClass() raise unittest.SkipTest("Hive is not available") except TypeError: + cls.tearDownClass() raise unittest.SkipTest("Hive is not available") os.unlink(cls.tempdir.name) _scala_HiveContext =\ diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/types.py similarity index 99% rename from python/pyspark/sql/_types.py rename to python/pyspark/sql/types.py index 9e7e9f04bc35d..b6ec6137c9180 100644 --- a/python/pyspark/sql/_types.py +++ b/python/pyspark/sql/types.py @@ -97,8 +97,6 @@ class AtomicType(DataType): """An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps.""" - __metaclass__ = DataTypeSingleton - class NumericType(AtomicType): """Numeric data types. @@ -109,6 +107,8 @@ class IntegralType(NumericType): """Integral data types. """ + __metaclass__ = DataTypeSingleton + class FractionalType(NumericType): """Fractional data types. @@ -119,26 +119,36 @@ class StringType(AtomicType): """String data type. """ + __metaclass__ = DataTypeSingleton + class BinaryType(AtomicType): """Binary (byte array) data type. """ + __metaclass__ = DataTypeSingleton + class BooleanType(AtomicType): """Boolean data type. """ + __metaclass__ = DataTypeSingleton + class DateType(AtomicType): """Date (datetime.date) data type. """ + __metaclass__ = DataTypeSingleton + class TimestampType(AtomicType): """Timestamp (datetime.datetime) data type. """ + __metaclass__ = DataTypeSingleton + class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. @@ -172,11 +182,15 @@ class DoubleType(FractionalType): """Double data type, representing double precision floats. """ + __metaclass__ = DataTypeSingleton + class FloatType(FractionalType): """Float data type, representing single precision floats. """ + __metaclass__ = DataTypeSingleton + class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte. diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 0a0e006bdf83a..c74745c726a0c 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -32,7 +32,6 @@ def _to_java_cols(cols): class Window(object): - """ Utility functions for defining window in DataFrames. diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 33ea8c9293d74..57049beea4dba 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -41,8 +41,8 @@ class PySparkStreamingTestCase(unittest.TestCase): - timeout = 4 # seconds - duration = .2 + timeout = 10 # seconds + duration = .5 @classmethod def setUpClass(cls): @@ -379,13 +379,13 @@ def func(dstream): class WindowFunctionTests(PySparkStreamingTestCase): - timeout = 5 + timeout = 15 def test_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.window(.6, .2).count() + return dstream.window(1.5, .5).count() expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -394,7 +394,7 @@ def test_count_by_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.countByWindow(.6, .2) + return dstream.countByWindow(1.5, .5) expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -403,7 +403,7 @@ def test_count_by_window_large(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByWindow(1, .2) + return dstream.countByWindow(2.5, .5) expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] self._test_func(input, func, expected) @@ -412,7 +412,7 @@ def test_count_by_value_and_window(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByValueAndWindow(1, .2) + return dstream.countByValueAndWindow(2.5, .5) expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] self._test_func(input, func, expected) @@ -421,7 +421,7 @@ def test_group_by_key_and_window(self): input = [[('a', i)] for i in range(5)] def func(dstream): - return dstream.groupByKeyAndWindow(.6, .2).mapValues(list) + return dstream.groupByKeyAndWindow(1.5, .5).mapValues(list) expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] @@ -615,7 +615,6 @@ def test_kafka_stream(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, @@ -631,7 +630,6 @@ def test_kafka_direct_stream(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) self._validateStreamResult(sendData, stream) @@ -646,7 +644,6 @@ def test_kafka_direct_stream_from_offset(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) self._validateStreamResult(sendData, stream) @@ -661,7 +658,6 @@ def test_kafka_rdd(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) self._validateRddResult(sendData, rdd) @@ -677,7 +673,6 @@ def test_kafka_rdd_with_leaders(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) self._validateRddResult(sendData, rdd) diff --git a/python/run-tests b/python/run-tests index ffde2fb24b369..4468fdb3f267e 100755 --- a/python/run-tests +++ b/python/run-tests @@ -57,54 +57,57 @@ function run_test() { function run_core_tests() { echo "Run core tests ..." - run_test "pyspark/rdd.py" - run_test "pyspark/context.py" - run_test "pyspark/conf.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" - run_test "pyspark/serializers.py" - run_test "pyspark/profiler.py" - run_test "pyspark/shuffle.py" - run_test "pyspark/tests.py" + run_test "pyspark.rdd" + run_test "pyspark.context" + run_test "pyspark.conf" + run_test "pyspark.broadcast" + run_test "pyspark.accumulators" + run_test "pyspark.serializers" + run_test "pyspark.profiler" + run_test "pyspark.shuffle" + run_test "pyspark.tests" } function run_sql_tests() { echo "Run sql tests ..." - run_test "pyspark/sql/_types.py" - run_test "pyspark/sql/context.py" - run_test "pyspark/sql/column.py" - run_test "pyspark/sql/dataframe.py" - run_test "pyspark/sql/group.py" - run_test "pyspark/sql/functions.py" - run_test "pyspark/sql/tests.py" + run_test "pyspark.sql.types" + run_test "pyspark.sql.context" + run_test "pyspark.sql.column" + run_test "pyspark.sql.dataframe" + run_test "pyspark.sql.group" + run_test "pyspark.sql.functions" + run_test "pyspark.sql.readwriter" + run_test "pyspark.sql.window" + run_test "pyspark.sql.tests" } function run_mllib_tests() { echo "Run mllib tests ..." - run_test "pyspark/mllib/classification.py" - run_test "pyspark/mllib/clustering.py" - run_test "pyspark/mllib/evaluation.py" - run_test "pyspark/mllib/feature.py" - run_test "pyspark/mllib/fpm.py" - run_test "pyspark/mllib/linalg.py" - run_test "pyspark/mllib/rand.py" - run_test "pyspark/mllib/recommendation.py" - run_test "pyspark/mllib/regression.py" - run_test "pyspark/mllib/stat/_statistics.py" - run_test "pyspark/mllib/tree.py" - run_test "pyspark/mllib/util.py" - run_test "pyspark/mllib/tests.py" + run_test "pyspark.mllib.classification" + run_test "pyspark.mllib.clustering" + run_test "pyspark.mllib.evaluation" + run_test "pyspark.mllib.feature" + run_test "pyspark.mllib.fpm" + run_test "pyspark.mllib.linalg" + run_test "pyspark.mllib.random" + run_test "pyspark.mllib.recommendation" + run_test "pyspark.mllib.regression" + run_test "pyspark.mllib.stat._statistics" + run_test "pyspark.mllib.stat.KernelDensity" + run_test "pyspark.mllib.tree" + run_test "pyspark.mllib.util" + run_test "pyspark.mllib.tests" } function run_ml_tests() { echo "Run ml tests ..." - run_test "pyspark/ml/feature.py" - run_test "pyspark/ml/classification.py" - run_test "pyspark/ml/recommendation.py" - run_test "pyspark/ml/regression.py" - run_test "pyspark/ml/tuning.py" - run_test "pyspark/ml/tests.py" - run_test "pyspark/ml/evaluation.py" + run_test "pyspark.ml.feature" + run_test "pyspark.ml.classification" + run_test "pyspark.ml.recommendation" + run_test "pyspark.ml.regression" + run_test "pyspark.ml.tuning" + run_test "pyspark.ml.tests" + run_test "pyspark.ml.evaluation" } function run_streaming_tests() { @@ -124,8 +127,8 @@ function run_streaming_tests() { done export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell" - run_test "pyspark/streaming/util.py" - run_test "pyspark/streaming/tests.py" + run_test "pyspark.streaming.util" + run_test "pyspark.streaming.tests" } echo "Running PySpark tests. Output is in python/$LOG_FILE." diff --git a/python/test_support/sql/parquet_partitioned/_SUCCESS b/python/test_support/sql/parquet_partitioned/_SUCCESS new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python/test_support/sql/parquet_partitioned/_common_metadata b/python/test_support/sql/parquet_partitioned/_common_metadata new file mode 100644 index 0000000000000..7ef2320651dee Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_common_metadata differ diff --git a/python/test_support/sql/parquet_partitioned/_metadata b/python/test_support/sql/parquet_partitioned/_metadata new file mode 100644 index 0000000000000..78a1ca7d38279 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_metadata differ diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc new file mode 100644 index 0000000000000..e93f42ed6f350 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet new file mode 100644 index 0000000000000..461c382937ecd Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc new file mode 100644 index 0000000000000..b63c4d6d1e1dc Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc new file mode 100644 index 0000000000000..5bc0ebd713563 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet new file mode 100644 index 0000000000000..62a63915beac2 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet new file mode 100644 index 0000000000000..67665a7b55da6 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc new file mode 100644 index 0000000000000..ae94a15d08c81 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet new file mode 100644 index 0000000000000..6cb8538aa8904 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc new file mode 100644 index 0000000000000..58d9bb5fc5883 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet new file mode 100644 index 0000000000000..9b00805481e7b Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet differ diff --git a/python/test_support/sql/people.json b/python/test_support/sql/people.json new file mode 100644 index 0000000000000..50a859cbd7ee8 --- /dev/null +++ b/python/test_support/sql/people.json @@ -0,0 +1,3 @@ +{"name":"Michael"} +{"name":"Andy", "age":30} +{"name":"Justin", "age":19} diff --git a/repl/pom.xml b/repl/pom.xml index 03053b4c3b287..85f7bc8ac1024 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -48,6 +48,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-bagel_${scala.binary.version} diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 934daaeaafca1..50fd43a418bca 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -22,13 +22,12 @@ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.util.Utils -class ReplSuite extends FunSuite { +class ReplSuite extends SparkFunSuite { def runInterpreter(master: String, input: String): String = { val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 14f5e9ed4f25e..9ecc7c229e38a 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -24,14 +24,13 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.tools.nsc.interpreter.SparkILoop -import org.scalatest.FunSuite import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.util.Utils -class ReplSuite extends FunSuite { +class ReplSuite extends SparkFunSuite { def runInterpreter(master: String, input: String): String = { val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index c709cde740748..a58eda12b1120 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -25,7 +25,6 @@ import scala.language.implicitConversions import scala.language.postfixOps import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite import org.scalatest.concurrent.Interruptor import org.scalatest.concurrent.Timeouts._ import org.scalatest.mock.MockitoSugar @@ -35,7 +34,7 @@ import org.apache.spark._ import org.apache.spark.util.Utils class ExecutorClassLoaderSuite - extends FunSuite + extends SparkFunSuite with BeforeAndAfterAll with MockitoSugar with Logging { diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 7168d5b2a8e26..d6f927b6fa803 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -14,25 +14,41 @@ ~ See the License for the specific language governing permissions and ~ limitations under the License. --> - - - - - - + - Scalastyle standard configuration - - - - - - - - - Scalastyle standard configuration + + + + + + + + + + - - - - - - - - - - true - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + + + + ^FunSuite[A-Za-z]*$ + Tests must extend org.apache.spark.SparkFunSuite instead. + + + + + + + + + ^println$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 800> + + + + + 30 + + + + + 10 + + + + + 50 + + + + + + + + + + + -1,0,1,2,3 + + diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 5c322d032d474..f4b1cc3a4ffe7 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -36,10 +36,6 @@ - - org.scala-lang - scala-compiler - org.scala-lang scala-reflect @@ -50,6 +46,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-unsafe_${scala.binary.version} @@ -60,6 +63,11 @@ scalacheck_${scala.binary.version} test + + org.codehaus.janino + janino + 2.7.8 + target/scala-${scala.binary.version}/classes @@ -101,13 +109,6 @@ !scala-2.11 - - - org.scalamacros - quasiquotes_${scala.binary.version} - ${scala.macros.version} - - diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index bb546b3086b33..ec97fe603c44f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,23 +17,25 @@ package org.apache.spark.sql.catalyst.expressions; -import scala.collection.Map; +import javax.annotation.Nullable; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + import scala.collection.Seq; import scala.collection.mutable.ArraySeq; -import javax.annotation.Nullable; -import java.math.BigDecimal; -import java.sql.Date; -import java.util.*; - import org.apache.spark.sql.Row; +import org.apache.spark.sql.BaseMutableRow; import org.apache.spark.sql.types.DataType; -import static org.apache.spark.sql.types.DataTypes.*; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.UTF8String; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.bitset.BitSetMethods; +import static org.apache.spark.sql.types.DataTypes.*; + /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. * @@ -49,7 +51,7 @@ * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ -public final class UnsafeRow implements MutableRow { +public final class UnsafeRow extends BaseMutableRow { private Object baseObject; private long baseOffset; @@ -227,21 +229,11 @@ public int size() { return numFields; } - @Override - public int length() { - return size(); - } - @Override public StructType schema() { return schema; } - @Override - public Object apply(int i) { - return get(i); - } - @Override public Object get(int i) { assertIndexIsValid(i); @@ -339,60 +331,7 @@ public String getString(int i) { return getUTF8String(i).toString(); } - @Override - public BigDecimal getDecimal(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(int i) { - throw new UnsupportedOperationException(); - } - @Override - public Seq getSeq(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public List getList(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Map getMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { - throw new UnsupportedOperationException(); - } - - @Override - public java.util.Map getJavaMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Row getStruct(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(String fieldName) { - throw new UnsupportedOperationException(); - } - - @Override - public int fieldIndex(String name) { - throw new UnsupportedOperationException(); - } @Override public Row copy() { @@ -412,24 +351,4 @@ public Seq toSeq() { } return values; } - - @Override - public String toString() { - return mkString("[", ",", "]"); - } - - @Override - public String mkString() { - return toSeq().mkString(); - } - - @Override - public String mkString(String sep) { - return toSeq().mkString(sep); - } - - @Override - public String mkString(String start, String sep, String end) { - return toSeq().mkString(start, sep, end); - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java new file mode 100644 index 0000000000000..acec2bf4520f2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql; + +import org.apache.spark.sql.catalyst.expressions.MutableRow; + +public abstract class BaseMutableRow extends BaseRow implements MutableRow { + + @Override + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setInt(int ordinal, int value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setLong(int ordinal, long value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setDouble(int ordinal, double value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setShort(int ordinal, short value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setByte(int ordinal, byte value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setFloat(int ordinal, float value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setString(int ordinal, String value) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java new file mode 100644 index 0000000000000..d138b43a3482b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql; + +import java.math.BigDecimal; +import java.sql.Date; +import java.util.List; + +import scala.collection.Seq; +import scala.collection.mutable.ArraySeq; + +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.StructType; + +public abstract class BaseRow implements Row { + + @Override + final public int length() { + return size(); + } + + @Override + public boolean anyNull() { + final int n = size(); + for (int i=0; i < n; i++) { + if (isNullAt(i)) { + return true; + } + } + return false; + } + + @Override + public StructType schema() { throw new UnsupportedOperationException(); } + + @Override + final public Object apply(int i) { + return get(i); + } + + @Override + public int getInt(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getBoolean(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public String getString(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public BigDecimal getDecimal(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Seq getSeq(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public List getList(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public scala.collection.Map getMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { + throw new UnsupportedOperationException(); + } + + @Override + public java.util.Map getJavaMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Row getStruct(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(String fieldName) { + throw new UnsupportedOperationException(); + } + + @Override + public int fieldIndex(String name) { + throw new UnsupportedOperationException(); + } + + @Override + public Row copy() { + final int n = size(); + Object[] arr = new Object[n]; + for (int i = 0; i < n; i++) { + arr[i] = get(i); + } + return new GenericRow(arr); + } + + @Override + public Seq toSeq() { + final int n = size(); + final ArraySeq values = new ArraySeq(n); + for (int i = 0; i < n; i++) { + values.update(i, get(i)); + } + return values; + } + + @Override + public String toString() { + return mkString("[", ",", "]"); + } + + @Override + public String mkString() { + return toSeq().mkString(); + } + + @Override + public String mkString(String sep) { + return toSeq().mkString(sep); + } + + @Override + public String mkString(String start, String sep, String end) { + return toSeq().mkString(start, sep, end); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 2eb3e167baad5..ef7b3ad9432cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -103,7 +103,7 @@ class SqlLexical extends StdLexical { ( identChar ~ (identChar | digit).* ^^ { case first ~ rest => processIdent((first :: rest).mkString) } | rep1(digit) ~ ('.' ~> digit.*).? ^^ { - case i ~ None => NumericLit(i.mkString) + case i ~ None => NumericLit(i.mkString) case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) } | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 75a493b248f6e..2e7b4c236d8f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -18,7 +18,10 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} +import java.math.{BigDecimal => JavaBigDecimal} +import java.sql.Date import java.util.{Map => JavaMap} +import javax.annotation.Nullable import scala.collection.mutable.HashMap @@ -34,197 +37,338 @@ object CatalystTypeConverters { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map + private def isPrimitive(dataType: DataType): Boolean = { + dataType match { + case BooleanType => true + case ByteType => true + case ShortType => true + case IntegerType => true + case LongType => true + case FloatType => true + case DoubleType => true + case _ => false + } + } + + private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = { + val converter = dataType match { + case udt: UserDefinedType[_] => UDTConverter(udt) + case arrayType: ArrayType => ArrayConverter(arrayType.elementType) + case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType) + case structType: StructType => StructConverter(structType) + case StringType => StringConverter + case DateType => DateConverter + case dt: DecimalType => BigDecimalConverter + case BooleanType => BooleanConverter + case ByteType => ByteConverter + case ShortType => ShortConverter + case IntegerType => IntConverter + case LongType => LongConverter + case FloatType => FloatConverter + case DoubleType => DoubleConverter + case _ => IdentityConverter + } + converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] + } + /** - * Converts Scala objects to catalyst rows / types. This method is slow, and for batch - * conversion you should be using converter produced by createToCatalystConverter. - * Note: This is always called after schemaFor has been called. - * This ordering is important for UDT registration. + * Converts a Scala type to its Catalyst equivalent (and vice versa). + * + * @tparam ScalaInputType The type of Scala values that can be converted to Catalyst. + * @tparam ScalaOutputType The type of Scala values returned when converting Catalyst to Scala. + * @tparam CatalystType The internal Catalyst type used to represent values of this Scala type. */ - def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (obj, udt: UserDefinedType[_]) => - udt.serialize(obj) - - case (o: Option[_], _) => - o.map(convertToCatalyst(_, dataType)).orNull - - case (s: Seq[_], arrayType: ArrayType) => - s.map(convertToCatalyst(_, arrayType.elementType)) - - case (jit: JavaIterable[_], arrayType: ArrayType) => { - val iter = jit.iterator - var listOfItems: List[Any] = List() - while (iter.hasNext) { - val item = iter.next() - listOfItems :+= convertToCatalyst(item, arrayType.elementType) + private abstract class CatalystTypeConverter[ScalaInputType, ScalaOutputType, CatalystType] + extends Serializable { + + /** + * Converts a Scala type to its Catalyst equivalent while automatically handling nulls + * and Options. + */ + final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = { + if (maybeScalaValue == null) { + null.asInstanceOf[CatalystType] + } else if (maybeScalaValue.isInstanceOf[Option[ScalaInputType]]) { + val opt = maybeScalaValue.asInstanceOf[Option[ScalaInputType]] + if (opt.isDefined) { + toCatalystImpl(opt.get) + } else { + null.asInstanceOf[CatalystType] + } + } else { + toCatalystImpl(maybeScalaValue.asInstanceOf[ScalaInputType]) } - listOfItems } - case (s: Array[_], arrayType: ArrayType) => - s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + */ + final def toScala(row: Row, column: Int): ScalaOutputType = { + if (row.isNullAt(column)) null.asInstanceOf[ScalaOutputType] else toScalaImpl(row, column) + } - case (m: Map[_, _], mapType: MapType) => - m.map { case (k, v) => - convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) - } + /** + * Convert a Catalyst value to its Scala equivalent. + */ + def toScala(@Nullable catalystValue: CatalystType): ScalaOutputType + + /** + * Converts a Scala value to its Catalyst equivalent. + * @param scalaValue the Scala value, guaranteed not to be null. + * @return the Catalyst value. + */ + protected def toCatalystImpl(scalaValue: ScalaInputType): CatalystType + + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + * This method will only be called on non-null columns. + */ + protected def toScalaImpl(row: Row, column: Int): ScalaOutputType + } + + private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] { + override def toCatalystImpl(scalaValue: Any): Any = scalaValue + override def toScala(catalystValue: Any): Any = catalystValue + override def toScalaImpl(row: Row, column: Int): Any = row(column) + } + + private case class UDTConverter( + udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) + override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) + override def toScalaImpl(row: Row, column: Int): Any = toScala(row(column)) + } - case (jmap: JavaMap[_, _], mapType: MapType) => - val iter = jmap.entrySet.iterator - var listOfEntries: List[(Any, Any)] = List() - while (iter.hasNext) { - val entry = iter.next() - listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType), - convertToCatalyst(entry.getValue, mapType.valueType)) + /** Converter for arrays, sequences, and Java iterables. */ + private case class ArrayConverter( + elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] { + + private[this] val elementConverter = getConverterForType(elementType) + + override def toCatalystImpl(scalaValue: Any): Seq[Any] = { + scalaValue match { + case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) + case s: Seq[_] => s.map(elementConverter.toCatalyst) + case i: JavaIterable[_] => + val iter = i.iterator + var convertedIterable: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + convertedIterable :+= elementConverter.toCatalyst(item) + } + convertedIterable } - listOfEntries.toMap - - case (p: Product, structType: StructType) => - val ar = new Array[Any](structType.size) - val iter = p.productIterator - var idx = 0 - while (idx < structType.size) { - ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType) - idx += 1 + } + + override def toScala(catalystValue: Seq[Any]): Seq[Any] = { + if (catalystValue == null) { + null + } else { + catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala) } - new GenericRowWithSchema(ar, structType) + } + + override def toScalaImpl(row: Row, column: Int): Seq[Any] = + toScala(row(column).asInstanceOf[Seq[Any]]) + } - case (d: String, _) => - UTF8String(d) + private case class MapConverter( + keyType: DataType, + valueType: DataType) + extends CatalystTypeConverter[Any, Map[Any, Any], Map[Any, Any]] { - case (d: BigDecimal, _) => - Decimal(d) + private[this] val keyConverter = getConverterForType(keyType) + private[this] val valueConverter = getConverterForType(valueType) - case (d: java.math.BigDecimal, _) => - Decimal(d) + override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match { + case m: Map[_, _] => + m.map { case (k, v) => + keyConverter.toCatalyst(k) -> valueConverter.toCatalyst(v) + } - case (d: java.sql.Date, _) => - DateUtils.fromJavaDate(d) + case jmap: JavaMap[_, _] => + val iter = jmap.entrySet.iterator + val convertedMap: HashMap[Any, Any] = HashMap() + while (iter.hasNext) { + val entry = iter.next() + val key = keyConverter.toCatalyst(entry.getKey) + convertedMap(key) = valueConverter.toCatalyst(entry.getValue) + } + convertedMap + } - case (r: Row, structType: StructType) => - val converters = structType.fields.map { - f => (item: Any) => convertToCatalyst(item, f.dataType) + override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = { + if (catalystValue == null) { + null + } else { + catalystValue.map { case (k, v) => + keyConverter.toScala(k) -> valueConverter.toScala(v) + } } - convertRowWithConverters(r, structType, converters) + } - case (other, _) => - other + override def toScalaImpl(row: Row, column: Int): Map[Any, Any] = + toScala(row(column).asInstanceOf[Map[Any, Any]]) } - /** - * Creates a converter function that will convert Scala objects to the specified catalyst type. - * Typical use case would be converting a collection of rows that have the same schema. You will - * call this function once to get a converter, and apply it to every row. - */ - private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { - def extractOption(item: Any): Any = item match { - case opt: Option[_] => opt.orNull - case other => other - } + private case class StructConverter( + structType: StructType) extends CatalystTypeConverter[Any, Row, Row] { - dataType match { - // Check UDT first since UDTs can override other types - case udt: UserDefinedType[_] => - (item) => extractOption(item) match { - case null => null - case other => udt.serialize(other) - } + private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) } - case arrayType: ArrayType => - val elementConverter = createToCatalystConverter(arrayType.elementType) - (item: Any) => { - extractOption(item) match { - case a: Array[_] => a.toSeq.map(elementConverter) - case s: Seq[_] => s.map(elementConverter) - case i: JavaIterable[_] => { - val iter = i.iterator - var convertedIterable: List[Any] = List() - while (iter.hasNext) { - val item = iter.next() - convertedIterable :+= elementConverter(item) - } - convertedIterable - } - case null => null - } + override def toCatalystImpl(scalaValue: Any): Row = scalaValue match { + case row: Row => + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx).toCatalyst(row(idx)) + idx += 1 } - - case mapType: MapType => - val keyConverter = createToCatalystConverter(mapType.keyType) - val valueConverter = createToCatalystConverter(mapType.valueType) - (item: Any) => { - extractOption(item) match { - case m: Map[_, _] => - m.map { case (k, v) => - keyConverter(k) -> valueConverter(v) - } - - case jmap: JavaMap[_, _] => - val iter = jmap.entrySet.iterator - val convertedMap: HashMap[Any, Any] = HashMap() - while (iter.hasNext) { - val entry = iter.next() - convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue) - } - convertedMap - - case null => null - } + new GenericRowWithSchema(ar, structType) + + case p: Product => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx).toCatalyst(iter.next()) + idx += 1 } + new GenericRowWithSchema(ar, structType) + } - case structType: StructType => - val converters = structType.fields.map(f => createToCatalystConverter(f.dataType)) - (item: Any) => { - extractOption(item) match { - case r: Row => - convertRowWithConverters(r, structType, converters) - - case p: Product => - val ar = new Array[Any](structType.size) - val iter = p.productIterator - var idx = 0 - while (idx < structType.size) { - ar(idx) = converters(idx)(iter.next()) - idx += 1 - } - new GenericRowWithSchema(ar, structType) - - case null => - null - } + override def toScala(row: Row): Row = { + if (row == null) { + null + } else { + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx).toScala(row, idx) + idx += 1 } - - case dateType: DateType => (item: Any) => extractOption(item) match { - case d: java.sql.Date => DateUtils.fromJavaDate(d) - case other => other + new GenericRowWithSchema(ar, structType) } + } - case dataType: StringType => (item: Any) => extractOption(item) match { - case s: String => UTF8String(s) - case other => other - } + override def toScalaImpl(row: Row, column: Int): Row = toScala(row(column).asInstanceOf[Row]) + } + + private object StringConverter extends CatalystTypeConverter[Any, String, Any] { + override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { + case str: String => UTF8String(str) + case utf8: UTF8String => utf8 + } + override def toScala(catalystValue: Any): String = catalystValue match { + case null => null + case str: String => str + case utf8: UTF8String => utf8.toString() + } + override def toScalaImpl(row: Row, column: Int): String = row(column).toString + } + + private object DateConverter extends CatalystTypeConverter[Date, Date, Any] { + override def toCatalystImpl(scalaValue: Date): Int = DateUtils.fromJavaDate(scalaValue) + override def toScala(catalystValue: Any): Date = + if (catalystValue == null) null else DateUtils.toJavaDate(catalystValue.asInstanceOf[Int]) + override def toScalaImpl(row: Row, column: Int): Date = toScala(row.getInt(column)) + } - case _ => - (item: Any) => extractOption(item) match { - case d: BigDecimal => Decimal(d) - case d: java.math.BigDecimal => Decimal(d) - case other => other + private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { + case d: BigDecimal => Decimal(d) + case d: JavaBigDecimal => Decimal(d) + case d: Decimal => d + } + override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal + override def toScalaImpl(row: Row, column: Int): JavaBigDecimal = row.get(column) match { + case d: JavaBigDecimal => d + case d: Decimal => d.toJavaBigDecimal + } + } + + private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { + final override def toScala(catalystValue: Any): Any = catalystValue + final override def toCatalystImpl(scalaValue: T): Any = scalaValue + } + + private object BooleanConverter extends PrimitiveConverter[Boolean] { + override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column) + } + + private object ByteConverter extends PrimitiveConverter[Byte] { + override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column) + } + + private object ShortConverter extends PrimitiveConverter[Short] { + override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column) + } + + private object IntConverter extends PrimitiveConverter[Int] { + override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column) + } + + private object LongConverter extends PrimitiveConverter[Long] { + override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column) + } + + private object FloatConverter extends PrimitiveConverter[Float] { + override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column) + } + + private object DoubleConverter extends PrimitiveConverter[Double] { + override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column) + } + + /** + * Converts Scala objects to catalyst rows / types. This method is slow, and for batch + * conversion you should be using converter produced by createToCatalystConverter. + * Note: This is always called after schemaFor has been called. + * This ordering is important for UDT registration. + */ + def convertToCatalyst(scalaValue: Any, dataType: DataType): Any = { + getConverterForType(dataType).toCatalyst(scalaValue) + } + + /** + * Creates a converter function that will convert Scala objects to the specified Catalyst type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { + if (isPrimitive(dataType)) { + // Although the `else` branch here is capable of handling inbound conversion of primitives, + // we add some special-case handling for those types here. The motivation for this relates to + // Java method invocation costs: if we have rows that consist entirely of primitive columns, + // then returning the same conversion function for all of the columns means that the call site + // will be monomorphic instead of polymorphic. In microbenchmarks, this actually resulted in + // a measurable performance impact. Note that this optimization will be unnecessary if we + // use code generation to construct Scala Row -> Catalyst Row converters. + def convert(maybeScalaValue: Any): Any = { + if (maybeScalaValue.isInstanceOf[Option[Any]]) { + maybeScalaValue.asInstanceOf[Option[Any]].orNull + } else { + maybeScalaValue } + } + convert + } else { + getConverterForType(dataType).toCatalyst } } /** - * Converts Scala objects to catalyst rows / types. + * Converts Scala objects to Catalyst rows / types. * * Note: This should be called before do evaluation on Row * (It does not support UDT) * This is used to create an RDD or test results with correct types for Catalyst. */ def convertToCatalyst(a: Any): Any = a match { - case s: String => UTF8String(s) - case d: java.sql.Date => DateUtils.fromJavaDate(d) - case d: BigDecimal => Decimal(d) - case d: java.math.BigDecimal => Decimal(d) + case s: String => StringConverter.toCatalyst(s) + case d: Date => DateConverter.toCatalyst(d) + case d: BigDecimal => BigDecimalConverter.toCatalyst(d) + case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) case seq: Seq[Any] => seq.map(convertToCatalyst) case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray @@ -233,38 +377,13 @@ object CatalystTypeConverters { case other => other } - /** + /** * Converts Catalyst types used internally in rows to standard Scala types * This method is slow, and for batch conversion you should be using converter * produced by createToScalaConverter. */ - def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (d, udt: UserDefinedType[_]) => - udt.deserialize(d) - - case (s: Seq[_], arrayType: ArrayType) => - s.map(convertToScala(_, arrayType.elementType)) - - case (m: Map[_, _], mapType: MapType) => - m.map { case (k, v) => - convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) - } - - case (r: Row, s: StructType) => - convertRowToScala(r, s) - - case (d: Decimal, _: DecimalType) => - d.toJavaBigDecimal - - case (i: Int, DateType) => - DateUtils.toJavaDate(i) - - case (s: UTF8String, StringType) => - s.toString() - - case (other, _) => - other + def convertToScala(catalystValue: Any, dataType: DataType): Any = { + getConverterForType(dataType).toScala(catalystValue) } /** @@ -272,82 +391,7 @@ object CatalystTypeConverters { * Typical use case would be converting a collection of rows that have the same schema. You will * call this function once to get a converter, and apply it to every row. */ - private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match { - // Check UDT first since UDTs can override other types - case udt: UserDefinedType[_] => - (item: Any) => if (item == null) null else udt.deserialize(item) - - case arrayType: ArrayType => - val elementConverter = createToScalaConverter(arrayType.elementType) - (item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter) - - case mapType: MapType => - val keyConverter = createToScalaConverter(mapType.keyType) - val valueConverter = createToScalaConverter(mapType.valueType) - (item: Any) => if (item == null) { - null - } else { - item.asInstanceOf[Map[_, _]].map { case (k, v) => - keyConverter(k) -> valueConverter(v) - } - } - - case s: StructType => - val converters = s.fields.map(f => createToScalaConverter(f.dataType)) - (item: Any) => { - if (item == null) { - null - } else { - convertRowWithConverters(item.asInstanceOf[Row], s, converters) - } - } - - case _: DecimalType => - (item: Any) => item match { - case d: Decimal => d.toJavaBigDecimal - case other => other - } - - case DateType => - (item: Any) => item match { - case i: Int => DateUtils.toJavaDate(i) - case other => other - } - - case StringType => - (item: Any) => item match { - case s: UTF8String => s.toString() - case other => other - } - - case other => - (item: Any) => item - } - - def convertRowToScala(r: Row, schema: StructType): Row = { - val ar = new Array[Any](r.size) - var idx = 0 - while (idx < r.size) { - ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType) - idx += 1 - } - new GenericRowWithSchema(ar, schema) - } - - /** - * Converts a row by applying the provided set of converter functions. It is used for both - * toScala and toCatalyst conversions. - */ - private[sql] def convertRowWithConverters( - row: Row, - schema: StructType, - converters: Array[Any => Any]): Row = { - val ar = new Array[Any](row.size) - var idx = 0 - while (idx < row.size) { - ar(idx) = converters(idx)(row(idx)) - idx += 1 - } - new GenericRowWithSchema(ar, schema) + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { + getConverterForType(dataType).toScala } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 625c8d3a62125..9a3f9694e4c48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -38,12 +38,21 @@ private [sql] object JavaTypeInference { private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType + /** + * Infers the corresponding SQL data type of a JavaClean class. + * @param beanClass Java type + * @return (SQL data type, nullable) + */ + def inferDataType(beanClass: Class[_]): (DataType, Boolean) = { + inferDataType(TypeToken.of(beanClass)) + } + /** * Infers the corresponding SQL data type of a Java type. * @param typeToken Java type * @return (SQL data type, nullable) */ - private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index fc36b9f1f20d2..f74c17d583359 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import scala.language.implicitConversions +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -48,26 +49,21 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ABS = Keyword("ABS") protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") protected val APPROXIMATE = Keyword("APPROXIMATE") protected val AS = Keyword("AS") protected val ASC = Keyword("ASC") - protected val AVG = Keyword("AVG") protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") - protected val COALESCE = Keyword("COALESCE") - protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") protected val FALSE = Keyword("FALSE") - protected val FIRST = Keyword("FIRST") protected val FROM = Keyword("FROM") protected val FULL = Keyword("FULL") protected val GROUP = Keyword("GROUP") @@ -80,13 +76,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val INTO = Keyword("INTO") protected val IS = Keyword("IS") protected val JOIN = Keyword("JOIN") - protected val LAST = Keyword("LAST") protected val LEFT = Keyword("LEFT") protected val LIKE = Keyword("LIKE") protected val LIMIT = Keyword("LIMIT") - protected val LOWER = Keyword("LOWER") - protected val MAX = Keyword("MAX") - protected val MIN = Keyword("MIN") protected val NOT = Keyword("NOT") protected val NULL = Keyword("NULL") protected val ON = Keyword("ON") @@ -100,15 +92,10 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val RLIKE = Keyword("RLIKE") protected val SELECT = Keyword("SELECT") protected val SEMI = Keyword("SEMI") - protected val SQRT = Keyword("SQRT") - protected val SUBSTR = Keyword("SUBSTR") - protected val SUBSTRING = Keyword("SUBSTRING") - protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") protected val THEN = Keyword("THEN") protected val TRUE = Keyword("TRUE") protected val UNION = Keyword("UNION") - protected val UPPER = Keyword("UPPER") protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") protected val WITH = Keyword("WITH") @@ -140,7 +127,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { (HAVING ~> expression).? ~ sortType.? ~ (LIMIT ~> expression).? ^^ { - case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => + case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => val base = r.getOrElse(OneRowRelation) val withFilter = f.map(Filter(_, base)).getOrElse(base) val withProjection = g @@ -212,7 +199,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val ordering: Parser[Seq[SortOrder]] = ( rep1sep(expression ~ direction.? , ",") ^^ { - case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) + case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) } ) @@ -242,7 +229,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | termExpression ~ NOT.? ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { case e ~ not ~ el ~ eu => val betweenExpr: Expression = And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) - not.fold(betweenExpr)(f=> Not(betweenExpr)) + not.fold(betweenExpr)(f => Not(betweenExpr)) } | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } @@ -277,25 +264,36 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { ) protected lazy val function: Parser[Expression] = - ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } - | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } - | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) } - | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) } - | COUNT ~> "(" ~> DISTINCT ~> repsep(expression, ",") <~ ")" ^^ - { case exps => CountDistinct(exps) } - | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^ - { case exp => ApproxCountDistinct(exp) } - | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ - { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) } - | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) } - | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } - | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } - | MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } - | MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } - | UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } - | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } - | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ - { case c ~ t ~ f => If(c, t, f) } + ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName => + if (lexical.normalizeKeyword(udfName) == "count") { + Count(Literal(1)) + } else { + throw new AnalysisException(s"invalid expression $udfName(*)") + } + } + | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ + { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => + lexical.normalizeKeyword(udfName) match { + case "sum" => SumDistinct(exprs.head) + case "count" => CountDistinct(exprs) + } + } + | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => + if (lexical.normalizeKeyword(udfName) == "count") { + ApproxCountDistinct(exp) + } else { + throw new AnalysisException(s"invalid function approximate $udfName") + } + } + | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ + { case s ~ _ ~ udfName ~ _ ~ _ ~ exp => + if (lexical.normalizeKeyword(udfName) == "count") { + ApproxCountDistinct(exp, s.toDouble) + } else { + throw new AnalysisException(s"invalid function approximate($floatLit) $udfName") + } + } | CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ { case casePart ~ altPart ~ elsePart => @@ -304,16 +302,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } ++ elsePart casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches)) } - | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^ - { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } - | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ - { case s ~ p ~ l => Substring(s, p, l) } - | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) } - | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } - | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } - | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ - { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } - ) + ) protected lazy val cast: Parser[Expression] = CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { @@ -365,7 +354,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) - | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } + | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } | primary ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c239e83271615..02b10c444d1a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -235,9 +235,8 @@ class Analyzer( } /** - * Replaces [[UnresolvedAttribute]]s with concrete - * [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's - * children. + * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from + * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -455,14 +454,16 @@ class Analyzer( } /** - * Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]]. + * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { case u @ UnresolvedFunction(name, children) if u.childrenResolved => - registry.lookupFunction(name, children) + withPosition(u) { + registry.lookupFunction(name, children) + } } } } @@ -494,7 +495,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) if aggregate.resolved && containsAggregate(havingCondition) => { - val evaluatedCondition = Alias(havingCondition, "havingCondition")() + val evaluatedCondition = Alias(havingCondition, "havingCondition")() val aggExprsWithHaving = evaluatedCondition +: originalAggExprs Project(aggregate.output, @@ -515,16 +516,15 @@ class Analyzer( * - concrete attribute references for their output. * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). * - * Names for the output [[Attributes]] are extracted from [[Alias]] or [[MultiAlias]] expressions + * Names for the output [[Attribute]]s are extracted from [[Alias]] or [[MultiAlias]] expressions * that wrap the [[Generator]]. If more than one [[Generator]] is found in a Project, an * [[AnalysisException]] is throw. */ object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p: Generate if !p.child.resolved || !p.generator.resolved => p - case g: Generate if g.resolved == false => - g.copy( - generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) + case g: Generate if !g.resolved => + g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) case p @ Project(projectList, child) => // Holds the resolved generator, if one exists in the project list. @@ -634,10 +634,10 @@ class Analyzer( * it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { - def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = + private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = projectList.exists(hasWindowFunction) - def hasWindowFunction(expr: NamedExpression): Boolean = { + private def hasWindowFunction(expr: NamedExpression): Boolean = { expr.find { case window: WindowExpression => true case _ => false @@ -645,14 +645,24 @@ class Analyzer( } /** - * From a Seq of [[NamedExpression]]s, extract window expressions and - * other regular expressions. + * From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and + * other regular expressions that do not contain any window expression. For example, for + * `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract + * `col1`, `col2 + col3`, `col4`, and `col5` out and replace their appearances in + * the window expression as attribute references. So, the first returned value will be + * `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be + * [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2]. + * + * @return (seq of expressions containing at lease one window expressions, + * seq of non-window expressions) */ - def extract( + private def extract( expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = { - // First, we simple partition the input expressions to two part, one having - // WindowExpressions and another one without WindowExpressions. - val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction) + // First, we partition the input expressions to two part. For the first part, + // every expression in it contain at least one WindowExpression. + // Expressions in the second part do not have any WindowExpression. + val (expressionsWithWindowFunctions, regularExpressions) = + expressions.partition(hasWindowFunction) // Then, we need to extract those regular expressions used in the WindowExpression. // For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5), @@ -661,8 +671,8 @@ class Analyzer( val extractedExprBuffer = new ArrayBuffer[NamedExpression]() def extractExpr(expr: Expression): Expression = expr match { case ne: NamedExpression => - // If a named expression is not in regularExpressions, add extract it and replace it - // with an AttributeReference. + // If a named expression is not in regularExpressions, add it to + // extractedExprBuffer and replace it with an AttributeReference. val missingExpr = AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer) if (missingExpr.nonEmpty) { @@ -679,8 +689,9 @@ class Analyzer( withName.toAttribute } - // Now, we extract expressions from windowExpressions by using extractExpr. - val newWindowExpressions = windowExpressions.map { + // Now, we extract regular expressions from expressionsWithWindowFunctions + // by using extractExpr. + val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). @@ -706,37 +717,80 @@ class Analyzer( }.asInstanceOf[NamedExpression] } - (newWindowExpressions, regularExpressions ++ extractedExprBuffer) - } + (newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprBuffer) + } // end of extract /** * Adds operators for Window Expressions. Every Window operator handles a single Window Spec. */ - def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = { - // First, we group window expressions based on their Window Spec. - val groupedWindowExpression = windowExpressions.groupBy { expr => - val windowSpec = expr.collectFirst { + private def addWindow( + expressionsWithWindowFunctions: Seq[NamedExpression], + child: LogicalPlan): LogicalPlan = { + // First, we need to extract all WindowExpressions from expressionsWithWindowFunctions + // and put those extracted WindowExpressions to extractedWindowExprBuffer. + // This step is needed because it is possible that an expression contains multiple + // WindowExpressions with different Window Specs. + // After extracting WindowExpressions, we need to construct a project list to generate + // expressionsWithWindowFunctions based on extractedWindowExprBuffer. + // For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract + // "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to + // "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)". + // Then, the projectList will be [_we0/_we1]. + val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]() + val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { + // We need to use transformDown because we want to trigger + // "case alias @ Alias(window: WindowExpression, _)" first. + _.transformDown { + case alias @ Alias(window: WindowExpression, _) => + // If a WindowExpression has an assigned alias, just use it. + extractedWindowExprBuffer += alias + alias.toAttribute + case window: WindowExpression => + // If there is no alias assigned to the WindowExpressions. We create an + // internal column. + val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")() + extractedWindowExprBuffer += withName + withName.toAttribute + }.asInstanceOf[NamedExpression] + } + + // Second, we group extractedWindowExprBuffer based on their Window Spec. + val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr => + val distinctWindowSpec = expr.collect { case window: WindowExpression => window.windowSpec + }.distinct + + // We do a final check and see if we only have a single Window Spec defined in an + // expressions. + if (distinctWindowSpec.length == 0 ) { + failAnalysis(s"$expr does not have any WindowExpression.") + } else if (distinctWindowSpec.length > 1) { + // newExpressionsWithWindowFunctions only have expressions with a single + // WindowExpression. If we reach here, we have a bug. + failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." + + s"Please file a bug report with this error message, stack trace, and the query.") + } else { + distinctWindowSpec.head } - windowSpec.getOrElse( - failAnalysis(s"$windowExpressions does not have any WindowExpression.")) }.toSeq - // For every Window Spec, we add a Window operator and set currentChild as the child of it. + // Third, for every Window Spec, we add a Window operator and set currentChild as the + // child of it. var currentChild = child var i = 0 - while (i < groupedWindowExpression.size) { - val (windowSpec, windowExpressions) = groupedWindowExpression(i) + while (i < groupedWindowExpressions.size) { + val (windowSpec, windowExpressions) = groupedWindowExpressions(i) // Set currentChild to the newly created Window operator. currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild) - // Move to next WindowExpression. + // Move to next Window Spec. i += 1 } - // We return the top operator. - currentChild - } + // Finally, we create a Project to output currentChild's output + // newExpressionsWithWindowFunctions. + Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild) + } // end of addWindow // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. @@ -793,9 +847,8 @@ class Analyzer( } /** - * Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are - * only required to provide scoping information for attributes and can be removed once analysis is - * complete. + * Removes [[Subquery]] operators from the plan. Subqueries are only required to provide + * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubQueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 208021c421326..1541491608b24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConversions._ import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.EmptyConf @@ -81,18 +85,18 @@ trait Catalog { } class SimpleCatalog(val conf: CatalystConf) extends Catalog { - val tables = new mutable.HashMap[String, LogicalPlan]() + val tables = new ConcurrentHashMap[String, LogicalPlan] override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { val tableIdent = processTableIdentifier(tableIdentifier) - tables += ((getDbTableName(tableIdent), plan)) + tables.put(getDbTableName(tableIdent), plan) } override def unregisterTable(tableIdentifier: Seq[String]): Unit = { val tableIdent = processTableIdentifier(tableIdentifier) - tables -= getDbTableName(tableIdent) + tables.remove(getDbTableName(tableIdent)) } override def unregisterAllTables(): Unit = { @@ -101,10 +105,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { override def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) - tables.get(getDbTableName(tableIdent)) match { - case Some(_) => true - case None => false - } + tables.containsKey(getDbTableName(tableIdent)) } override def lookupRelation( @@ -112,7 +113,10 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { alias: Option[String] = None): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val tableFullName = getDbTableName(tableIdent) - val table = tables.getOrElse(tableFullName, sys.error(s"Table Not Found: $tableFullName")) + val table = tables.get(tableFullName) + if (table == null) { + sys.error(s"Table Not Found: $tableFullName") + } val tableWithQualifiers = Subquery(tableIdent.last, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are @@ -121,9 +125,11 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - tables.map { - case (name, _) => (name, true) - }.toSeq + val result = ArrayBuffer.empty[(String, Boolean)] + for (name <- tables.keySet()) { + result += ((name, true)) + } + result } override def refreshTable(databaseName: String, tableName: String): Unit = { @@ -140,7 +146,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { trait OverrideCatalog extends Catalog { // TODO: This doesn't work when the database changes... - val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]() + val overrides = new mutable.HashMap[(Option[String], String), LogicalPlan]() abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 193dc6b6546b5..c0695ae369421 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -62,15 +62,17 @@ trait CheckAnalysis { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + case e: Expression if e.checkInputDataTypes().isFailure => + e.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + e.failAnalysis( + s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") + } + case c: Cast if !c.resolved => failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - case b: BinaryExpression if !b.resolved => - failAnalysis( - s"invalid expression ${b.prettyString} " + - s"between ${b.left.dataType.simpleString} and ${b.right.dataType.simpleString}") - case WindowExpression(UnresolvedWindowFunction(name, _), _) => failAnalysis( s"Could not resolve window function '$name'. " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 0849faa9bfa7b..406f6fad8413b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,24 +17,27 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.Expression -import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.StringKeyHashMap + /** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ trait FunctionRegistry { - type FunctionBuilder = Seq[Expression] => Expression def registerFunction(name: String, builder: FunctionBuilder): Unit + @throws[AnalysisException]("If function does not exist") def lookupFunction(name: String, children: Seq[Expression]): Expression - - def conf: CatalystConf } trait OverrideFunctionRegistry extends FunctionRegistry { - val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis) + private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false) override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) @@ -45,16 +48,19 @@ trait OverrideFunctionRegistry extends FunctionRegistry { } } -class SimpleFunctionRegistry(val conf: CatalystConf) extends FunctionRegistry { +class SimpleFunctionRegistry extends FunctionRegistry { - val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis) + private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false) override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - functionBuilders(name)(children) + val func = functionBuilders.get(name).getOrElse { + throw new AnalysisException(s"undefined function $name") + } + func(children) } } @@ -70,30 +76,89 @@ object EmptyFunctionRegistry extends FunctionRegistry { override def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } - - override def conf: CatalystConf = throw new UnsupportedOperationException } -/** - * Build a map with String type of key, and it also supports either key case - * sensitive or insensitive. - * TODO move this into util folder? - */ -object StringKeyHashMap { - def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { - case false => new StringKeyHashMap[T](_.toLowerCase) - case true => new StringKeyHashMap[T](identity) - } -} -class StringKeyHashMap[T](normalizer: (String) => String) { - private val base = new collection.mutable.HashMap[String, T]() +object FunctionRegistry { - def apply(key: String): T = base(normalizer(key)) + type FunctionBuilder = Seq[Expression] => Expression - def get(key: String): Option[T] = base.get(normalizer(key)) - def put(key: String, value: T): Option[T] = base.put(normalizer(key), value) - def remove(key: String): Option[T] = base.remove(normalizer(key)) - def iterator: Iterator[(String, T)] = base.toIterator + val expressions: Map[String, FunctionBuilder] = Map( + // Non aggregate functions + expression[Abs]("abs"), + expression[CreateArray]("array"), + expression[Coalesce]("coalesce"), + expression[Explode]("explode"), + expression[Lower]("lower"), + expression[Substring]("substr"), + expression[Substring]("substring"), + expression[Rand]("rand"), + expression[Randn]("randn"), + expression[CreateStruct]("struct"), + expression[Sqrt]("sqrt"), + expression[Upper]("upper"), + + // Math functions + expression[Acos]("acos"), + expression[Asin]("asin"), + expression[Atan]("atan"), + expression[Atan2]("atan2"), + expression[Cbrt]("cbrt"), + expression[Ceil]("ceil"), + expression[Cos]("cos"), + expression[Exp]("exp"), + expression[Expm1]("expm1"), + expression[Floor]("floor"), + expression[Hypot]("hypot"), + expression[Log]("log"), + expression[Log10]("log10"), + expression[Log1p]("log1p"), + expression[Pow]("pow"), + expression[Rint]("rint"), + expression[Signum]("signum"), + expression[Sin]("sin"), + expression[Sinh]("sinh"), + expression[Tan]("tan"), + expression[Tanh]("tanh"), + expression[ToDegrees]("todegrees"), + expression[ToRadians]("toradians"), + + // aggregate functions + expression[Average]("avg"), + expression[Count]("count"), + expression[First]("first"), + expression[Last]("last"), + expression[Max]("max"), + expression[Min]("min"), + expression[Sum]("sum") + ) + + /** See usage above. */ + private def expression[T <: Expression](name: String) + (implicit tag: ClassTag[T]): (String, FunctionBuilder) = { + // Use the companion class to find apply methods. + val objectClass = Class.forName(tag.runtimeClass.getName + "$") + val companionObj = objectClass.getDeclaredField("MODULE$").get(null) + + // See if we can find an apply that accepts Seq[Expression] + val varargApply = Try(objectClass.getDeclaredMethod("apply", classOf[Seq[_]])).toOption + + val builder = (expressions: Seq[Expression]) => { + if (varargApply.isDefined) { + // If there is an apply method that accepts Seq[Expression], use that one. + varargApply.get.invoke(companionObj, expressions).asInstanceOf[Expression] + } else { + // Otherwise, find an apply method that matches the number of arguments, and use that. + val params = Seq.fill(expressions.size)(classOf[Expression]) + val f = Try(objectClass.getDeclaredMethod("apply", params : _*)) match { + case Success(e) => + e + case Failure(e) => + throw new AnalysisException(s"Invalid number of arguments for function $name") + } + f.invoke(companionObj, expressions : _*).asInstanceOf[Expression] + } + } + (name, builder) + } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index b45b17d856fac..737905c3582ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -41,7 +41,7 @@ object HiveTypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. */ - val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -57,6 +57,17 @@ object HiveTypeCoercion { case _ => None } + + /** + * Find the tightest common type of a set of types by continuously applying + * `findTightestCommonTypeOfTwo` on these types. + */ + private def findTightestCommonType(types: Seq[DataType]) = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case None => None + case Some(d) => findTightestCommonTypeOfTwo(d, c) + }) + } } /** @@ -76,8 +87,7 @@ trait HiveTypeCoercion { WidenTypes :: PromoteStrings :: DecimalPrecision :: - BooleanComparisons :: - BooleanCasts :: + BooleanEquality :: StringToIntegralCasts :: FunctionArgumentConversion :: CaseWhenCoercion :: @@ -120,7 +130,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal("NaN") + private val StringNaN = Literal("NaN") def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -128,20 +138,20 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e /* Double Conversions */ - case b: BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType => - b.makeCopy(Array(b.right, Literal(Double.NaN))) - case b: BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN => - b.makeCopy(Array(Literal(Double.NaN), b.left)) - case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => - b.makeCopy(Array(Literal(Double.NaN), b.left)) + case b @ BinaryExpression(StringNaN, right @ DoubleType()) => + b.makeCopy(Array(Literal(Double.NaN), right)) + case b @ BinaryExpression(left @ DoubleType(), StringNaN) => + b.makeCopy(Array(left, Literal(Double.NaN))) /* Float Conversions */ - case b: BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType => - b.makeCopy(Array(b.right, Literal(Float.NaN))) - case b: BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN => - b.makeCopy(Array(Literal(Float.NaN), b.left)) - case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => - b.makeCopy(Array(Literal(Float.NaN), b.left)) + case b @ BinaryExpression(StringNaN, right @ FloatType()) => + b.makeCopy(Array(Literal(Float.NaN), right)) + case b @ BinaryExpression(left @ FloatType(), StringNaN) => + b.makeCopy(Array(left, Literal(Float.NaN))) + + /* Use float NaN by default to avoid unnecessary type widening */ + case b @ BinaryExpression(left @ StringNaN, StringNaN) => + b.makeCopy(Array(left, Literal(Float.NaN))) } } } @@ -174,21 +184,25 @@ trait HiveTypeCoercion { case u @ Union(left, right) if u.childrenResolved && !u.resolved => val castedInput = left.output.zip(right.output).map { // When a string is found on one side, make the other side a string too. - case (l, r) if l.dataType == StringType && r.dataType != StringType => - (l, Alias(Cast(r, StringType), r.name)()) - case (l, r) if l.dataType != StringType && r.dataType == StringType => - (Alias(Cast(l, StringType), l.name)(), r) - - case (l, r) if l.dataType != r.dataType => - logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") - findTightestCommonType(l.dataType, r.dataType).map { widestType => + case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType => + (lhs, Alias(Cast(rhs, StringType), rhs.name)()) + case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType => + (Alias(Cast(lhs, StringType), lhs.name)(), rhs) + + case (lhs, rhs) if lhs.dataType != rhs.dataType => + logDebug(s"Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}") + findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType => val newLeft = - if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() + if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() val newRight = - if (r.dataType == widestType) r else Alias(Cast(r, widestType), r.name)() + if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() (newLeft, newRight) - }.getOrElse((l, r)) // If there is no applicable conversion, leave expression unchanged. + }.getOrElse { + // If there is no applicable conversion, leave expression unchanged. + (lhs, rhs) + } + case other => other } @@ -217,12 +231,10 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b: BinaryExpression if b.left.dataType != b.right.dataType => - findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType => - val newLeft = - if (b.left.dataType == widestType) b.left else Cast(b.left, widestType) - val newRight = - if (b.right.dataType == widestType) b.right else Cast(b.right, widestType) + case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) b.makeCopy(Array(newLeft, newRight)) }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. } @@ -237,57 +249,42 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a: BinaryArithmetic if a.left.dataType == StringType => - a.makeCopy(Array(Cast(a.left, DoubleType), a.right)) - case a: BinaryArithmetic if a.right.dataType == StringType => - a.makeCopy(Array(a.left, Cast(a.right, DoubleType))) + case a @ BinaryArithmetic(left @ StringType(), r) => + a.makeCopy(Array(Cast(left, DoubleType), r)) + case a @ BinaryArithmetic(left, right @ StringType()) => + a.makeCopy(Array(left, Cast(right, DoubleType))) // we should cast all timestamp/date/string compare into string compare - case p: BinaryComparison if p.left.dataType == StringType && - p.right.dataType == DateType => - p.makeCopy(Array(p.left, Cast(p.right, StringType))) - case p: BinaryComparison if p.left.dataType == DateType && - p.right.dataType == StringType => - p.makeCopy(Array(Cast(p.left, StringType), p.right)) - case p: BinaryComparison if p.left.dataType == StringType && - p.right.dataType == TimestampType => - p.makeCopy(Array(Cast(p.left, TimestampType), p.right)) - case p: BinaryComparison if p.left.dataType == TimestampType && - p.right.dataType == StringType => - p.makeCopy(Array(p.left, Cast(p.right, TimestampType))) - case p: BinaryComparison if p.left.dataType == TimestampType && - p.right.dataType == DateType => - p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) - case p: BinaryComparison if p.left.dataType == DateType && - p.right.dataType == TimestampType => - p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) - - case p: BinaryComparison if p.left.dataType == StringType && - p.right.dataType != StringType => - p.makeCopy(Array(Cast(p.left, DoubleType), p.right)) - case p: BinaryComparison if p.left.dataType != StringType && - p.right.dataType == StringType => - p.makeCopy(Array(p.left, Cast(p.right, DoubleType))) - - case i @ In(a, b) if a.dataType == DateType && - b.forall(_.dataType == StringType) => + case p @ BinaryComparison(left @ StringType(), right @ DateType()) => + p.makeCopy(Array(left, Cast(right, StringType))) + case p @ BinaryComparison(left @ DateType(), right @ StringType()) => + p.makeCopy(Array(Cast(left, StringType), right)) + case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => + p.makeCopy(Array(Cast(left, TimestampType), right)) + case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => + p.makeCopy(Array(left, Cast(right, TimestampType))) + case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => + p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) + case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => + p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) + + case p @ BinaryComparison(left @ StringType(), right) if right.dataType != StringType => + p.makeCopy(Array(Cast(left, DoubleType), right)) + case p @ BinaryComparison(left, right @ StringType()) if left.dataType != StringType => + p.makeCopy(Array(left, Cast(right, DoubleType))) + + case i @ In(a @ DateType(), b) if b.forall(_.dataType == StringType) => i.makeCopy(Array(Cast(a, StringType), b)) - case i @ In(a, b) if a.dataType == TimestampType && - b.forall(_.dataType == StringType) => + case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == StringType) => i.makeCopy(Array(a, b.map(Cast(_, TimestampType)))) - case i @ In(a, b) if a.dataType == DateType && - b.forall(_.dataType == TimestampType) => + case i @ In(a @ DateType(), b) if b.forall(_.dataType == TimestampType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) - case i @ In(a, b) if a.dataType == TimestampType && - b.forall(_.dataType == DateType) => + case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == DateType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) - case Sum(e) if e.dataType == StringType => - Sum(Cast(e, DoubleType)) - case Average(e) if e.dataType == StringType => - Average(Cast(e, DoubleType)) - case Sqrt(e) if e.dataType == StringType => - Sqrt(Cast(e, DoubleType)) + case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) + case Average(e @ StringType()) => Average(Cast(e, DoubleType)) + case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType)) } } @@ -297,8 +294,8 @@ trait HiveTypeCoercion { object InConversion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e - + case e if !e.childrenResolved => e + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => i.makeCopy(Array(a, b.map(Cast(_, a.dataType)))) } @@ -350,17 +347,17 @@ trait HiveTypeCoercion { import scala.math.{max, min} // Conversion rules for integer types into fixed-precision decimals - val intTypeToFixed: Map[DataType, DecimalType] = Map( + private val intTypeToFixed: Map[DataType, DecimalType] = Map( ByteType -> DecimalType(3, 0), ShortType -> DecimalType(5, 0), IntegerType -> DecimalType(10, 0), LongType -> DecimalType(20, 0) ) - def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType // Conversion rules for float and double into fixed-precision decimals - val floatTypeToFixed: Map[DataType, DecimalType] = Map( + private val floatTypeToFixed: Map[DataType, DecimalType] = Map( FloatType -> DecimalType(7, 7), DoubleType -> DecimalType(15, 15) ) @@ -369,22 +366,22 @@ trait HiveTypeCoercion { // fix decimal precision for union case u @ Union(left, right) if u.childrenResolved && !u.resolved => val castedInput = left.output.zip(right.output).map { - case (l, r) if l.dataType != r.dataType => - (l.dataType, r.dataType) match { + case (lhs, rhs) if lhs.dataType != rhs.dataType => + (lhs.dataType, rhs.dataType) match { case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) - (Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)()) + (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)()) case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - (Alias(Cast(l, intTypeToFixed(t)), l.name)(), r) + (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs) case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - (l, Alias(Cast(r, intTypeToFixed(t)), r.name)()) + (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)()) case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => - (Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r) + (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs) case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => - (l, Alias(Cast(r, floatTypeToFixed(t)), r.name)()) - case _ => (l, r) + (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)()) + case _ => (lhs, rhs) } case other => other } @@ -442,34 +439,25 @@ trait HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) - case LessThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + // When we compare 2 decimal types with different precisions, cast them to the smallest + // common precision. + case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + val resultType = DecimalType(max(p1, p2), max(s1, s2)) + b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b: BinaryExpression if b.left.dataType != b.right.dataType => - (b.left.dataType, b.right.dataType) match { + case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + (left.dataType, right.dataType) match { case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) + b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) + b.makeCopy(Array(left, Cast(right, intTypeToFixed(t)))) case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + b.makeCopy(Array(left, Cast(right, DoubleType))) case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + b.makeCopy(Array(Cast(left, DoubleType), right)) case _ => b } @@ -483,56 +471,66 @@ trait HiveTypeCoercion { } /** - * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. + * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ - object BooleanComparisons extends Rule[LogicalPlan] { - val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, new java.math.BigDecimal(1)).map(Literal(_)) - val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, new java.math.BigDecimal(0)).map(Literal(_)) + object BooleanEquality extends Rule[LogicalPlan] { + private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1)) + private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal(0)) + + private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { + CaseKeyWhen(numericExpr, Seq( + Literal(trueValues.head), booleanExpr, + Literal(falseValues.head), Not(booleanExpr), + Literal(false))) + } - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e + private def transform(booleanExpr: Expression, numericExpr: Expression) = { + If(Or(IsNull(booleanExpr), IsNull(numericExpr)), + Literal.create(null, BooleanType), + buildCaseKeyWhen(booleanExpr, numericExpr)) + } - // Hive treats (true = 1) as true and (false = 0) as true. - case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l - case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r - case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l) - case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r) - - // No need to change other EqualTo operators as that actually makes sense for boolean types. - case e: EqualTo => e - // No need to change the EqualNullSafe operators, too - case e: EqualNullSafe => e - // Otherwise turn them to Byte types so that there exists and ordering. - case p: BinaryComparison if p.left.dataType == BooleanType && - p.right.dataType == BooleanType => - p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType))) + private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { + CaseWhen(Seq( + And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true), + Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false), + buildCaseKeyWhen(booleanExpr, numericExpr) + )) } - } - /** - * Casts to/from [[BooleanType]] are transformed into comparisons since - * the JVM does not consider Booleans to be numeric types. - */ - object BooleanCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // Skip if the type is boolean type already. Note that this extra cast should be removed - // by optimizer.SimplifyCasts. - case Cast(e, BooleanType) if e.dataType == BooleanType => e - // DateType should be null if be cast to boolean. - case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType) - // If the data type is not boolean and is being cast boolean, turn it into a comparison - // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. - case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) - // Stringify boolean if casting to StringType. - // TODO Ensure true/false string letter casing is consistent with Hive in all cases. - case Cast(e, StringType) if e.dataType == BooleanType => - If(e, Literal("true"), Literal("false")) - // Turn true into 1, and false into 0 if casting boolean into other types. - case Cast(e, dataType) if e.dataType == BooleanType => - Cast(If(e, Literal(1), Literal(0)), dataType) + + // Hive treats (true = 1) as true and (false = 0) as true, + // all other cases are considered as false. + + // We may simplify the expression if one side is literal numeric values + case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => bool + case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => Not(bool) + case EqualTo(Literal(value, _: NumericType), bool @ BooleanType()) + if trueValues.contains(value) => bool + case EqualTo(Literal(value, _: NumericType), bool @ BooleanType()) + if falseValues.contains(value) => Not(bool) + case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => And(IsNotNull(bool), bool) + case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) + case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType()) + if trueValues.contains(value) => And(IsNotNull(bool), bool) + case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType()) + if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) + + case EqualTo(left @ BooleanType(), right @ NumericType()) => + transform(left , right) + case EqualTo(left @ NumericType(), right @ BooleanType()) => + transform(right, left) + case EqualNullSafe(left @ BooleanType(), right @ NumericType()) => + transformNullSafe(left, right) + case EqualNullSafe(left @ NumericType(), right @ BooleanType()) => + transformNullSafe(right, left) } } @@ -561,8 +559,7 @@ trait HiveTypeCoercion { case a @ CreateArray(children) if !a.resolved => val commonType = a.childTypes.reduce( - (a,b) => - findTightestCommonType(a,b).getOrElse(StringType)) + (a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType)) CreateArray( children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) @@ -591,14 +588,9 @@ trait HiveTypeCoercion { // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => - val dt: Option[DataType] = Some(NullType) val types = es.map(_.dataType) - val rt = types.foldLeft(dt)((r, c) => r match { - case None => None - case Some(d) => findTightestCommonType(d, c) - }) - rt match { - case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt))) + findTightestCommonType(types) match { + case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") } @@ -611,19 +603,15 @@ trait HiveTypeCoercion { */ object Division extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e + // Skip nodes who has not been resolved yet, + // as this is an extra rule which should be applied at last. + case e if !e.resolved => e // Decimal and Double remain the same - case d: Divide if d.resolved && d.dataType == DoubleType => d - case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d - - case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] => - Divide(l, Cast(r, DecimalType.Unlimited)) - case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] => - Divide(Cast(l, DecimalType.Unlimited), r) + case d: Divide if d.dataType == DoubleType => d + case d: Divide if d.dataType.isInstanceOf[DecimalType] => d - case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) + case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right, DoubleType)) } } @@ -634,25 +622,33 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual => - logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") - val commonType = cw.valueTypes.reduce { (v1, v2) => - findTightestCommonType(v1, v2).getOrElse(sys.error( - s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) - } - val transformedBranches = cw.branches.sliding(2, 2).map { - case Seq(when, value) if value.dataType != commonType => - Seq(when, Cast(value, commonType)) - case Seq(elseVal) if elseVal.dataType != commonType => - Seq(Cast(elseVal, commonType)) - case s => s - }.reduce(_ ++ _) - cw match { - case _: CaseWhen => - CaseWhen(transformedBranches) - case CaseKeyWhen(key, _) => - CaseKeyWhen(key, transformedBranches) - } + case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => + logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") + val maybeCommonType = findTightestCommonType(c.valueTypes) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { + case Seq(when, value) if value.dataType != commonType => + Seq(when, Cast(value, commonType)) + case Seq(elseVal) if elseVal.dataType != commonType => + Seq(Cast(elseVal, commonType)) + case other => other + }.reduce(_ ++ _) + c match { + case _: CaseWhen => CaseWhen(castedBranches) + case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches) + } + }.getOrElse(c) + + case c: CaseKeyWhen if c.childrenResolved && !c.resolved => + val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType)) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { + case Seq(when, then) if when.dataType != commonType => + Seq(Cast(when, commonType), then) + case other => other + }.reduce(_ ++ _) + CaseKeyWhen(Cast(c.key, commonType), castedBranches) + }.getOrElse(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala new file mode 100644 index 0000000000000..79c3528a522d3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +/** + * Represents the result of `Expression.checkInputDataTypes`. + * We will throw `AnalysisException` in `CheckAnalysis` if `isFailure` is true. + */ +trait TypeCheckResult { + def isFailure: Boolean = !isSuccess + def isSuccess: Boolean +} + +object TypeCheckResult { + + /** + * Represents the successful result of `Expression.checkInputDataTypes`. + */ + object TypeCheckSuccess extends TypeCheckResult { + def isSuccess: Boolean = true + } + + /** + * Represents the failing result of `Expression.checkInputDataTypes`, + * with a error message to show the reason of failure. + */ + case class TypeCheckFailure(message: String) extends TypeCheckResult { + def isSuccess: Boolean = false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 2999c2ef3efe1..bbb150c1e83c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -67,7 +67,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name" @@ -85,7 +85,7 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name(${children.mkString(",")})" @@ -107,7 +107,7 @@ trait Star extends NamedExpression with trees.LeafNode[Expression] { override lazy val resolved = false // Star gets expanded at runtime so we never evaluate a Star. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] @@ -166,7 +166,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) override lazy val resolved = false - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child AS $names" @@ -200,7 +200,7 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child[$extraction]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 307a9ca9b0070..51821757967d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} import scala.language.implicitConversions -import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ @@ -61,7 +60,7 @@ package object dsl { trait ImplicitOperators { def expr: Expression - def unary_- : Expression= UnaryMinus(expr) + def unary_- : Expression = UnaryMinus(expr) def unary_! : Predicate = Not(expr) def unary_~ : Expression = BitwiseNot(expr) @@ -141,7 +140,7 @@ package object dsl { // Note that if we make ExpressionConversions an object rather than a trait, we can // then make this a value class to avoid the small penalty of runtime instantiation. def $(args: Any*): analysis.UnresolvedAttribute = { - analysis.UnresolvedAttribute(sc.s(args :_*)) + analysis.UnresolvedAttribute(sc.s(args : _*)) } } @@ -234,133 +233,59 @@ package object dsl { implicit class DslAttribute(a: AttributeReference) { def notNull: AttributeReference = a.withNullability(false) def nullable: AttributeReference = a.withNullability(true) - - // Protobuf terminology - def required: AttributeReference = a.withNullability(false) - def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable) } } - object expressions extends ExpressionConversions // scalastyle:ignore - abstract class LogicalPlanFunctions { - def logicalPlan: LogicalPlan - - def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) + object plans { // scalastyle:ignore + implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) { + def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) - def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) + def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) - def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) - def join( + def join( otherPlan: LogicalPlan, joinType: JoinType = Inner, condition: Option[Expression] = None): LogicalPlan = - Join(logicalPlan, otherPlan, joinType, condition) + Join(logicalPlan, otherPlan, joinType, condition) - def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) + def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) - def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) + def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { - val aliasedExprs = aggregateExprs.map { - case ne: NamedExpression => ne - case e => Alias(e, e.toString)() + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { + val aliasedExprs = aggregateExprs.map { + case ne: NamedExpression => ne + case e => Alias(e, e.toString)() + } + Aggregate(groupingExprs, aliasedExprs, logicalPlan) } - Aggregate(groupingExprs, aliasedExprs, logicalPlan) - } - - def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) - def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) + def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) - def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) + def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) - def intersect(otherPlan: LogicalPlan): LogicalPlan = Intersect(logicalPlan, otherPlan) + def intersect(otherPlan: LogicalPlan): LogicalPlan = Intersect(logicalPlan, otherPlan) - def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan = - Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) + def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - // TODO specify the output column names - def generate( + // TODO specify the output column names + def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, alias: Option[String] = None): LogicalPlan = - Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) + Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) - def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = - InsertIntoTable( - analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) + def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = + InsertIntoTable( + analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) - def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) - } - - object plans { // scalastyle:ignore - implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions { - def writeToFile(path: String): LogicalPlan = WriteToFile(path, logicalPlan) + def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) } } - - case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { - def call(args: Expression*): ScalaUdf = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) - } - } - - // scalastyle:off - /** functionToUdfBuilder 1-22 were generated by this script - - (1 to 22).map { x => - val argTypes = Seq.fill(x)("_").mkString(", ") - s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func)" - } - */ - - implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - // scalastyle:on } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index 0fd4f9b374ee0..d2a90a50c89f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -49,11 +49,4 @@ package object errors { case e: Exception => throw new TreeNodeException(tree, msg, e) } } - - /** - * Executes `f` which is expected to throw a - * [[catalyst.errors.TreeNodeException TreeNodeException]]. The first tree encountered in - * the stack of exceptions of type `TreeType` is returned. - */ - def getTree[TreeType <: TreeNode[_]](f: => Unit): TreeType = ??? // TODO: Implement } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index c6217f07c452d..fcadf9595e768 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees @@ -30,8 +31,6 @@ import org.apache.spark.sql.catalyst.trees case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends NamedExpression with trees.LeafNode[Expression] { - type EvaluatedType = Any - override def toString: String = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) @@ -43,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + s""" + boolean ${ev.isNull} = i.isNullAt($ordinal); + ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? + ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)}); + """ + } } object BindReferences extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d8cf2b2e32435..18102d1acb5b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -35,48 +36,48 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match { case (StringType, _: NumericType) => true - case (StringType, TimestampType) => true - case (DoubleType, TimestampType) => true - case (FloatType, TimestampType) => true - case (StringType, DateType) => true - case (_: NumericType, DateType) => true - case (BooleanType, DateType) => true - case (DateType, _: NumericType) => true - case (DateType, BooleanType) => true + case (StringType, TimestampType) => true + case (DoubleType, TimestampType) => true + case (FloatType, TimestampType) => true + case (StringType, DateType) => true + case (_: NumericType, DateType) => true + case (BooleanType, DateType) => true + case (DateType, _: NumericType) => true + case (DateType, BooleanType) => true case (DoubleType, _: DecimalType) => true - case (FloatType, _: DecimalType) => true + case (FloatType, _: DecimalType) => true case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null - case _ => false + case _ => false } private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to private[this] def resolve(from: DataType, to: DataType): Boolean = { (from, to) match { - case (from, to) if from == to => true + case (from, to) if from == to => true - case (NullType, _) => true + case (NullType, _) => true - case (_, StringType) => true + case (_, StringType) => true - case (StringType, BinaryType) => true + case (StringType, BinaryType) => true - case (StringType, BooleanType) => true - case (DateType, BooleanType) => true - case (TimestampType, BooleanType) => true - case (_: NumericType, BooleanType) => true + case (StringType, BooleanType) => true + case (DateType, BooleanType) => true + case (TimestampType, BooleanType) => true + case (_: NumericType, BooleanType) => true - case (StringType, TimestampType) => true - case (BooleanType, TimestampType) => true - case (DateType, TimestampType) => true - case (_: NumericType, TimestampType) => true + case (StringType, TimestampType) => true + case (BooleanType, TimestampType) => true + case (DateType, TimestampType) => true + case (_: NumericType, TimestampType) => true - case (_, DateType) => true + case (_, DateType) => true - case (StringType, _: NumericType) => true - case (BooleanType, _: NumericType) => true - case (DateType, _: NumericType) => true - case (TimestampType, _: NumericType) => true + case (StringType, _: NumericType) => true + case (BooleanType, _: NumericType) => true + case (DateType, _: NumericType) => true + case (TimestampType, _: NumericType) => true case (_: NumericType, _: NumericType) => true case (ArrayType(from, fn), ArrayType(to, tn)) => @@ -105,8 +106,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w override def toString: String = s"CAST($child, $dataType)" - type EvaluatedType = Any - // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) @@ -162,7 +161,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null } }) case BooleanType => - buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0))) + buildCast[Boolean](_, b => new Timestamp(if (b) 1 else 0)) case LongType => buildCast[Long](_, l => new Timestamp(l)) case IntegerType => @@ -412,21 +411,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def cast(from: DataType, to: DataType): Any => Any = to match { case dt if dt == child.dataType => identity[Any] - case StringType => castToString(from) - case BinaryType => castToBinary(from) - case DateType => castToDate(from) - case decimal: DecimalType => castToDecimal(from, decimal) - case TimestampType => castToTimestamp(from) - case BooleanType => castToBoolean(from) - case ByteType => castToByte(from) - case ShortType => castToShort(from) - case IntegerType => castToInt(from) - case FloatType => castToFloat(from) - case LongType => castToLong(from) - case DoubleType => castToDouble(from) - case array: ArrayType => castArray(from.asInstanceOf[ArrayType], array) - case map: MapType => castMap(from.asInstanceOf[MapType], map) - case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case StringType => castToString(from) + case BinaryType => castToBinary(from) + case DateType => castToDate(from) + case decimal: DecimalType => castToDecimal(from, decimal) + case TimestampType => castToTimestamp(from) + case BooleanType => castToBoolean(from) + case ByteType => castToByte(from) + case ShortType => castToShort(from) + case IntegerType => castToInt(from) + case FloatType => castToFloat(from) + case LongType => castToLong(from) + case DoubleType => castToDouble(from) + case array: ArrayType => castArray(from.asInstanceOf[ArrayType], array) + case map: MapType => castMap(from.asInstanceOf[MapType], map) + case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) @@ -435,6 +434,47 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val evaluated = child.eval(input) if (evaluated == null) null else cast(evaluated) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + // TODO(cg): Add support for more data types. + (child.dataType, dataType) match { + + case (BinaryType, StringType) => + defineCodeGen (ctx, ev, c => + s"new ${ctx.stringType}().set($c)") + case (DateType, StringType) => + defineCodeGen(ctx, ev, c => + s"""new ${ctx.stringType}().set( + org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""") + // Special handling required for timestamps in hive test cases since the toString function + // does not match the expected output. + case (TimestampType, StringType) => + super.genCode(ctx, ev) + case (_, StringType) => + defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))") + + // fallback for DecimalType, this must be before other numeric types + case (_, dt: DecimalType) => + super.genCode(ctx, ev) + + case (BooleanType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") + case (dt: DecimalType, BooleanType) => + defineCodeGen(ctx, ev, c => s"$c.isZero()") + case (dt: NumericType, BooleanType) => + defineCodeGen(ctx, ev, c => s"$c != 0") + + case (_: DecimalType, IntegerType) => + defineCodeGen(ctx, ev, c => s"($c).toInt()") + case (_: DecimalType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") + case (_: NumericType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") + + case other => + super.genCode(ctx, ev) + } + } } object Cast { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c7ae9da7fce49..f2ed1f0929987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,17 +17,24 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ + +/** + * For Catalyst to work correctly, concrete implementations of [[Expression]]s must be case classes + * whose constructor arguments are all Expressions types. In addition, if we want to support more + * than one constructor, define those constructors explicitly as apply methods in the companion + * object. + * + * See [[Substring]] for an example. + */ abstract class Expression extends TreeNode[Expression] { self: Product => - /** The narrowest possible type that is produced when this expression is evaluated. */ - type EvaluatedType <: Any - /** * Returns true when an expression is a candidate for static evaluation before the query is * executed. @@ -40,19 +47,66 @@ abstract class Expression extends TreeNode[Expression] { * - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable */ def foldable: Boolean = false + + /** + * Returns true when the current expression always return the same result for fixed input values. + */ + // TODO: Need to define explicit input values vs implicit input values. + def deterministic: Boolean = true + def nullable: Boolean + def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) /** Returns the result of evaluating this expression on a given input Row */ - def eval(input: Row = null): EvaluatedType + def eval(input: Row = null): Any + + /** + * Returns an [[GeneratedExpressionCode]], which contains Java source code that + * can be used to generate the result of evaluating the expression on an input row. + * + * @param ctx a [[CodeGenContext]] + * @return [[GeneratedExpressionCode]] + */ + def gen(ctx: CodeGenContext): GeneratedExpressionCode = { + val isNull = ctx.freshName("isNull") + val primitive = ctx.freshName("primitive") + val ve = GeneratedExpressionCode("", isNull, primitive) + ve.code = genCode(ctx, ve) + ve + } + + /** + * Returns Java source code that can be compiled to evaluate this expression. + * The default behavior is to call the eval method of the expression. Concrete expression + * implementations should override this to do actual code generation. + * + * @param ctx a [[CodeGenContext]] + * @param ev an [[GeneratedExpressionCode]] with unique terms. + * @return Java source code + */ + protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + ctx.references += this + val objectTerm = ctx.freshName("obj") + s""" + /* expression: ${this} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); + boolean ${ev.isNull} = $objectTerm == null; + ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm; + } + """ + } /** * Returns `true` if this expression and all its children have been resolved to a specific schema - * and `false` if it still contains any unresolved placeholders. Implementations of expressions - * should override this if the resolution of this type of expression involves more than just - * the resolution of its children. + * and input data types checking passed, and `false` if it still contains any unresolved + * placeholders or has data types mismatch. + * Implementations of expressions should override this if the resolution of this type of + * expression involves more than just the resolution of its children and type checking. */ - lazy val resolved: Boolean = childrenResolved + lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** * Returns the [[DataType]] of the result of evaluating this expression. It is @@ -89,18 +143,66 @@ abstract class Expression extends TreeNode[Expression] { case (i1, i2) => i1 == i2 } } + + /** + * Checks the input data types, returns `TypeCheckResult.success` if it's valid, + * or returns a `TypeCheckResult` with an error message if invalid. + * Note: it's not valid to call this method until `childrenResolved == true` + * TODO: we should remove the default implementation and implement it for all + * expressions with proper error message. + */ + def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { self: Product => - def symbol: String + def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") override def foldable: Boolean = left.foldable && right.foldable override def nullable: Boolean = left.nullable || right.nullable override def toString: String = s"($left $symbol $right)" + + /** + * Short hand for generating binary evaluation code, which depends on two sub-evaluations of + * the same type. If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts two variable names and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { + // TODO: Right now some timestamp tests fail if we enforce this... + if (left.dataType != right.dataType) { + // log.warn(s"${left.dataType} != ${right.dataType}") + } + + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(eval1.primitive, eval2.primitive) + + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + ${ev.primitive} = $resultCode; + } else { + ${ev.isNull} = true; + } + } + """ + } +} + +private[sql] object BinaryExpression { + def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) } abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { @@ -109,6 +211,32 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => + + /** + * Called by unary expressions to generate a code block that returns null if its parent returns + * null, and if not not null, use `f` to generate the expression. + * + * As an example, the following does a boolean inversion (i.e. NOT). + * {{{ + * defineCodeGen(ctx, ev, c => s"!($c)") + * }}} + * + * @param f function that accepts a variable name and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: String => String): String = { + val eval = child.gen(ctx) + // reuse the previous isNull + ev.isNull = eval.isNull + eval.code + s""" + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = ${f(eval.primitive)}; + } + """ + } } // TODO Semantically we probably not need GroupExpression @@ -117,8 +245,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio // not like a real expressions. case class GroupExpression(children: Seq[Expression]) extends Expression { self: Product => - type EvaluatedType = Seq[Any] - override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def eval(input: Row): Any = throw new UnsupportedOperationException override def nullable: Boolean = false override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException @@ -129,7 +256,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression { * so that the proper type conversions can be performed in the analyzer. */ trait ExpectsInputTypes { + self: Expression => def expectedChildTypes: Seq[DataType] + override def checkInputDataTypes(): TypeCheckResult = { + // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + TypeCheckResult.TypeCheckSuccess + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index e05926cbfe74b..a1e0819e8a433 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -47,7 +47,7 @@ object ExtractValue { case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) => val ordinal = findField(fields, fieldName.toString, resolver) GetArrayStructFields(child, fields(ordinal), ordinal, containsNull) - case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => + case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) case (_: MapType, _) => GetMapValue(child, extraction) @@ -92,8 +92,6 @@ object ExtractValue { trait ExtractValue extends UnaryExpression { self: Product => - - type EvaluatedType = Any } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index fe2873e0be34d..5b45347872cca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.types.DataType case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) extends Expression { - type EvaluatedType = Any - override def nullable: Boolean = true override def toString: String = s"scalaUDF(${children.mkString(",")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 83074eb1e6310..99340a14c9ecc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -29,14 +29,14 @@ case object Descending extends SortDirection * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder(child: Expression, direction: SortDirection) extends Expression +case class SortOrder(child: Expression, direction: SortDirection) extends Expression with trees.UnaryNode[Expression] { override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable // SortOrder itself is never evaluated. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index f3830c6d3bcf2..0266084a6d174 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -37,7 +37,7 @@ abstract class AggregateExpression extends Expression { * [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are * replaced with a physical aggregate operator at runtime. */ - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -74,8 +74,6 @@ abstract class AggregateFunction extends AggregateExpression with Serializable with trees.LeafNode[Expression] { self: Product => - override type EvaluatedType = Any - /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression @@ -113,7 +111,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr override def update(input: Row): Unit = { if (currentMin.value == null) { currentMin.value = expr.eval(input) - } else if(cmp.eval(input) == true) { + } else if (cmp.eval(input) == true) { currentMin.value = expr.eval(input) } } @@ -144,7 +142,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr override def update(input: Row): Unit = { if (currentMax.value == null) { currentMax.value = expr.eval(input) - } else if(cmp.eval(input) == true) { + } else if (cmp.eval(input) == true) { currentMax.value = expr.eval(input) } } @@ -396,13 +394,13 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ * Combining PartitionLevel InputData * <-- null * Zero <-- Zero <-- null - * + * * <-- null <-- no data - * null <-- null <-- no data + * null <-- null <-- no data */ case class CombineSum(child: Expression) extends AggregateExpression { def this() = this(null) - + override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -618,7 +616,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr private val sum = MutableLiteral(null, calcType) - private val addFunction = + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) override def update(input: Row): Unit = { @@ -636,7 +634,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr case class CombineSumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - + def this() = this(null, null) // Required for serialization. private val calcType = @@ -651,12 +649,12 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) private val sum = MutableLiteral(null, calcType) - private val addFunction = + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) - + override def update(input: Row): Unit = { val result = expr.eval(input) - // partial sum result can be null only when no input rows present + // partial sum result can be null only when no input rows present if(result != null) { sum.update(addFunction, input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index c7a37ad966df6..124274c94203c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -17,76 +17,113 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -case class UnaryMinus(child: Expression) extends UnaryExpression { - type EvaluatedType = Any +abstract class UnaryArithmetic extends UnaryExpression { + self: Product => - override def dataType: DataType = child.dataType override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable - override def toString: String = s"-$child" - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override def dataType: DataType = child.dataType override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - numeric.negate(evalE) + evalInternal(evalE) } } + + protected def evalInternal(evalE: Any): Any = + sys.error(s"UnaryArithmetics must override either eval or evalInternal") } -case class Sqrt(child: Expression) extends UnaryExpression { - type EvaluatedType = Any +case class UnaryMinus(child: Expression) extends UnaryArithmetic { + override def toString: String = s"-$child" + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "operator -") + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()") + case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)") + } + protected override def evalInternal(evalE: Any) = numeric.negate(evalE) +} + +case class Sqrt(child: Expression) extends UnaryArithmetic { override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"SQRT($child)" - lazy val numeric = child.dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support non-negative numeric operations") + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sqrt") + + private lazy val numeric = TypeUtils.getNumeric(child.dataType) + + protected override def evalInternal(evalE: Any) = { + val value = numeric.toDouble(evalE) + if (value < 0) null + else math.sqrt(value) } - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val value = numeric.toDouble(evalE) - if (value < 0) null - else math.sqrt(value) - } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + if (${eval.primitive} < 0.0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.sqrt(${eval.primitive}); + } + } + """ } } +/** + * A function that get the absolute value of the numeric value. + */ +case class Abs(child: Expression) extends UnaryArithmetic { + override def toString: String = s"Abs($child)" + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function abs") + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE: Any) = numeric.abs(evalE) +} + abstract class BinaryArithmetic extends BinaryExpression { self: Product => - type EvaluatedType = Any + /** Name of the function for this expression on a [[Decimal]] type. */ + def decimalMethod: String = "" - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + override def dataType: DataType = left.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in ${this.getClass.getSimpleName} " + + s"(${left.dataType} and ${right.dataType}).") + } else { + checkTypesInternal(dataType) } - left.dataType } + protected def checkTypesInternal(t: DataType): TypeCheckResult + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if(evalE1 == null) { @@ -101,90 +138,87 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = - sys.error(s"BinaryExpressions must either override eval or evalInternal") + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + // byte and short are casted into int when add, minus, times or divide + case ByteType | ShortType => + defineCodeGen(ctx, ev, (eval1, eval2) => + s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + case _ => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + } + + protected def evalInternal(evalE1: Any, evalE2: Any): Any = + sys.error(s"BinaryArithmetics must override either eval or evalInternal") +} + +private[sql] object BinaryArithmetic { + def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" + override def decimalMethod: String = "$plus" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.plus(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.plus(evalE1, evalE2) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" + override def decimalMethod: String = "$minus" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.minus(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.minus(evalE1, evalE2) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" + override def decimalMethod: String = "$times" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.times(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.times(evalE1, evalE2) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" + override def decimalMethod: String = "$divide" override def nullable: Boolean = true - lazy val div: (Any, Any) => Any = dataType match { + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot - case other => sys.error(s"Type $other does not support numeric operations") } - + override def eval(input: Row): Any = { val evalE2 = right.eval(input) if (evalE2 == null || evalE2 == 0) { @@ -198,17 +232,51 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } } + + /** + * Special case handling due to division by 0 => null. + */ + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val test = if (left.dataType.isInstanceOf[DecimalType]) { + s"${eval2.primitive}.isZero()" + } else { + s"${eval2.primitive} == 0" + } + val method = if (left.dataType.isInstanceOf[DecimalType]) { + s".$decimalMethod" + } else { + s"$symbol" + } + eval1.code + eval2.code + + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; + if (${eval1.isNull} || ${eval2.isNull} || $test) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); + } + """ + } } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "%" + override def decimalMethod: String = "reminder" override def nullable: Boolean = true - lazy val integral = dataType match { + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] - case other => sys.error(s"Type $other does not support numeric operations") } override def eval(input: Row): Any = { @@ -224,129 +292,43 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } } -} -/** - * A function that calculates bitwise and(&) of two numbers. - */ -case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "&" - - lazy val and: (Any, Any) => Any = dataType match { - case ByteType => - ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] - case ShortType => - ((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any] - case IntegerType => - ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] - case LongType => - ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise & operation on $other") - } - - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = and(evalE1, evalE2) -} - -/** - * A function that calculates bitwise or(|) of two numbers. - */ -case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "|" - - lazy val or: (Any, Any) => Any = dataType match { - case ByteType => - ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] - case ShortType => - ((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any] - case IntegerType => - ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] - case LongType => - ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise | operation on $other") - } - - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = or(evalE1, evalE2) -} - -/** - * A function that calculates bitwise xor(^) of two numbers. - */ -case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "^" - - lazy val xor: (Any, Any) => Any = dataType match { - case ByteType => - ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] - case ShortType => - ((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any] - case IntegerType => - ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] - case LongType => - ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise ^ operation on $other") - } - - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = xor(evalE1, evalE2) -} - -/** - * A function that calculates bitwise not(~) of a number. - */ -case class BitwiseNot(child: Expression) extends UnaryExpression { - type EvaluatedType = Any - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable - override def toString: String = s"~$child" - - lazy val not: (Any) => Any = dataType match { - case ByteType => - ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] - case ShortType => - ((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any] - case IntegerType => - ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] - case LongType => - ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] - case other => sys.error(s"Unsupported bitwise ~ operation on $other") - } - - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null + /** + * Special case handling for x % 0 ==> null. + */ + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val test = if (left.dataType.isInstanceOf[DecimalType]) { + s"${eval2.primitive}.isZero()" + } else { + s"${eval2.primitive} == 0" + } + val method = if (left.dataType.isInstanceOf[DecimalType]) { + s".$decimalMethod" } else { - not(evalE) + s"$symbol" } + eval1.code + eval2.code + + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; + if (${eval1.isNull} || ${eval2.isNull} || $test) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); + } + """ } } -case class MaxOf(left: Expression, right: Expression) extends Expression { - type EvaluatedType = Any - - override def foldable: Boolean = left.foldable && right.foldable - +case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - override def children: Seq[Expression] = left :: right :: Nil - - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(t, "function maxOf") - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + private lazy val ordering = TypeUtils.getOrdering(dataType) override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -364,34 +346,43 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + if (ctx.isNativeType(left.dataType)) { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + eval1.code + eval2.code + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = + ${ctx.defaultValue(left.dataType)}; + + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; + } else { + if (${eval1.primitive} > ${eval2.primitive}) { + ${ev.primitive} = ${eval1.primitive}; + } else { + ${ev.primitive} = ${eval2.primitive}; + } + } + """ + } else { + super.genCode(ctx, ev) + } + } override def toString: String = s"MaxOf($left, $right)" } -case class MinOf(left: Expression, right: Expression) extends Expression { - type EvaluatedType = Any - - override def foldable: Boolean = left.foldable && right.foldable - +case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - override def children: Seq[Expression] = left :: right :: Nil - - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(t, "function minOf") - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + private lazy val ordering = TypeUtils.getOrdering(dataType) override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -409,31 +400,35 @@ case class MinOf(left: Expression, right: Expression) extends Expression { } } - override def toString: String = s"MinOf($left, $right)" -} - -/** - * A function that get the absolute value of the numeric value. - */ -case class Abs(child: Expression) extends UnaryExpression { - type EvaluatedType = Any - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable - override def toString: String = s"Abs($child)" - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + if (ctx.isNativeType(left.dataType)) { + + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + + eval1.code + eval2.code + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = + ${ctx.defaultValue(left.dataType)}; + + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; + } else { + if (${eval1.primitive} < ${eval2.primitive}) { + ${ev.primitive} = ${eval1.primitive}; + } else { + ${ev.primitive} = ${eval2.primitive}; + } + } + """ } else { - numeric.abs(evalE) + super.genCode(ctx, ev) } } + + override def toString: String = s"MinOf($left, $right)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala new file mode 100644 index 0000000000000..9002dda7bf4d0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + + +/** + * A function that calculates bitwise and(&) of two numbers. + * + * Code generation inherited from BinaryArithmetic. + */ +case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { + override def symbol: String = "&" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val and: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] + } + + protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, evalE2) +} + +/** + * A function that calculates bitwise or(|) of two numbers. + * + * Code generation inherited from BinaryArithmetic. + */ +case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { + override def symbol: String = "|" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val or: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] + } + + protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, evalE2) +} + +/** + * A function that calculates bitwise xor of two numbers. + * + * Code generation inherited from BinaryArithmetic. + */ +case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { + override def symbol: String = "^" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val xor: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] + } + + protected override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) +} + +/** + * A function that calculates bitwise not(~) of a number. + */ +case class BitwiseNot(child: Expression) extends UnaryArithmetic { + override def toString: String = s"~$child" + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~") + + private lazy val not: (Any) => Any = dataType match { + case ByteType => + ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] + case ShortType => + ((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any] + case IntegerType => + ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] + case LongType => + ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") + } + + protected override def evalInternal(evalE: Any) = not(evalE) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index ecb4c4b68f904..e95682f952a7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import com.google.common.cache.{CacheLoader, CacheBuilder} - +import scala.collection.mutable import scala.language.existentials +import com.google.common.cache.{CacheBuilder, CacheLoader} +import org.codehaus.janino.ClassBodyEvaluator + import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -31,28 +32,166 @@ class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] /** - * A base class for generators of byte code to perform expression evaluation. Includes a set of - * helpers for referring to Catalyst types and building trees that perform evaluation of individual - * expressions. + * Java source for evaluating an [[Expression]] given a [[Row]] of input. + * + * @param code The sequence of statements required to evaluate the expression. + * @param isNull A term that holds a boolean value representing whether the expression evaluated + * to null. + * @param primitive A term for a possible primitive value of the result of the evaluation. Not + * valid if `isNull` is set to `true`. */ -abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ +case class GeneratedExpressionCode(var code: String, var isNull: String, var primitive: String) - import scala.tools.reflect.ToolBox - - protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() +/** + * A context for codegen, which is used to bookkeeping the expressions those are not supported + * by codegen, then they are evaluated directly. The unsupported expression is appended at the + * end of `references`, the position of it is kept in the code, used to access and evaluate it. + */ +class CodeGenContext { - protected val rowType = typeOf[Row] - protected val mutableRowType = typeOf[MutableRow] - protected val genericRowType = typeOf[GenericRow] - protected val genericMutableRowType = typeOf[GenericMutableRow] + /** + * Holding all the expressions those do not support codegen, will be evaluated directly. + */ + val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() - protected val projectionType = typeOf[Projection] - protected val mutableProjectionType = typeOf[MutableProjection] + val stringType: String = classOf[UTF8String].getName + val decimalType: String = classOf[Decimal].getName private val curId = new java.util.concurrent.atomic.AtomicInteger() - private val javaSeparator = "$" + + /** + * Returns a term name that is unique within this instance of a `CodeGenerator`. + * + * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` + * function.) + */ + def freshName(prefix: String): String = { + s"$prefix${curId.getAndIncrement}" + } + + /** + * Return the code to access a column for given DataType + */ + def getColumn(dataType: DataType, ordinal: Int): String = { + if (isNativeType(dataType)) { + s"i.${accessorForType(dataType)}($ordinal)" + } else { + s"(${boxedType(dataType)})i.apply($ordinal)" + } + } + + /** + * Return the code to update a column in Row for given DataType + */ + def setColumn(dataType: DataType, ordinal: Int, value: String): String = { + if (isNativeType(dataType)) { + s"${mutatorForType(dataType)}($ordinal, $value)" + } else { + s"update($ordinal, $value)" + } + } + + /** + * Return the name of accessor in Row for a DataType + */ + def accessorForType(dt: DataType): String = dt match { + case IntegerType => "getInt" + case other => s"get${boxedType(dt)}" + } + + /** + * Return the name of mutator in Row for a DataType + */ + def mutatorForType(dt: DataType): String = dt match { + case IntegerType => "setInt" + case other => s"set${boxedType(dt)}" + } + + /** + * Return the Java type for a DataType + */ + def javaType(dt: DataType): String = dt match { + case IntegerType => "int" + case LongType => "long" + case ShortType => "short" + case ByteType => "byte" + case DoubleType => "double" + case FloatType => "float" + case BooleanType => "boolean" + case dt: DecimalType => decimalType + case BinaryType => "byte[]" + case StringType => stringType + case DateType => "int" + case TimestampType => "java.sql.Timestamp" + case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName + case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName + case _ => "Object" + } + + /** + * Return the boxed type in Java + */ + def boxedType(dt: DataType): String = dt match { + case IntegerType => "Integer" + case LongType => "Long" + case ShortType => "Short" + case ByteType => "Byte" + case DoubleType => "Double" + case FloatType => "Float" + case BooleanType => "Boolean" + case DateType => "Integer" + case _ => javaType(dt) + } + + /** + * Return the representation of default value for given DataType + */ + def defaultValue(dt: DataType): String = dt match { + case BooleanType => "false" + case FloatType => "-1.0f" + case ShortType => "(short)-1" + case LongType => "-1L" + case ByteType => "(byte)-1" + case DoubleType => "-1.0" + case IntegerType => "-1" + case DateType => "-1" + case _ => "null" + } + + /** + * Returns a function to generate equal expression in Java + */ + def equalFunc(dataType: DataType): ((String, String) => String) = dataType match { + case BinaryType => { case (eval1, eval2) => + s"java.util.Arrays.equals($eval1, $eval2)" } + case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType => + { case (eval1, eval2) => s"$eval1 == $eval2" } + case other => + { case (eval1, eval2) => s"$eval1.equals($eval2)" } + } + + /** + * List of data types that have special accessors and setters in [[Row]]. + */ + val nativeTypes = + Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) + + /** + * Returns true if the data type has a special accessor and setter in [[Row]]. + */ + def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt) +} + +/** + * A base class for generators of byte code to perform expression evaluation. Includes a set of + * helpers for referring to Catalyst types and building trees that perform evaluation of individual + * expressions. + */ +abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { + + protected val exprType: String = classOf[Expression].getName + protected val mutableRowType: String = classOf[MutableRow].getName + protected val genericMutableRowType: String = classOf[GenericMutableRow].getName /** * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. @@ -74,6 +213,26 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** Binds an input expression to a given input schema */ protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + /** + * Compile the Java source code into a Java class, using Janino. + * + * It will track the time used to compile + */ + protected def compile(code: String): Class[_] = { + val startTime = System.nanoTime() + val clazz = try { + new ClassBodyEvaluator(code).getClazz() + } catch { + case e: Exception => + logError(s"failed to compile:\n $code", e) + throw e + } + val endTime = System.nanoTime() + def timeMs: Double = (endTime - startTime).toDouble / 1000000 + logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms") + clazz + } + /** * A cache of generated classes. * @@ -87,7 +246,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin .maximumSize(1000) .build( new CacheLoader[InType, OutType]() { - override def load(in: InType): OutType = globalLock.synchronized { + override def load(in: InType): OutType = { val startTime = System.nanoTime() val result = create(in) val endTime = System.nanoTime() @@ -105,590 +264,10 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) /** - * Returns a term name that is unique within this instance of a `CodeGenerator`. - * - * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` - * function.) + * Create a new codegen context for expression evaluator, used to store those + * expressions that don't support codegen */ - protected def freshName(prefix: String): TermName = { - newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}") - } - - /** - * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input. - * - * @param code The sequence of statements required to evaluate the expression. - * @param nullTerm A term that holds a boolean value representing whether the expression evaluated - * to null. - * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not - * valid if `nullTerm` is set to `true`. - * @param objectTerm A possibly boxed version of the result of evaluating this expression. - */ - protected case class EvaluatedExpression( - code: Seq[Tree], - nullTerm: TermName, - primitiveTerm: TermName, - objectTerm: TermName) - - /** - * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that - * can be used to determine the result of evaluating the expression on an input row. - */ - def expressionEvaluator(e: Expression): EvaluatedExpression = { - val primitiveTerm = freshName("primitiveTerm") - val nullTerm = freshName("nullTerm") - val objectTerm = freshName("objectTerm") - - implicit class Evaluate1(e: Expression) { - def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = { - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(dataType)} - else - ${f(eval.primitiveTerm)} - """.children - } - } - - implicit class Evaluate2(expressions: (Expression, Expression)) { - - /** - * Short hand for generating binary evaluation code, which depends on two sub-evaluations of - * the same type. If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f a function from two primitive term names to a tree that evaluates them. - */ - def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] = - evaluateAs(expressions._1.dataType)(f) - - def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = { - // TODO: Right now some timestamp tests fail if we enforce this... - if (expressions._1.dataType != expressions._2.dataType) { - log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") - } - - val eval1 = expressionEvaluator(expressions._1) - val eval2 = expressionEvaluator(expressions._2) - val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) - - eval1.code ++ eval2.code ++ - q""" - val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm} - val $primitiveTerm: ${termForType(resultType)} = - if($nullTerm) { - ${defaultPrimitive(resultType)} - } else { - $resultCode.asInstanceOf[${termForType(resultType)}] - } - """.children : Seq[Tree] - } - } - - val inputTuple = newTermName(s"i") - - // TODO: Skip generation of null handling code when expression are not nullable. - val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = { - case b @ BoundReference(ordinal, dataType, nullable) => - val nullValue = q"$inputTuple.isNullAt($ordinal)" - q""" - val $nullTerm: Boolean = $nullValue - val $primitiveTerm: ${termForType(dataType)} = - if($nullTerm) - ${defaultPrimitive(dataType)} - else - ${getColumn(inputTuple, dataType, ordinal)} - """.children - - case expressions.Literal(null, dataType) => - q""" - val $nullTerm = true - val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}] - """.children - - case expressions.Literal(value: Boolean, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case expressions.Literal(value: UTF8String, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = - org.apache.spark.sql.types.UTF8String(${value.getBytes}) - """.children - - case expressions.Literal(value: Int, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case expressions.Literal(value: Long, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case Cast(e @ BinaryType(), StringType) => - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(StringType)} - else - org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) - """.children - - case Cast(child @ DateType(), StringType) => - child.castOrNull(c => - q"""org.apache.spark.sql.types.UTF8String( - org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", - StringType) - - case Cast(child @ NumericType(), IntegerType) => - child.castOrNull(c => q"$c.toInt", IntegerType) - - case Cast(child @ NumericType(), LongType) => - child.castOrNull(c => q"$c.toLong", LongType) - - case Cast(child @ NumericType(), DoubleType) => - child.castOrNull(c => q"$c.toDouble", DoubleType) - - case Cast(child @ NumericType(), FloatType) => - child.castOrNull(c => q"$c.toFloat", FloatType) - - // Special handling required for timestamps in hive test cases since the toString function - // does not match the expected output. - case Cast(e, StringType) if e.dataType != TimestampType => - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(StringType)} - else - org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) - """.children - - case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) => - (e1, e2).evaluateAs (BooleanType) { - case (eval1, eval2) => - q""" - java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]], - $eval2.asInstanceOf[Array[Byte]]) - """ - } - - case EqualTo(e1, e2) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } - - /* TODO: Fix null semantics. - case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) => - val eval = expressionEvaluator(e1) - - val checks = list.map { - case expressions.Literal(v: String, dataType) => - q"if(${eval.primitiveTerm} == $v) return true" - case expressions.Literal(v: Int, dataType) => - q"if(${eval.primitiveTerm} == $v) return true" - } - - val funcName = newTermName(s"isIn${curId.getAndIncrement()}") - - q""" - def $funcName: Boolean = { - ..${eval.code} - if(${eval.nullTerm}) return false - ..$checks - return false - } - val $nullTerm = false - val $primitiveTerm = $funcName - """.children - */ - - case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" } - case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" } - case LessThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" } - case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" } - - case And(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - q""" - ..${eval1.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = false - - if (!${eval1.nullTerm} && ${eval1.primitiveTerm} == false) { - } else { - ..${eval2.code} - if (!${eval2.nullTerm} && ${eval2.primitiveTerm} == false) { - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = true - } else { - $nullTerm = true - } - } - """.children - - case Or(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - q""" - ..${eval1.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = false - - if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { - $primitiveTerm = true - } else { - ..${eval2.code} - if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - $primitiveTerm = true - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = false - } else { - $nullTerm = true - } - } - """.children - - case Not(child) => - // Uh, bad function name... - child.castOrNull(c => q"!$c", BooleanType) - - case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } - case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } - case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } - case Divide(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = 0 - - if (${eval1.nullTerm} || ${eval2.nullTerm} ) { - $nullTerm = true - } else if (${eval2.primitiveTerm} == 0) - $nullTerm = true - else { - $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm} - } - """.children - - case Remainder(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = 0 - - if (${eval1.nullTerm} || ${eval2.nullTerm} ) { - $nullTerm = true - } else if (${eval2.primitiveTerm} == 0) - $nullTerm = true - else { - $nullTerm = false - $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm} - } - """.children - - case IsNotNull(e) => - val eval = expressionEvaluator(e) - q""" - ..${eval.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm} - """.children - - case IsNull(e) => - val eval = expressionEvaluator(e) - q""" - ..${eval.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm} - """.children - - case c @ Coalesce(children) => - q""" - var $nullTerm = true - var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)} - """.children ++ - children.map { c => - val eval = expressionEvaluator(c) - q""" - if($nullTerm) { - ..${eval.code} - if(!${eval.nullTerm}) { - $nullTerm = false - $primitiveTerm = ${eval.primitiveTerm} - } - } - """ - } - - case i @ expressions.If(condition, trueValue, falseValue) => - val condEval = expressionEvaluator(condition) - val trueEval = expressionEvaluator(trueValue) - val falseEval = expressionEvaluator(falseValue) - - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)} - ..${condEval.code} - if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { - ..${trueEval.code} - $nullTerm = ${trueEval.nullTerm} - $primitiveTerm = ${trueEval.primitiveTerm} - } else { - ..${falseEval.code} - $nullTerm = ${falseEval.nullTerm} - $primitiveTerm = ${falseEval.primitiveTerm} - } - """.children - - case NewSet(elementType) => - q""" - val $nullTerm = false - val $primitiveTerm = new ${hashSetForType(elementType)}() - """.children - - case AddItemToSet(item, set) => - val itemEval = expressionEvaluator(item) - val setEval = expressionEvaluator(set) - - val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType - - itemEval.code ++ setEval.code ++ - q""" - if (!${itemEval.nullTerm}) { - ${setEval.primitiveTerm} - .asInstanceOf[${hashSetForType(elementType)}] - .add(${itemEval.primitiveTerm}) - } - - val $nullTerm = false - val $primitiveTerm = ${setEval.primitiveTerm} - """.children - - case CombineSets(left, right) => - val leftEval = expressionEvaluator(left) - val rightEval = expressionEvaluator(right) - - val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType - - leftEval.code ++ rightEval.code ++ - q""" - val $nullTerm = false - var $primitiveTerm: ${hashSetForType(elementType)} = null - - { - val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] - val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] - val iterator = rightSet.iterator - while (iterator.hasNext) { - leftSet.add(iterator.next()) - } - $primitiveTerm = leftSet - } - """.children - - case MaxOf(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} - - if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm} - $primitiveTerm = ${eval2.primitiveTerm} - } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm} - $primitiveTerm = ${eval1.primitiveTerm} - } else { - if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm} - } else { - $primitiveTerm = ${eval2.primitiveTerm} - } - } - """.children - - case MinOf(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} - - if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm} - $primitiveTerm = ${eval2.primitiveTerm} - } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm} - $primitiveTerm = ${eval1.primitiveTerm} - } else { - if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm} - } else { - $primitiveTerm = ${eval2.primitiveTerm} - } - } - """.children - - case UnscaledValue(child) => - val childEval = expressionEvaluator(child) - - childEval.code ++ - q""" - var $nullTerm = ${childEval.nullTerm} - var $primitiveTerm: Long = if (!$nullTerm) { - ${childEval.primitiveTerm}.toUnscaledLong - } else { - ${defaultPrimitive(LongType)} - } - """.children - - case MakeDecimal(child, precision, scale) => - val childEval = expressionEvaluator(child) - - childEval.code ++ - q""" - var $nullTerm = ${childEval.nullTerm} - var $primitiveTerm: org.apache.spark.sql.types.Decimal = - ${defaultPrimitive(DecimalType())} - - if (!$nullTerm) { - $primitiveTerm = new org.apache.spark.sql.types.Decimal() - $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale) - $nullTerm = $primitiveTerm == null - } - """.children - } - - // If there was no match in the partial function above, we fall back on calling the interpreted - // expression evaluator. - val code: Seq[Tree] = - primitiveEvaluation.lift.apply(e).getOrElse { - log.debug(s"No rules to generate $e") - val tree = reify { e } - q""" - val $objectTerm = $tree.eval(i) - val $nullTerm = $objectTerm == null - val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}] - """.children - } - - // Only inject debugging code if debugging is turned on. - val debugCode = - if (debugLogging) { - val localLogger = log - val localLoggerTree = reify { localLogger } - q""" - $localLoggerTree.debug( - ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString)) - """ :: Nil - } else { - Nil - } - - EvaluatedExpression(code ++ debugCode, nullTerm, primitiveTerm, objectTerm) - } - - protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { - dataType match { - case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" - case dt: DataType if isNativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)" - case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" - } + def newCodeGenContext(): CodeGenContext = { + new CodeGenContext } - - protected def setColumn( - destinationRow: TermName, - dataType: DataType, - ordinal: Int, - value: TermName) = { - dataType match { - case StringType => q"$destinationRow.update($ordinal, $value)" - case dt: DataType if isNativeType(dt) => - q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" - case _ => q"$destinationRow.update($ordinal, $value)" - } - } - - protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}") - protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}") - - protected def hashSetForType(dt: DataType) = dt match { - case IntegerType => typeOf[IntegerHashSet] - case LongType => typeOf[LongHashSet] - case unsupportedType => - sys.error(s"Code generation not support for hashset of type $unsupportedType") - } - - protected def primitiveForType(dt: DataType) = dt match { - case IntegerType => "Int" - case LongType => "Long" - case ShortType => "Short" - case ByteType => "Byte" - case DoubleType => "Double" - case FloatType => "Float" - case BooleanType => "Boolean" - case StringType => "org.apache.spark.sql.types.UTF8String" - } - - protected def defaultPrimitive(dt: DataType) = dt match { - case BooleanType => ru.Literal(Constant(false)) - case FloatType => ru.Literal(Constant(-1.0.toFloat)) - case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" - case ShortType => ru.Literal(Constant(-1.toShort)) - case LongType => ru.Literal(Constant(-1L)) - case ByteType => ru.Literal(Constant(-1.toByte)) - case DoubleType => ru.Literal(Constant(-1.toDouble)) - case DecimalType() => q"org.apache.spark.sql.types.Decimal(-1)" - case IntegerType => ru.Literal(Constant(-1)) - case DateType => ru.Literal(Constant(-1)) - case _ => ru.Literal(Constant(null)) - } - - protected def termForType(dt: DataType) = dt match { - case n: AtomicType => n.tag - case _ => typeTag[Any] - } - - /** - * List of data types that have special accessors and setters in [[Row]]. - */ - protected val nativeTypes = - Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) - - /** - * Returns true if the data type has a special accessor and setter in [[Row]]. - */ - protected def isNativeType(dt: DataType) = nativeTypes.contains(dt) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 840260703ab74..e5ee2accd8a84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -19,15 +19,14 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +// MutableProjection is not accessible in Java +abstract class BaseMutableProjection extends MutableProjection {} + /** * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new * input [[Row]] for a fixed set of [[Expression Expressions]]. */ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ - - val mutableRowName = newTermName("mutableRow") protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -36,41 +35,61 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu in.map(BindReferences.bindReference(_, inputSchema)) protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { - val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) => - val evaluationCode = expressionEvaluator(e) - - evaluationCode.code :+ - q""" - if(${evaluationCode.nullTerm}) - mutableRow.setNullAt($i) - else - ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)} - """ - } + val ctx = newCodeGenContext() + val projectionCode = expressions.zipWithIndex.map { case (e, i) => + val evaluationCode = e.gen(ctx) + evaluationCode.code + + s""" + if(${evaluationCode.isNull}) + mutableRow.setNullAt($i); + else + mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)}; + """ + }.mkString("\n") + val code = s""" + import org.apache.spark.sql.Row; + + public SpecificProjection generate($exprType[] expr) { + return new SpecificProjection(expr); + } + + class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { - val code = - q""" - () => { new $mutableProjectionType { + private $exprType[] expressions = null; + private $mutableRowType mutableRow = null; - private[this] var $mutableRowName: $mutableRowType = - new $genericMutableRowType(${expressions.size}) + public SpecificProjection($exprType[] expr) { + expressions = expr; + mutableRow = new $genericMutableRowType(${expressions.size}); + } - def target(row: $mutableRowType): $mutableProjectionType = { - $mutableRowName = row - this - } + public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) { + mutableRow = row; + return this; + } - /* Provide immutable access to the last projected row. */ - def currentValue: $rowType = mutableRow + /* Provide immutable access to the last projected row. */ + public Row currentValue() { + return mutableRow; + } - def apply(i: $rowType): $rowType = { - ..$projectionCode - mutableRow - } - } } - """ + public Object apply(Object _i) { + Row i = (Row) _i; + $projectionCode - log.debug(s"code for ${expressions.mkString(",")}:\n$code") - toolBox.eval(code).asInstanceOf[() => MutableProjection] + return mutableRow; + } + } + """ + + + logDebug(s"code for ${expressions.mkString(",")}:\n$code") + + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + () => { + m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseMutableProjection] + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index b129c0d898bb7..36e155d164a40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -18,18 +18,29 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.Logging +import org.apache.spark.annotation.Private +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{BinaryType, StringType, NumericType} +import org.apache.spark.sql.types.{BinaryType, NumericType} + +/** + * Inherits some default implementation for Java from `Ordering[Row]` + */ +@Private +class BaseOrdering extends Ordering[Row] { + def compare(a: Row, b: Row): Int = { + throw new UnsupportedOperationException + } +} /** * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of * [[Expression Expressions]]. */ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging { - import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ - protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = + protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = @@ -38,73 +49,90 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit protected def create(ordering: Seq[SortOrder]): Ordering[Row] = { val a = newTermName("a") val b = newTermName("b") - val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = expressionEvaluator(order.child) - val evalB = expressionEvaluator(order.child) + val ctx = newCodeGenContext() + val comparisons = ordering.zipWithIndex.map { case (order, i) => + val evalA = order.child.gen(ctx) + val evalB = order.child.gen(ctx) + val asc = order.direction == Ascending val compare = order.child.dataType match { case BinaryType => - q""" - val x = ${if (order.direction == Ascending) evalA.primitiveTerm else evalB.primitiveTerm} - val y = ${if (order.direction != Ascending) evalB.primitiveTerm else evalA.primitiveTerm} - var i = 0 - while (i < x.length && i < y.length) { - val res = x(i).compareTo(y(i)) - if (res != 0) return res - i = i+1 - } - return x.length - y.length - """ + s""" + { + byte[] x = ${if (asc) evalA.primitive else evalB.primitive}; + byte[] y = ${if (!asc) evalB.primitive else evalA.primitive}; + int j = 0; + while (j < x.length && j < y.length) { + if (x[j] != y[j]) return x[j] - y[j]; + j = j + 1; + } + int d = x.length - y.length; + if (d != 0) { + return d; + } + }""" case _: NumericType => - q""" - val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} - if(comp != 0) { - return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"} - } - """ - case StringType => - if (order.direction == Ascending) { - q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})""" + s""" + if (${evalA.primitive} != ${evalB.primitive}) { + if (${evalA.primitive} > ${evalB.primitive}) { + return ${if (asc) "1" else "-1"}; + } else { + return ${if (asc) "-1" else "1"}; + } + }""" + case _ => + s""" + int comp = ${evalA.primitive}.compare(${evalB.primitive}); + if (comp != 0) { + return ${if (asc) "comp" else "-comp"}; + }""" + } + + s""" + i = $a; + ${evalA.code} + i = $b; + ${evalB.code} + if (${evalA.isNull} && ${evalB.isNull}) { + // Nothing + } else if (${evalA.isNull}) { + return ${if (order.direction == Ascending) "-1" else "1"}; + } else if (${evalB.isNull}) { + return ${if (order.direction == Ascending) "1" else "-1"}; } else { - q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})""" + $compare } + """ + }.mkString("\n") + + val code = s""" + import org.apache.spark.sql.Row; + + public SpecificOrdering generate($exprType[] expr) { + return new SpecificOrdering(expr); } - q""" - i = $a - ..${evalA.code} - i = $b - ..${evalB.code} - if (${evalA.nullTerm} && ${evalB.nullTerm}) { - // Nothing - } else if (${evalA.nullTerm}) { - return ${if (order.direction == Ascending) q"-1" else q"1"} - } else if (${evalB.nullTerm}) { - return ${if (order.direction == Ascending) q"1" else q"-1"} - } else { - $compare + class SpecificOrdering extends ${typeOf[BaseOrdering]} { + + private $exprType[] expressions = null; + + public SpecificOrdering($exprType[] expr) { + expressions = expr; } - """ - } - val q"class $orderingName extends $orderingType { ..$body }" = reify { - class SpecificOrdering extends Ordering[Row] { - val o = ordering - } - }.tree.children.head - - val code = q""" - class $orderingName extends $orderingType { - ..$body - def compare(a: $rowType, b: $rowType): Int = { - var i: $rowType = null // Holds current row being evaluated. - ..$comparisons - return 0 + @Override + public int compare(Row a, Row b) { + Row i = null; // Holds current row being evaluated. + $comparisons + return 0; } - } - new $orderingName() - """ + }""" + logDebug(s"Generated Ordering: $code") - toolBox.eval(code).asInstanceOf[Ordering[Row]] + + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 40e163024360e..4a547b5ce9543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -19,12 +19,17 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +/** + * Interface for generated predicate + */ +abstract class Predicate { + def eval(r: Row): Boolean +} + /** * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]]. */ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) @@ -32,17 +37,34 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { BindReferences.bindReference(in, inputSchema) protected def create(predicate: Expression): ((Row) => Boolean) = { - val cEval = expressionEvaluator(predicate) + val ctx = newCodeGenContext() + val eval = predicate.gen(ctx) + val code = s""" + import org.apache.spark.sql.Row; - val code = - q""" - (i: $rowType) => { - ..${cEval.code} - if (${cEval.nullTerm}) false else ${cEval.primitiveTerm} + public SpecificPredicate generate($exprType[] expr) { + return new SpecificPredicate(expr); + } + + class SpecificPredicate extends ${classOf[Predicate].getName} { + private final $exprType[] expressions; + public SpecificPredicate($exprType[] expr) { + expressions = expr; + } + + @Override + public boolean eval(Row i) { + ${eval.code} + return !${eval.isNull} && ${eval.primitive}; } - """ + }""" + + logDebug(s"Generated predicate '$predicate':\n$code") - log.debug(s"Generated predicate '$predicate':\n$code") - toolBox.eval(code).asInstanceOf[Row => Boolean] + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + val p = m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Predicate] + (r: Row) => p.eval(r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 584f938445c8c..7caf4aaab88bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.spark.sql.BaseMutableRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +/** + * Java can not access Projection (in package object) + */ +abstract class BaseProject extends Projection {} /** * Generates bytecode that produces a new [[Row]] object based on a fixed set of input @@ -27,7 +32,6 @@ import org.apache.spark.sql.types._ * generated based on the output types of the [[Expression]] to avoid boxing of primitive values. */ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[Expression]): Seq[Expression] = @@ -38,201 +42,183 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { // Make Mutablility optional... protected def create(expressions: Seq[Expression]): Projection = { - val tupleLength = ru.Literal(Constant(expressions.length)) - val lengthDef = q"final val length = $tupleLength" - - /* TODO: Configurable... - val nullFunctions = - q""" - private final val nullSet = new org.apache.spark.util.collection.BitSet(length) - final def setNullAt(i: Int) = nullSet.set(i) - final def isNullAt(i: Int) = nullSet.get(i) - """ - */ - - val nullFunctions = - q""" - private[this] var nullBits = new Array[Boolean](${expressions.size}) - override def setNullAt(i: Int) = { nullBits(i) = true } - override def isNullAt(i: Int) = nullBits(i) - """.children - - val tupleElements = expressions.zipWithIndex.flatMap { + val ctx = newCodeGenContext() + val columns = expressions.zipWithIndex.map { case (e, i) => - val elementName = newTermName(s"c$i") - val evaluatedExpression = expressionEvaluator(e) - val iLit = ru.Literal(Constant(i)) + s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" + }.mkString("\n ") - q""" - var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _ + val initColumns = expressions.zipWithIndex.map { + case (e, i) => + val eval = e.gen(ctx) + s""" { - ..${evaluatedExpression.code} - if(${evaluatedExpression.nullTerm}) - setNullAt($iLit) - else { - nullBits($iLit) = false - $elementName = ${evaluatedExpression.primitiveTerm} + // column$i + ${eval.code} + nullBits[$i] = ${eval.isNull}; + if (!${eval.isNull}) { + c$i = ${eval.primitive}; } } - """.children : Seq[Tree] - } - - val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" - val applyFunction = { - val cases = (0 until expressions.size).map { i => - val ordinal = ru.Literal(Constant(i)) - val elementName = newTermName(s"c$i") - val iLit = ru.Literal(Constant(i)) - - q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }" - } - q"override def apply(i: Int): Any = { ..$cases; $accessorFailure }" - } - - val updateFunction = { - val cases = expressions.zipWithIndex.map {case (e, i) => - val ordinal = ru.Literal(Constant(i)) - val elementName = newTermName(s"c$i") - val iLit = ru.Literal(Constant(i)) - - q""" - if(i == $ordinal) { - if(value == null) { - setNullAt(i) - } else { - nullBits(i) = false - $elementName = value.asInstanceOf[${termForType(e.dataType)}] - } - return - }""" - } - q"override def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" - } - - val specificAccessorFunctions = nativeTypes.map { dataType => - val ifStatements = expressions.zipWithIndex.flatMap { - // getString() is not used by expressions - case (e, i) if e.dataType == dataType && dataType != StringType => - val elementName = newTermName(s"c$i") - // TODO: The string of ifs gets pretty inefficient as the row grows in size. - // TODO: Optional null checks? - q"if(i == $i) return $elementName" :: Nil - case _ => Nil - } - dataType match { - // Row() need this interface to compile - case StringType => - q""" - override def getString(i: Int): String = { - $accessorFailure - }""" - case other => - q""" - override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = { - ..$ifStatements; - $accessorFailure - }""" - } - } - - val specificMutatorFunctions = nativeTypes.map { dataType => - val ifStatements = expressions.zipWithIndex.flatMap { - // setString() is not used by expressions - case (e, i) if e.dataType == dataType && dataType != StringType => - val elementName = newTermName(s"c$i") - // TODO: The string of ifs gets pretty inefficient as the row grows in size. - // TODO: Optional null checks? - q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil - case _ => Nil + """ + }.mkString("\n") + + val getCases = (0 until expressions.size).map { i => + s"case $i: return c$i;" + }.mkString("\n ") + + val updateCases = expressions.zipWithIndex.map { case (e, i) => + s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" + }.mkString("\n ") + + val specificAccessorFunctions = ctx.nativeTypes.map { dataType => + val cases = expressions.zipWithIndex.map { + case (e, i) if e.dataType == dataType => + s"case $i: return c$i;" + case _ => "" + }.mkString("\n ") + if (cases.count(_ != '\n') > 0) { + s""" + @Override + public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) { + if (isNullAt(i)) { + return ${ctx.defaultValue(dataType)}; + } + switch (i) { + $cases + } + return ${ctx.defaultValue(dataType)}; + }""" + } else { + "" } - dataType match { - case StringType => - // MutableRow() need this interface to compile - q""" - override def setString(i: Int, value: String) { - $accessorFailure - }""" - case other => - q""" - override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) { - ..$ifStatements; - $accessorFailure - }""" + }.mkString("\n") + + val specificMutatorFunctions = ctx.nativeTypes.map { dataType => + val cases = expressions.zipWithIndex.map { + case (e, i) if e.dataType == dataType => + s"case $i: { c$i = value; return; }" + case _ => "" + }.mkString("\n") + if (cases.count(_ != '\n') > 0) { + s""" + @Override + public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) { + nullBits[i] = false; + switch (i) { + $cases + } + }""" + } else { + "" } - } + }.mkString("\n") - val hashValues = expressions.zipWithIndex.map { case (e,i) => - val elementName = newTermName(s"c$i") + val hashValues = expressions.zipWithIndex.map { case (e, i) => + val col = newTermName(s"c$i") val nonNull = e.dataType match { - case BooleanType => q"if ($elementName) 0 else 1" - case ByteType | ShortType | IntegerType => q"$elementName.toInt" - case LongType => q"($elementName ^ ($elementName >>> 32)).toInt" - case FloatType => q"java.lang.Float.floatToIntBits($elementName)" + case BooleanType => s"$col ? 0 : 1" + case ByteType | ShortType | IntegerType | DateType => s"$col" + case LongType => s"$col ^ ($col >>> 32)" + case FloatType => s"Float.floatToIntBits($col)" case DoubleType => - q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }" - case _ => q"$elementName.hashCode" + s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" + case _ => s"$col.hashCode()" } - q"if (isNullAt($i)) 0 else $nonNull" + s"isNullAt($i) ? 0 : ($nonNull)" } - val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree) + val hashUpdates: String = hashValues.map( v => + s""" + result *= 37; result += $v;""" + ).mkString("\n") - val hashCodeFunction = - q""" - override def hashCode(): Int = { - var result: Int = 37 - ..$hashUpdates - result - } + val columnChecks = expressions.zipWithIndex.map { case (e, i) => + s""" + if (isNullAt($i) != row.isNullAt($i) || !isNullAt($i) && !get($i).equals(row.get($i))) { + return false; + } """ + }.mkString("\n") + + val code = s""" + import org.apache.spark.sql.Row; - val columnChecks = (0 until expressions.size).map { i => - val elementName = newTermName(s"c$i") - q"if (this.$elementName != specificType.$elementName) return false" + public SpecificProjection generate($exprType[] expr) { + return new SpecificProjection(expr); } - val equalsFunction = - q""" - override def equals(other: Any): Boolean = other match { - case specificType: SpecificRow => - ..$columnChecks - return true - case other => super.equals(other) - } - """ + class SpecificProjection extends ${typeOf[BaseProject]} { + private $exprType[] expressions = null; - val allColumns = (0 until expressions.size).map { i => - val iLit = ru.Literal(Constant(i)) - q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" + public SpecificProjection($exprType[] expr) { + expressions = expr; + } + + @Override + public Object apply(Object r) { + return new SpecificRow(expressions, (Row) r); + } } - val copyFunction = - q"override def copy() = new $genericRowType(Array[Any](..$allColumns))" - - val toSeqFunction = - q"override def toSeq: Seq[Any] = Seq(..$allColumns)" - - val classBody = - nullFunctions ++ ( - lengthDef +: - applyFunction +: - updateFunction +: - equalsFunction +: - hashCodeFunction +: - copyFunction +: - toSeqFunction +: - (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) - - val code = q""" - final class SpecificRow(i: $rowType) extends $mutableRowType { - ..$classBody + final class SpecificRow extends ${typeOf[BaseMutableRow]} { + + $columns + + public SpecificRow($exprType[] expressions, Row i) { + $initColumns + } + + public int size() { return ${expressions.length};} + private boolean[] nullBits = new boolean[${expressions.length}]; + public void setNullAt(int i) { nullBits[i] = true; } + public boolean isNullAt(int i) { return nullBits[i]; } + + public Object get(int i) { + if (isNullAt(i)) return null; + switch (i) { + $getCases + } + return null; + } + public void update(int i, Object value) { + if (value == null) { + setNullAt(i); + return; + } + nullBits[i] = false; + switch (i) { + $updateCases + } + } + $specificAccessorFunctions + $specificMutatorFunctions + + @Override + public int hashCode() { + int result = 37; + $hashUpdates + return result; } - new $projectionType { def apply(r: $rowType) = new SpecificRow(r) } + @Override + public boolean equals(Object other) { + if (other instanceof Row) { + Row row = (Row) other; + if (row.length() != size()) return false; + $columnChecks + return true; + } + return super.equals(other); + } + } """ - log.debug( - s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}") - toolBox.eval(code).asInstanceOf[Projection] + logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}") + + val c = compile(code) + // fetch the only one method `generate(Expression[])` + val m = c.getDeclaredMethods()(0) + m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 528e38a50a740..7f1b12cdd5800 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -27,12 +27,6 @@ import org.apache.spark.util.Utils */ package object codegen { - /** - * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala - * 2.10. - */ - protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock - /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { val batches = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 956a2429b0b61..6398b8f9e4ed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -24,10 +24,9 @@ import org.apache.spark.sql.types._ * Returns an Array containing the evaluation of all children expressions. */ case class CreateArray(children: Seq[Expression]) extends Expression { - override type EvaluatedType = Any - + override def foldable: Boolean = children.forall(_.foldable) - + lazy val childTypes = children.map(_.dataType).distinct override lazy val resolved = @@ -54,7 +53,6 @@ case class CreateArray(children: Seq[Expression]) extends Expression { * TODO: [[CreateStruct]] does not support codegen. */ case class CreateStruct(children: Seq[NamedExpression]) extends Expression { - override type EvaluatedType = Row override def foldable: Boolean = children.forall(_.foldable) @@ -71,7 +69,7 @@ case class CreateStruct(children: Seq[NamedExpression]) extends Expression { override def nullable: Boolean = false - override def eval(input: Row): EvaluatedType = { + override def eval(input: Row): Any = { Row(children.map(_.eval(input)): _*) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala new file mode 100644 index 0000000000000..1a5cde26c9b13 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types.{BooleanType, DataType} + + +case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) + extends Expression { + + override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil + override def nullable: Boolean = trueValue.nullable || falseValue.nullable + + override def checkInputDataTypes(): TypeCheckResult = { + if (predicate.dataType != BooleanType) { + TypeCheckResult.TypeCheckFailure( + s"type of predicate expression in If should be boolean, not ${predicate.dataType}") + } else if (trueValue.dataType != falseValue.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def dataType: DataType = trueValue.dataType + + override def eval(input: Row): Any = { + if (true == predicate.eval(input)) { + trueValue.eval(input) + } else { + falseValue.eval(input) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val condEval = predicate.gen(ctx) + val trueEval = trueValue.gen(ctx) + val falseEval = falseValue.gen(ctx) + + s""" + ${condEval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${condEval.isNull} && ${condEval.primitive}) { + ${trueEval.code} + ${ev.isNull} = ${trueEval.isNull}; + ${ev.primitive} = ${trueEval.primitive}; + } else { + ${falseEval.code} + ${ev.isNull} = ${falseEval.isNull}; + ${ev.primitive} = ${falseEval.primitive}; + } + """ + } + + override def toString: String = s"if ($predicate) $trueValue else $falseValue" +} + +trait CaseWhenLike extends Expression { + self: Product => + + // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last + // element is the value for the default catch-all case (if provided). + // Hence, `branches` consists of at least two elements, and can have an odd or even length. + def branches: Seq[Expression] + + @transient lazy val whenList = + branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq + @transient lazy val thenList = + branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq + val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) + + // both then and else expressions should be considered. + def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) + def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 + + override def checkInputDataTypes(): TypeCheckResult = { + if (valueTypesEqual) { + checkTypesInternal() + } else { + TypeCheckResult.TypeCheckFailure( + "THEN and ELSE expressions should all be same type or coercible to a common type") + } + } + + protected def checkTypesInternal(): TypeCheckResult + + override def dataType: DataType = thenList.head.dataType + + override def nullable: Boolean = { + // If no value is nullable and no elseValue is provided, the whole statement defaults to null. + thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) + } +} + +// scalastyle:off +/** + * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + * Refer to this link for the corresponding semantics: + * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions + */ +// scalastyle:on +case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { + + // Use private[this] Array to speed up evaluation. + @transient private[this] lazy val branchesArr = branches.toArray + + override def children: Seq[Expression] = branches + + override protected def checkTypesInternal(): TypeCheckResult = { + if (whenList.forall(_.dataType == BooleanType)) { + TypeCheckResult.TypeCheckSuccess + } else { + val index = whenList.indexWhere(_.dataType != BooleanType) + TypeCheckResult.TypeCheckFailure( + s"WHEN expressions in CaseWhen should all be boolean type, " + + s"but the ${index + 1}th when expression's type is ${whenList(index)}") + } + } + + /** Written in imperative fashion for performance considerations. */ + override def eval(input: Row): Any = { + val len = branchesArr.length + var i = 0 + // If all branches fail and an elseVal is not provided, the whole statement + // defaults to null, according to Hive's semantics. + while (i < len - 1) { + if (branchesArr(i).eval(input) == true) { + return branchesArr(i + 1).eval(input) + } + i += 2 + } + var res: Any = null + if (i == len - 1) { + res = branchesArr(i).eval(input) + } + return res + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (!${cond.isNull} && ${cond.primitive}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + $cases + $other + """ + } + + override def toString: String = { + "CASE" + branches.sliding(2, 2).map { + case Seq(cond, value) => s" WHEN $cond THEN $value" + case Seq(elseValue) => s" ELSE $elseValue" + }.mkString + } +} + +// scalastyle:off +/** + * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". + * Refer to this link for the corresponding semantics: + * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions + */ +// scalastyle:on +case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike { + + // Use private[this] Array to speed up evaluation. + @transient private[this] lazy val branchesArr = branches.toArray + + override def children: Seq[Expression] = key +: branches + + override protected def checkTypesInternal(): TypeCheckResult = { + if ((key +: whenList).map(_.dataType).distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + "key and WHEN expressions should all be same type or coercible to a common type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + /** Written in imperative fashion for performance considerations. */ + override def eval(input: Row): Any = { + val evaluatedKey = key.eval(input) + val len = branchesArr.length + var i = 0 + // If all branches fail and an elseVal is not provided, the whole statement + // defaults to null, according to Hive's semantics. + while (i < len - 1) { + if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) { + return branchesArr(i + 1).eval(input) + } + i += 2 + } + var res: Any = null + if (i == len - 1) { + res = branchesArr(i).eval(input) + } + return res + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val keyEval = key.gen(ctx) + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (${keyEval.isNull} && ${cond.isNull} || + !${keyEval.isNull} && !${cond.isNull} + && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${keyEval.code} + $cases + $other + """ + } + + private def equalNullSafe(l: Any, r: Any) = { + if (l == null && r == null) { + true + } else if (l == null || r == null) { + false + } else { + l == r + } + } + + override def toString: String = { + s"CASE $key" + branches.sliding(2, 2).map { + case Seq(cond, value) => s" WHEN $cond THEN $value" + case Seq(elseValue) => s" ELSE $elseValue" + }.mkString + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index adb94df7d1c7b..8ab6d977dd3a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ case class UnscaledValue(child: Expression) extends UnaryExpression { - override type EvaluatedType = Any override def dataType: DataType = LongType override def foldable: Boolean = child.foldable @@ -36,11 +36,14 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { childResult.asInstanceOf[Decimal].toUnscaledLong } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") + } } /** Create a Decimal from an unscaled Long value */ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { - override type EvaluatedType = Decimal override def dataType: DataType = DecimalType(precision, scale) override def foldable: Boolean = child.foldable @@ -55,4 +58,18 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale) } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.decimalType} ${ev.primitive} = null; + + if (!${ev.isNull}) { + ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull( + ${eval.primitive}, $precision, $scale); + ${ev.isNull} = ${ev.primitive} == null; + } + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 747a47bdde953..b6191eafba71b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -40,8 +40,6 @@ import org.apache.spark.sql.types._ abstract class Generator extends Expression { self: Product => - override type EvaluatedType = TraversableOnce[Row] - // TODO ideally we should return the type of ArrayType(StructType), // however, we don't keep the output field names in the Generator. override def dataType: DataType = throw new UnsupportedOperationException @@ -73,12 +71,23 @@ case class UserDefinedGenerator( children: Seq[Expression]) extends Generator { + @transient private[this] var inputRow: InterpretedProjection = _ + @transient private[this] var convertToScala: (Row) => Row = _ + + private def initializeConverters(): Unit = { + inputRow = new InterpretedProjection(children) + convertToScala = { + val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) + CatalystTypeConverters.createToScalaConverter(inputSchema) + }.asInstanceOf[(Row => Row)] + } + override def eval(input: Row): TraversableOnce[Row] = { - // TODO(davies): improve this + if (inputRow == null) { + initializeConverters() + } // Convert the objects into Scala Type before calling function, we need schema to support UDT - val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) - val inputRow = new InterpretedProjection(children) - function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row]) + function(convertToScala(inputRow(input))) } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" @@ -105,8 +114,8 @@ case class Explode(child: Expression) val inputArray = child.eval(input).asInstanceOf[Seq[Any]] if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) case MapType(_, _, _) => - val inputMap = child.eval(input).asInstanceOf[Map[Any,Any]] - if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) } + val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]] + if (inputMap == null) Nil else inputMap.map { case (k, v) => new GenericRow(Array(k, v)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 5f8c7354aede1..297b35b4da94c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -78,14 +79,65 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def toString: String = if (value != null) value.toString else "null" - type EvaluatedType = Any + override def equals(other: Any): Boolean = other match { + case o: Literal => + dataType.equals(o.dataType) && + (value == null && null == o.value || value != null && value.equals(o.value)) + case _ => false + } + override def eval(input: Row): Any = value + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + // change the isNull and primitive to consts, to inline them + if (value == null) { + ev.isNull = "true" + ev.primitive = ctx.defaultValue(dataType) + "" + } else { + dataType match { + case BooleanType => + ev.isNull = "false" + ev.primitive = value.toString + "" + case FloatType => // This must go before NumericType + val v = value.asInstanceOf[Float] + if (v.isNaN || v.isInfinite) { + super.genCode(ctx, ev) + } else { + ev.isNull = "false" + ev.primitive = s"${value}f" + "" + } + case DoubleType => // This must go before NumericType + val v = value.asInstanceOf[Double] + if (v.isNaN || v.isInfinite) { + super.genCode(ctx, ev) + } else { + ev.isNull = "false" + ev.primitive = s"${value}" + "" + } + + case ByteType | ShortType => // This must go before NumericType + ev.isNull = "false" + ev.primitive = s"(${ctx.javaType(dataType)})$value" + "" + case dt: NumericType if !dt.isInstanceOf[DecimalType] => + ev.isNull = "false" + ev.primitive = value.toString + "" + // eval() version may be faster for non-primitive types + case other => + super.genCode(ctx, ev) + } + } + } } // TODO: Specialize case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) extends LeafExpression { - type EvaluatedType = Any def update(expression: Expression, input: Row): Unit = { value = expression.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala new file mode 100644 index 0000000000000..7dacb6a9b47b6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types.{DataType, DoubleType} + +/** + * A unary expression specifically for math functions. Math Functions expect a specific type of + * input format, therefore these functions extend `ExpectsInputTypes`. + * @param name The short name of the function + */ +abstract class UnaryMathExpression(f: Double => Double, name: String) + extends UnaryExpression with Serializable with ExpectsInputTypes { + self: Product => + + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) + override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + override def toString: String = s"$name($child)" + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val result = f(evalE.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } + + // name of function in java.lang.Math + def funcName: String = name.toLowerCase + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + } + """ + } +} + +/** + * A binary expression specifically for math functions that take two `Double`s as input and returns + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) + extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + + override def toString: String = s"$name($left, $right)" + + override def dataType: DataType = DoubleType + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Unary math functions +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") + +case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") + +case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") + +case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") + +case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") + +case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") + +case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") + +case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") + +case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") + +case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") + +case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") + +case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10") + +case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") + +case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { + override def funcName: String = "rint" +} + +case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") + +case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") + +case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") + +case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") + +case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") + +case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") { + override def funcName: String = "toDegrees" +} + +case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { + override def funcName: String = "toRadians" +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Binary math functions +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +case class Atan2(left: Expression, right: Expression) + extends BinaryMathExpression(math.atan2, "ATAN2") { + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + evalE2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} + +case class Hypot(left: Expression, right: Expression) + extends BinaryMathExpression(math.hypot, "HYPOT") + +case class Pow(left: Expression, right: Expression) + extends BinaryMathExpression(math.pow, "POWER") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala deleted file mode 100644 index fcc06d3aa1036..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.mathfuncs - -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} -import org.apache.spark.sql.types._ - -/** - * A binary expression specifically for math functions that take two `Double`s as input and returns - * a `Double`. - * @param f The math function. - * @param name The short name of the function - */ -abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => - type EvaluatedType = Any - override def symbol: String = null - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) - - override def nullable: Boolean = left.nullable || right.nullable - override def toString: String = s"$name($left, $right)" - - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } - - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) - if (result.isNaN) null else result - } - } - } -} - -case class Atan2( - left: Expression, - right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 - val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, - evalE2.asInstanceOf[Double] + 0.0) - if (result.isNaN) null else result - } - } - } -} - -case class Hypot( - left: Expression, - right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") - -case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala deleted file mode 100644 index dc68469e060cb..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.mathfuncs - -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} -import org.apache.spark.sql.types._ - -/** - * A unary expression specifically for math functions. Math Functions expect a specific type of - * input format, therefore these functions extend `ExpectsInputTypes`. - * @param name The short name of the function - */ -abstract class MathematicalExpression(f: Double => Double, name: String) - extends UnaryExpression with Serializable with ExpectsInputTypes { - self: Product => - type EvaluatedType = Any - - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) - override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = true - override def toString: String = s"$name($child)" - - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val result = f(evalE.asInstanceOf[Double]) - if (result.isNaN) null else result - } - } -} - -case class Acos(child: Expression) extends MathematicalExpression(math.acos, "ACOS") - -case class Asin(child: Expression) extends MathematicalExpression(math.asin, "ASIN") - -case class Atan(child: Expression) extends MathematicalExpression(math.atan, "ATAN") - -case class Cbrt(child: Expression) extends MathematicalExpression(math.cbrt, "CBRT") - -case class Ceil(child: Expression) extends MathematicalExpression(math.ceil, "CEIL") - -case class Cos(child: Expression) extends MathematicalExpression(math.cos, "COS") - -case class Cosh(child: Expression) extends MathematicalExpression(math.cosh, "COSH") - -case class Exp(child: Expression) extends MathematicalExpression(math.exp, "EXP") - -case class Expm1(child: Expression) extends MathematicalExpression(math.expm1, "EXPM1") - -case class Floor(child: Expression) extends MathematicalExpression(math.floor, "FLOOR") - -case class Log(child: Expression) extends MathematicalExpression(math.log, "LOG") - -case class Log10(child: Expression) extends MathematicalExpression(math.log10, "LOG10") - -case class Log1p(child: Expression) extends MathematicalExpression(math.log1p, "LOG1P") - -case class Rint(child: Expression) extends MathematicalExpression(math.rint, "ROUND") - -case class Signum(child: Expression) extends MathematicalExpression(math.signum, "SIGNUM") - -case class Sin(child: Expression) extends MathematicalExpression(math.sin, "SIN") - -case class Sinh(child: Expression) extends MathematicalExpression(math.sinh, "SINH") - -case class Tan(child: Expression) extends MathematicalExpression(math.tan, "TAN") - -case class Tanh(child: Expression) extends MathematicalExpression(math.tanh, "TANH") - -case class ToDegrees(child: Expression) - extends MathematicalExpression(math.toDegrees, "DEGREES") - -case class ToRadians(child: Expression) - extends MathematicalExpression(math.toRadians, "RADIANS") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 50be26d0b08b5..2e4b9ba678433 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.trees.LeafNode +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types._ object NamedExpression { @@ -111,12 +111,13 @@ case class Alias(child: Expression, name: String)( val explicitMetadata: Option[Metadata] = None) extends NamedExpression with trees.UnaryNode[Expression] { - override type EvaluatedType = Any // Alias(Generator, xx) need to be transformed into Generate(generator, ...) override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator] override def eval(input: Row): Any = child.eval(input) + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable override def metadata: Metadata = { @@ -229,7 +230,7 @@ case class AttributeReference( } // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$name#${exprId.id}$typeSuffix" @@ -240,7 +241,6 @@ case class AttributeReference( * expression id or the unresolved indicator. */ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { - type EvaluatedType = Any override def toString: String = name @@ -252,7 +252,7 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def eval(input: Row): Any = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException override def dataType: DataType = NullType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index f9161cf34f0c9..c2d1a4eadae29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types.DataType case class Coalesce(children: Seq[Expression]) extends Expression { - type EvaluatedType = Any /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ override def nullable: Boolean = !children.exists(!_.nullable) @@ -52,6 +52,25 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } result } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + + children.map { e => + val eval = e.gen(ctx) + s""" + if (${ev.isNull}) { + ${eval.code} + if (!${eval.isNull}) { + ${ev.isNull} = false; + ${ev.primitive} = ${eval.primitive}; + } + } + """ + }.mkString("\n") + } } case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { @@ -62,6 +81,13 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr child.eval(input) == null } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + ev.isNull = "false" + ev.primitive = eval.isNull + eval.code + } + override def toString: String = s"IS NULL $child" } @@ -73,6 +99,13 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E override def eval(input: Row): Any = { child.eval(input) != null } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + ev.isNull = "false" + ev.primitive = s"(!(${eval.isNull}))" + eval.code + } } /** @@ -96,4 +129,25 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate } numNonNulls >= n } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val nonnull = ctx.freshName("nonnull") + val code = children.map { e => + val eval = e.gen(ctx) + s""" + if ($nonnull < $n) { + ${eval.code} + if (!${eval.isNull}) { + $nonnull += 1; + } + } + """ + }.mkString("\n") + s""" + int $nonnull = 0; + $code + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = $nonnull >= $n; + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1d72a9eb834b9..3cbdfdfb13847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, AtomicType} +import org.apache.spark.sql.types._ object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -35,8 +36,6 @@ trait Predicate extends Expression { self: Product => override def dataType: DataType = BooleanType - - type EvaluatedType = Any } trait PredicateHelper { @@ -84,6 +83,10 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex case b: Boolean => !b } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"!($c)") + } } /** @@ -143,6 +146,29 @@ case class And(left: Expression, right: Expression) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + + // The result should be `false`, if any of them is `false` whenever the other is null or not. + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = false; + + if (!${eval1.isNull} && !${eval1.primitive}) { + } else { + ${eval2.code} + if (!${eval2.isNull} && !${eval2.primitive}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.primitive} = true; + } else { + ${ev.isNull} = true; + } + } + """ + } } case class Or(left: Expression, right: Expression) @@ -169,59 +195,45 @@ case class Or(left: Expression, right: Expression) } } } -} -abstract class BinaryComparison extends BinaryExpression with Predicate { - self: Product => -} + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) -case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "=" + // The result should be `true`, if any of them is `true` whenever the other is null or not. + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = true; - override def eval(input: Row): Any = { - val l = left.eval(input) - if (l == null) { - null - } else { - val r = right.eval(input) - if (r == null) null - else if (left.dataType != BinaryType) l == r - else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) - } + if (!${eval1.isNull} && ${eval1.primitive}) { + } else { + ${eval2.code} + if (!${eval2.isNull} && ${eval2.primitive}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.primitive} = false; + } else { + ${ev.isNull} = true; + } + } + """ } } -case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<=>" - - override def nullable: Boolean = false +abstract class BinaryComparison extends BinaryExpression with Predicate { + self: Product => - override def eval(input: Row): Any = { - val l = left.eval(input) - val r = right.eval(input) - if (l == null && r == null) { - true - } else if (l == null || r == null) { - false + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in ${this.getClass.getSimpleName} " + + s"(${left.dataType} and ${right.dataType}).") } else { - l == r + checkTypesInternal(dataType) } } -} -case class LessThan(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<" - - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + protected def checkTypesInternal(t: DataType): TypeCheckResult override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -232,258 +244,118 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso if (evalE2 == null) { null } else { - ordering.lt(evalE1, evalE2) + evalInternal(evalE1, evalE2) } } } -} -case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<=" - - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") + case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { + (c1, c3) => s"$c1 $symbol $c3" + }) + case TimestampType => + // java.sql.Timestamp does not have compare() + super.genCode(ctx, ev) + case other => defineCodeGen (ctx, ev, { + (c1, c2) => s"$c1.compare($c2) $symbol 0" + }) } } - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.lteq(evalE1, evalE2) - } - } - } + protected def evalInternal(evalE1: Any, evalE2: Any): Any = + sys.error(s"BinaryComparisons must override either eval or evalInternal") } -case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = ">" - - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } - - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.gt(evalE1, evalE2) - } - } - } +private[sql] object BinaryComparison { + def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right)) } -case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = ">=" +case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = "=" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.gteq(evalE1, evalE2) - } - } + protected override def evalInternal(l: Any, r: Any) = { + if (left.dataType != BinaryType) l == r + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType)) } } -case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { - - override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil - override def nullable: Boolean = trueValue.nullable || falseValue.nullable +case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = "<=>" - override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException( - this, - s"Can not resolve due to differing types ${trueValue.dataType}, ${falseValue.dataType}") - } - trueValue.dataType - } + override def nullable: Boolean = false - type EvaluatedType = Any + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess override def eval(input: Row): Any = { - if (true == predicate.eval(input)) { - trueValue.eval(input) + val l = left.eval(input) + val r = right.eval(input) + if (l == null && r == null) { + true + } else if (l == null || r == null) { + false } else { - falseValue.eval(input) + l == r } } - override def toString: String = s"if ($predicate) $trueValue else $falseValue" + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive) + ev.isNull = "false" + eval1.code + eval2.code + s""" + boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) || + (!${eval1.isNull} && $equalCode); + """ + } } -trait CaseWhenLike extends Expression { - self: Product => - - type EvaluatedType = Any - - // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last - // element is the value for the default catch-all case (if provided). - // Hence, `branches` consists of at least two elements, and can have an odd or even length. - def branches: Seq[Expression] - - @transient lazy val whenList = - branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq - @transient lazy val thenList = - branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq - val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) +case class LessThan(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = "<" - // both then and else val should be considered. - def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) - def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1 + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") - } - valueTypes.head - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) - override def nullable: Boolean = { - // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) - } + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lt(evalE1, evalE2) } -// scalastyle:off -/** - * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". - * Refer to this link for the corresponding semantics: - * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions - */ -// scalastyle:on -case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { - - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray - - override def children: Seq[Expression] = branches +case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = "<=" - override lazy val resolved: Boolean = - childrenResolved && - whenList.forall(_.dataType == BooleanType) && - valueTypesEqual + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - /** Written in imperative fashion for performance considerations. */ - override def eval(input: Row): Any = { - val len = branchesArr.length - var i = 0 - // If all branches fail and an elseVal is not provided, the whole statement - // defaults to null, according to Hive's semantics. - while (i < len - 1) { - if (branchesArr(i).eval(input) == true) { - return branchesArr(i + 1).eval(input) - } - i += 2 - } - var res: Any = null - if (i == len - 1) { - res = branchesArr(i).eval(input) - } - return res - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) - override def toString: String = { - "CASE" + branches.sliding(2, 2).map { - case Seq(cond, value) => s" WHEN $cond THEN $value" - case Seq(elseValue) => s" ELSE $elseValue" - }.mkString - } + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lteq(evalE1, evalE2) } -// scalastyle:off -/** - * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". - * Refer to this link for the corresponding semantics: - * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions - */ -// scalastyle:on -case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike { +case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = ">" - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def children: Seq[Expression] = key +: branches + private lazy val ordering = TypeUtils.getOrdering(left.dataType) - override lazy val resolved: Boolean = - childrenResolved && valueTypesEqual + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gt(evalE1, evalE2) +} - /** Written in imperative fashion for performance considerations. */ - override def eval(input: Row): Any = { - val evaluatedKey = key.eval(input) - val len = branchesArr.length - var i = 0 - // If all branches fail and an elseVal is not provided, the whole statement - // defaults to null, according to Hive's semantics. - while (i < len - 1) { - if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) { - return branchesArr(i + 1).eval(input) - } - i += 2 - } - var res: Any = null - if (i == len - 1) { - res = branchesArr(i).eval(input) - } - return res - } +case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = ">=" - private def equalNullSafe(l: Any, r: Any) = { - if (l == null && r == null) { - true - } else if (l == null || r == null) { - false - } else { - l == r - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def toString: String = { - s"CASE $key" + branches.sliding(2, 2).map { - case Seq(cond, value) => s" WHEN $cond THEN $value" - case Seq(elseValue) => s" ELSE $elseValue" - }.mkString - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 66d7c8b07cce8..6e4e9cb1be090 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom /** * A Random distribution generating expression. - * TODO: This can be made generic to generate any type of random distribution, or any type of + * TODO: This can be made generic to generate any type of random distribution, or any type of * StructType. * * Since this expression is stateful, it cannot be a case object. @@ -38,7 +39,7 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { */ @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId()) - override type EvaluatedType = Double + override def deterministic: Boolean = false override def nullable: Boolean = false @@ -46,11 +47,29 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ -case class Rand(seed: Long = Utils.random.nextLong()) extends RDG(seed) { +case class Rand(seed: Long) extends RDG(seed) { override def eval(input: Row): Double = rng.nextDouble() } +object Rand { + def apply(): Rand = apply(Utils.random.nextLong()) + + def apply(seed: Expression): Rand = apply(seed match { + case IntegerLiteral(s) => s + case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") + }) +} + /** Generate a random column with i.i.d. gaussian random distribution. */ -case class Randn(seed: Long = Utils.random.nextLong()) extends RDG(seed) { +case class Randn(seed: Long) extends RDG(seed) { override def eval(input: Row): Double = rng.nextGaussian() } + +object Randn { + def apply(): Randn = apply(Utils.random.nextLong()) + + def apply(seed: Expression): Randn = apply(seed match { + case IntegerLiteral(s) => s + case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") + }) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 4c44182278207..2bcb960e9177e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -51,7 +52,6 @@ private[sql] class OpenHashSetUDT( * Creates a new set of the specified type */ case class NewSet(elementType: DataType) extends LeafExpression { - type EvaluatedType = Any override def nullable: Boolean = false @@ -61,6 +61,17 @@ case class NewSet(elementType: DataType) extends LeafExpression { new OpenHashSet[Any]() } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + elementType match { + case IntegerType | LongType => + ev.isNull = "false" + s""" + ${ctx.javaType(dataType)} ${ev.primitive} = new ${ctx.javaType(dataType)}(); + """ + case _ => super.genCode(ctx, ev) + } + } + override def toString: String = s"new Set($dataType)" } @@ -69,7 +80,6 @@ case class NewSet(elementType: DataType) extends LeafExpression { * For performance, this expression mutates its input during evaluation. */ case class AddItemToSet(item: Expression, set: Expression) extends Expression { - type EvaluatedType = Any override def children: Seq[Expression] = item :: set :: Nil @@ -93,6 +103,25 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType + elementType match { + case IntegerType | LongType => + val itemEval = item.gen(ctx) + val setEval = set.gen(ctx) + val htype = ctx.javaType(dataType) + + ev.isNull = "false" + ev.primitive = setEval.primitive + itemEval.code + setEval.code + s""" + if (!${itemEval.isNull} && !${setEval.isNull}) { + (($htype)${setEval.primitive}).add(${itemEval.primitive}); + } + """ + case _ => super.genCode(ctx, ev) + } + } + override def toString: String = s"$set += $item" } @@ -101,7 +130,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { * For performance, this expression mutates its left input set during evaluation. */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { - type EvaluatedType = Any override def nullable: Boolean = left.nullable || right.nullable @@ -119,21 +147,37 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres val rightValue = iterator.next() leftEval.add(rightValue) } - leftEval - } else { - null } + leftEval } else { null } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType + elementType match { + case IntegerType | LongType => + val leftEval = left.gen(ctx) + val rightEval = right.gen(ctx) + val htype = ctx.javaType(dataType) + + ev.isNull = leftEval.isNull + ev.primitive = leftEval.primitive + leftEval.code + rightEval.code + s""" + if (!${leftEval.isNull} && !${rightEval.isNull}) { + ${leftEval.primitive}.union((${htype})${rightEval.primitive}); + } + """ + case _ => super.genCode(ctx, ev) + } + } } /** * Returns the number of elements in the input set. */ case class CountSet(child: Expression) extends UnaryExpression { - type EvaluatedType = Any override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 7683e0990ce80..856f56488c7a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -20,13 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ trait StringRegexExpression extends ExpectsInputTypes { self: BinaryExpression => - type EvaluatedType = Any - def escape(v: String): String def matches(regex: Pattern, str: String): Boolean @@ -40,14 +39,14 @@ trait StringRegexExpression extends ExpectsInputTypes { case _ => null } - protected def compile(str: String): Pattern = if(str == null) { + protected def compile(str: String): Pattern = if (str == null) { null } else { // Let it raise exception if couldn't compile the regex string Pattern.compile(escape(str)) } - protected def pattern(str: String) = if(cache == null) compile(str) else cache + protected def pattern(str: String) = if (cache == null) compile(str) else cache override def eval(input: Row): Any = { val l = left.eval(input) @@ -114,8 +113,6 @@ case class RLike(left: Expression, right: Expression) trait CaseConversionExpression extends ExpectsInputTypes { self: UnaryExpression => - type EvaluatedType = Any - def convert(v: UTF8String): UTF8String override def foldable: Boolean = child.foldable @@ -137,20 +134,28 @@ trait CaseConversionExpression extends ExpectsInputTypes { * A function that converts the characters of a string to uppercase. */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - + override def convert(v: UTF8String): UTF8String = v.toUpperCase() override def toString: String = s"Upper($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") + } } /** * A function that converts the characters of a string to lowercase. */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - + override def convert(v: UTF8String): UTF8String = v.toLowerCase() override def toString: String = s"Lower($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") + } } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -159,8 +164,6 @@ trait StringComparison extends ExpectsInputTypes { def compare(l: UTF8String, r: UTF8String): Boolean - override type EvaluatedType = Any - override def nullable: Boolean = left.nullable || right.nullable override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) @@ -187,6 +190,9 @@ trait StringComparison extends ExpectsInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") + } } /** @@ -195,6 +201,9 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") + } } /** @@ -203,6 +212,9 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") + } } /** @@ -211,12 +223,11 @@ case class EndsWith(left: Expression, right: Expression) */ case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression with ExpectsInputTypes { - - type EvaluatedType = Any override def foldable: Boolean = str.foldable && pos.foldable && len.foldable override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") @@ -231,7 +242,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) @inline def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and - // negative indices for start positions. If a start index i is greater than 0, it + // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers // to the -ith element before the end of the sequence. If a start index i is 0, it // refers to the first element. @@ -277,3 +288,9 @@ case class Substring(str: Expression, pos: Expression, len: Expression) case _ => s"SUBSTR($str, $pos, $len)" } } + +object Substring { + def apply(str: Expression, pos: Expression): Substring = { + apply(str, pos, Literal(Integer.MAX_VALUE)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 099d67ca7fee3..82c4d462cc322 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -66,9 +66,7 @@ case class WindowSpecDefinition( } } - type EvaluatedType = Any - - override def children: Seq[Expression] = partitionSpec ++ orderSpec + override def children: Seq[Expression] = partitionSpec ++ orderSpec override lazy val resolved: Boolean = childrenResolved && frameSpecification.isInstanceOf[SpecifiedWindowFrame] @@ -76,7 +74,7 @@ case class WindowSpecDefinition( override def toString: String = simpleString - override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def eval(input: Row): Any = throw new UnsupportedOperationException override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException @@ -299,7 +297,7 @@ case class UnresolvedWindowFunction( override def get(index: Int): Any = throw new UnresolvedException(this, "get") // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name(${children.mkString(",")})" @@ -311,25 +309,25 @@ case class UnresolvedWindowFunction( case class UnresolvedWindowExpression( child: UnresolvedWindowFunction, windowSpec: WindowSpecReference) extends UnaryExpression { + override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } case class WindowExpression( windowFunction: WindowFunction, windowSpec: WindowSpecDefinition) extends Expression { - override type EvaluatedType = Any override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil - override def eval(input: Row): EvaluatedType = + override def eval(input: Row): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def dataType: DataType = windowFunction.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c2818d957cc79..c16f08d389955 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -36,6 +36,8 @@ object DefaultOptimizer extends Optimizer { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: + Batch("Distinct", FixedPoint(100), + ReplaceDistinctWithAggregate) :: Batch("Operator Reordering", FixedPoint(100), UnionPushdown, CombineFilters, @@ -179,8 +181,17 @@ object ColumnPruning extends Rule[LogicalPlan] { * expressions into one single expression. */ object ProjectCollapsing extends Rule[LogicalPlan] { + + /** Returns true if any expression in projectList is non-deterministic. */ + private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = { + projectList.exists(expr => expr.find(!_.deterministic).isDefined) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case Project(projectList1, Project(projectList2, child)) => + // We only collapse these two Projects if the child Project's expressions are all + // deterministic. + case Project(projectList1, Project(projectList2, child)) + if !hasNondeterministic(projectList2) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). val aliasMap = AttributeMap(projectList2.collect { @@ -255,7 +266,7 @@ object NullPropagation extends Rule[LogicalPlan] { if (newChildren.length == 0) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { - newChildren(0) + newChildren.head } else { Coalesce(newChildren) } @@ -264,22 +275,23 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) + // MaxOf and MinOf can't do null propagation + case e: MaxOf => e + case e: MinOf => e + // Put exceptional cases above if any - case e: BinaryArithmetic => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } - case e: BinaryComparison => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } + case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType) + + case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType) + case e: StringRegexExpression => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } + case e: StringComparison => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) @@ -683,3 +695,15 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { LocalRelation(projectList.map(_.toAttribute), data.map(projection)) } } + +/** + * Replaces logical [[Distinct]] operator with an [[Aggregate]] operator. + * {{{ + * SELECT DISTINCT f1, f2 FROM t ==> SELECT f1, f2 FROM t GROUP BY f1, f2 + * }}} + */ +object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Distinct(child) => Aggregate(child.output, child.output, child) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7967189cacb24..eff5c61644944 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -84,7 +84,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy val newArgs = productIterator.map { case e: Expression => transformExpressionDown(e) case Some(e: Expression) => Some(transformExpressionDown(e)) - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionDown(e) @@ -117,7 +117,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy val newArgs = productIterator.map { case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionUp(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 01f4b6e9bb77d..e77e5c27b687a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -93,7 +93,7 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override lazy val resolved: Boolean = childrenResolved && - left.output.zip(right.output).forall { case (l,r) => l.dataType == r.dataType } + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } override def statistics: Statistics = { val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes @@ -339,6 +339,9 @@ case class Sample( override def output: Seq[Attribute] = child.output } +/** + * Returns a new logical plan that dedups input rows. + */ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index fb4217a44807b..80ba57a082a60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -169,7 +169,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def keyExpressions: Seq[Expression] = expressions - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -213,6 +213,6 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def keyExpressions: Seq[Expression] = ordering.map(_.child) - override def eval(input: Row): EvaluatedType = + override def eval(input: Row): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 28e15566f0961..36d005d0e1684 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -254,7 +254,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { Some(arg) } - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => @@ -311,7 +311,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { Some(arg) } - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala index 3f92be4a55d7d..ad649acf536f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala @@ -24,7 +24,7 @@ import java.util.{Calendar, TimeZone} import org.apache.spark.sql.catalyst.expressions.Cast /** - * helper function to convert between Int value of days since 1970-01-01 and java.sql.Date + * Helper function to convert between Int value of days since 1970-01-01 and java.sql.Date */ object DateUtils { private val MILLIS_PER_DAY = 86400000 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala new file mode 100644 index 0000000000000..191d5e6399fc9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +/** + * Build a map with String type of key, and it also supports either key case + * sensitive or insensitive. + */ +object StringKeyHashMap { + def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { + case false => new StringKeyHashMap[T](_.toLowerCase) + case true => new StringKeyHashMap[T](identity) + } +} + + +class StringKeyHashMap[T](normalizer: (String) => String) { + private val base = new collection.mutable.HashMap[String, T]() + + def apply(key: String): T = base(normalizer(key)) + + def get(key: String): Option[T] = base.get(normalizer(key)) + + def put(key: String, value: T): Option[T] = base.put(normalizer(key), value) + + def remove(key: String): Option[T] = base.remove(normalizer(key)) + + def iterator: Iterator[(String, T)] = base.toIterator +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala new file mode 100644 index 0000000000000..0bb12d2039ffc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.types._ + +/** + * Helper function to check for valid data types + */ +object TypeUtils { + def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[NumericType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts numeric types, not $t") + } + } + + def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[IntegralType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts integral types, not $t") + } + } + + def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[AtomicType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts non-complex types, not $t") + } + } + + def getNumeric(t: DataType): Numeric[Any] = + t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] + + def getOrdering(t: DataType): Ordering[Any] = + t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 9d613a940ee86..07054166a5e88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -83,7 +83,7 @@ package object util { } def resourceToString( - resource:String, + resource: String, encoding: String = "UTF-8", classLoader: ClassLoader = Utils.getSparkClassLoader): String = { new String(resourceToBytes(resource, classLoader), encoding) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index a0b261649f66f..74677ddfcad65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -107,7 +107,7 @@ protected[sql] abstract class AtomicType extends DataType { abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a - // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // type parameter and add a numeric annotation (i.e., [JvmType : Numeric]). This gets // desugared by the compiler into an argument to the objects constructor. This means there is no // longer an no argument constructor and thus the JVM cannot serialize the object anymore. private[sql] val numeric: Numeric[InternalType] @@ -165,6 +165,9 @@ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) + /** + * @deprecated As of 1.2.0, replaced by `DataType.fromJson()` + */ @deprecated("Use DataType.fromJson instead", "1.2.0") def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) @@ -271,7 +274,7 @@ object DataType { protected lazy val structField: Parser[StructField] = ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => + case name ~ tpe ~ nullable => StructField(name, tpe, nullable = nullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 994c5202c15dc..eb3c58c37f308 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -313,7 +313,7 @@ object Decimal { // See scala.math's Numeric.scala for examples for Scala's built-in types. /** Common methods for Decimal evidence parameters */ - trait DecimalIsConflicted extends Numeric[Decimal] { + private[sql] trait DecimalIsConflicted extends Numeric[Decimal] { override def plus(x: Decimal, y: Decimal): Decimal = x + y override def times(x: Decimal, y: Decimal): Decimal = x * y override def minus(x: Decimal, y: Decimal): Decimal = x - y @@ -327,12 +327,12 @@ object Decimal { } /** A [[scala.math.Fractional]] evidence parameter for Decimals. */ - object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { + private[sql] object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { override def div(x: Decimal, y: Decimal): Decimal = x / y } /** A [[scala.math.Integral]] evidence parameter for Decimals. */ - object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { + private[sql] object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { override def quot(x: Decimal, y: Decimal): Decimal = x / y override def rem(x: Decimal, y: Decimal): Decimal = x % y } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 0f8cecd28f7df..407dc27326c2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -82,12 +82,12 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT object DecimalType { val Unlimited: DecimalType = DecimalType(None) - object Fixed { + private[sql] object Fixed { def unapply(t: DecimalType): Option[(Int, Int)] = t.precisionInfo.map(p => (p.precision, p.scale)) } - object Expression { + private[sql] object Expression { def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) case _ => None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java index a64d2bb7cde37..df64a878b6b36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -24,11 +24,11 @@ /** * ::DeveloperApi:: * A user-defined type which can be automatically recognized by a SQLContext and registered. - * + *

    * WARNING: This annotation will only work if both Java and Scala reflection return the same class * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class * is enclosed in an object (a singleton). - * + *

    * WARNING: UDTs are currently only supported from Scala. */ // TODO: Should I used @Documented ? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 7e00a27dfe724..193c08a4d0df7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -230,10 +230,10 @@ object StructType { case (StructType(leftFields), StructType(rightFields)) => val newFields = ArrayBuffer.empty[StructField] + val rightMapped = fieldsMap(rightFields) leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) => - rightFields - .find(_.name == leftName) + rightMapped.get(leftName) .map { case rightField @ StructField(_, rightType, rightNullable, _) => leftField.copy( dataType = merge(leftType, rightType), @@ -243,8 +243,9 @@ object StructType { .foreach(newFields += _) } + val leftMapped = fieldsMap(leftFields) rightFields - .filterNot(f => leftFields.map(_.name).contains(f.name)) + .filterNot(f => leftMapped.get(f.name).nonEmpty) .foreach(newFields += _) StructType(newFields) @@ -264,4 +265,9 @@ object StructType { case _ => throw new SparkException(s"Failed to merge incompatible data types $left and $right") } + + private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { + import scala.collection.breakOut + fields.map(s => (s.name, s))(breakOut) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index bc9c37bf2d5d2..f5d8fcced362b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -203,7 +203,7 @@ object UTF8String { def apply(s: String): UTF8String = { if (s != null) { new UTF8String().set(s) - } else{ + } else { null } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala new file mode 100644 index 0000000000000..df0f04563edcf --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class CatalystTypeConvertersSuite extends SparkFunSuite { + + private val simpleTypes: Seq[DataType] = Seq( + StringType, + DateType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType) + + test("null handling in rows") { + val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t))) + val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema) + val convertToScala = CatalystTypeConverters.createToScalaConverter(schema) + + val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null)) + assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow) + } + + test("null handling for individual values") { + for (dataType <- simpleTypes) { + assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null) + } + } + + test("option handling in convertToCatalyst") { + // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with + // createToCatalystConverter but it may not actually matter as this is only called internally + // in a handful of places where we don't expect to receive Options. + assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123)) + } + + test("option handling in createToCatalystConverter") { + assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index ea82cd2622de9..c046dbf4dc2c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.physical._ /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ -class DistributionSuite extends FunSuite { +class DistributionSuite extends SparkFunSuite { protected def checkSatisfied( inputPartitioning: Partitioning, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index bbc0b661a0c0c..9a24b23024e18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.catalyst import java.math.BigInteger import java.sql.{Date, Timestamp} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.types._ @@ -75,7 +74,7 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } -class ScalaReflectionSuite extends FunSuite { +class ScalaReflectionSuite extends SparkFunSuite { import ScalaReflection._ test("primitive data") { @@ -253,7 +252,7 @@ class ScalaReflectionSuite extends FunSuite { } assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) - assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1,2,3)))) + assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1, 2, 3)))) } test("convert PrimitiveData to catalyst") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 890ea2a84b82e..b93a3abc6ebd2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.Command -import org.scalatest.FunSuite private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { override def output: Seq[Attribute] = Seq.empty @@ -28,7 +28,7 @@ private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Comman } private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") + protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") override protected lazy val start: Parser[LogicalPlan] = set @@ -39,7 +39,7 @@ private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { } private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("EXECUTE") + protected val EXECUTE = Keyword("EXECUTE") override protected lazy val start: Parser[LogicalPlan] = set @@ -49,7 +49,7 @@ private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { } } -class SqlParserSuite extends FunSuite { +class SqlParserSuite extends SparkFunSuite { test("test long keyword") { val parser = new SuperLongKeywordTestParser diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 939cefb71b817..e09cd790a7187 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -27,7 +28,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -class AnalysisSuite extends FunSuite with BeforeAndAfter { +class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { val caseSensitiveConf = new SimpleCatalystConf(true) val caseInsensitiveConf = new SimpleCatalystConf(false) @@ -155,7 +156,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { caseSensitive: Boolean = true): Unit = { test(name) { val error = intercept[AnalysisException] { - if(caseSensitive) { + if (caseSensitive) { caseSensitiveAnalyze(plan) } else { caseInsensitiveAnalyze(plan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 565b1cfe019c7..7bac97b7894f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.SimpleCatalystConf -class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { +class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { val conf = new SimpleCatalystConf(true) val catalog = new SimpleCatalog(conf) val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) @@ -91,8 +92,10 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { } test("Comparison operations") { - checkComparison(LessThan(i, d1), DecimalType.Unlimited) - checkComparison(LessThanOrEqual(d1, d2), DecimalType.Unlimited) + checkComparison(EqualTo(i, d1), DecimalType(10, 1)) + checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) + checkComparison(LessThan(i, d1), DecimalType(10, 1)) + checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) checkComparison(GreaterThan(d2, u), DecimalType.Unlimited) checkComparison(GreaterThanOrEqual(d1, f), DoubleType) checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index fcd745f43cfbf..9977f7af00f6b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -20,18 +20,19 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project} +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends PlanTest { test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { - var found = HiveTypeCoercion.findTightestCommonType(t1, t2) + var found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = HiveTypeCoercion.findTightestCommonType(t2, t1) + found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t2, t1) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") } @@ -104,31 +105,16 @@ class HiveTypeCoercionSuite extends PlanTest { widenTest(ArrayType(IntegerType), StructType(Seq()), None) } - test("boolean casts") { - val booleanCasts = new HiveTypeCoercion { }.BooleanCasts - def ruleTest(initial: Expression, transformed: Expression) { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - comparePlans( - booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - // Remove superflous boolean -> boolean casts. - ruleTest(Cast(Literal(true), BooleanType), Literal(true)) - // Stringify boolean when casting to string. - ruleTest( - Cast(Literal(false), StringType), - If(Literal(false), Literal("true"), Literal("false"))) + private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + comparePlans( + rule(Project(Seq(Alias(initial, "a")()), testRelation)), + Project(Seq(Alias(transformed, "a")()), testRelation)) } test("coalesce casts") { val fac = new HiveTypeCoercion { }.FunctionArgumentConversion - def ruleTest(initial: Expression, transformed: Expression) { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - comparePlans( - fac(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - ruleTest( + ruleTest(fac, Coalesce(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -137,7 +123,7 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest( + ruleTest(fac, Coalesce(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -147,4 +133,58 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType()) :: Nil)) } + + test("type coercion for CaseKeyWhen") { + val cwc = new HiveTypeCoercion {}.CaseWhenCoercion + ruleTest(cwc, + CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) + ) + ruleTest(cwc, + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) + ) + } + + test("type coercion simplification for equal to") { + val be = new HiveTypeCoercion {}.BooleanEquality + + ruleTest(be, + EqualTo(Literal(true), Literal(1)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal(true), Literal(0)), + Not(Literal(true)) + ) + ruleTest(be, + EqualNullSafe(Literal(true), Literal(1)), + And(IsNotNull(Literal(true)), Literal(true)) + ) + ruleTest(be, + EqualNullSafe(Literal(true), Literal(0)), + And(IsNotNull(Literal(true)), Not(Literal(true))) + ) + + ruleTest(be, + EqualTo(Literal(true), Literal(1L)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal(new java.math.BigDecimal(1)), Literal(true)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal(BigDecimal(0)), Literal(true)), + Not(Literal(true)) + ) + ruleTest(be, + EqualTo(Literal(Decimal(1)), Literal(true)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal.create(Decimal(1), DecimalType(8, 0)), Literal(true)), + Literal(true) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala new file mode 100644 index 0000000000000..e1afa81a7a82f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.Matchers._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{DoubleType, IntegerType} + + +class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("arithmetic") { + val row = create_row(1, 2, 3, null) + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.int.at(2) + val c4 = 'a.int.at(3) + + checkEvaluation(UnaryMinus(c1), -1, row) + checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) + + checkEvaluation(Add(c1, c4), null, row) + checkEvaluation(Add(c1, c2), 3, row) + checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation( + Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + + checkEvaluation(-c1, -1, row) + checkEvaluation(c1 + c2, 3, row) + checkEvaluation(c1 - c2, -1, row) + checkEvaluation(c1 * c2, 2, row) + checkEvaluation(c1 / c2, 0, row) + checkEvaluation(c1 % c2, 1, row) + } + + test("fractional arithmetic") { + val row = create_row(1.1, 2.0, 3.1, null) + val c1 = 'a.double.at(0) + val c2 = 'a.double.at(1) + val c3 = 'a.double.at(2) + val c4 = 'a.double.at(3) + + checkEvaluation(UnaryMinus(c1), -1.1, row) + checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) + checkEvaluation(Add(c1, c4), null, row) + checkEvaluation(Add(c1, c2), 3.1, row) + checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) + checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) + checkEvaluation( + Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) + + checkEvaluation(-c1, -1.1, row) + checkEvaluation(c1 + c2, 3.1, row) + checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row) + checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row) + checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row) + checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row) + } + + test("Divide") { + checkEvaluation(Divide(Literal(2), Literal(1)), 2) + checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) + checkEvaluation(Divide(Literal(1), Literal(2)), 0) + checkEvaluation(Divide(Literal(1), Literal(0)), null) + checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) + checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) + checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), + null) + } + + test("Remainder") { + checkEvaluation(Remainder(Literal(2), Literal(1)), 0) + checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0) + checkEvaluation(Remainder(Literal(1), Literal(2)), 1) + checkEvaluation(Remainder(Literal(1), Literal(0)), null) + checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) + checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) + checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), + null) + } + + test("MaxOf") { + checkEvaluation(MaxOf(1, 2), 2) + checkEvaluation(MaxOf(2, 1), 2) + checkEvaluation(MaxOf(1L, 2L), 2L) + checkEvaluation(MaxOf(2L, 1L), 2L) + + checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) + checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) + } + + test("MinOf") { + checkEvaluation(MinOf(1, 2), 1) + checkEvaluation(MinOf(2, 1), 1) + checkEvaluation(MinOf(1L, 2L), 1L) + checkEvaluation(MinOf(2L, 1L), 1L) + + checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) + checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) + } + + test("SQRT") { + val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) + val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) + val rowSequence = inputSequence.map(l => create_row(l.toDouble)) + val d = 'a.double.at(0) + + for ((row, expected) <- rowSequence zip expectedResults) { + checkEvaluation(Sqrt(d), expected, row) + } + + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) + checkEvaluation(Sqrt(-1), null, EmptyRow) + checkEvaluation(Sqrt(-1.5), null, EmptyRow) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala index f2f3a84d19380..97cfb5f06dd73 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.IntegerType -class AttributeSetSuite extends FunSuite { +class AttributeSetSuite extends SparkFunSuite { val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala new file mode 100644 index 0000000000000..c9bbc7a8b8c14 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types._ + + +class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("Bitwise operations") { + val row = create_row(1, 2, 3, null) + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.int.at(2) + val c4 = 'a.int.at(3) + + checkEvaluation(BitwiseAnd(c1, c4), null, row) + checkEvaluation(BitwiseAnd(c1, c2), 0, row) + checkEvaluation(BitwiseAnd(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + + checkEvaluation(BitwiseOr(c1, c4), null, row) + checkEvaluation(BitwiseOr(c1, c2), 3, row) + checkEvaluation(BitwiseOr(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + + checkEvaluation(BitwiseXor(c1, c4), null, row) + checkEvaluation(BitwiseXor(c1, c2), 3, row) + checkEvaluation(BitwiseXor(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + + checkEvaluation(BitwiseNot(c4), null, row) + checkEvaluation(BitwiseNot(c1), -2, row) + checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null, row) + + checkEvaluation(c1 & c2, 0, row) + checkEvaluation(c1 | c2, 3, row) + checkEvaluation(c1 ^ c2, 3, row) + checkEvaluation(~c1, -2, row) + } + + test("unary BitwiseNOT") { + checkEvaluation(BitwiseNot(1), -2) + assert(BitwiseNot(1).dataType === IntegerType) + assert(BitwiseNot(1).eval(EmptyRow).isInstanceOf[Int]) + + checkEvaluation(BitwiseNot(1.toLong), -2.toLong) + assert(BitwiseNot(1.toLong).dataType === LongType) + assert(BitwiseNot(1.toLong).eval(EmptyRow).isInstanceOf[Long]) + + checkEvaluation(BitwiseNot(1.toShort), -2.toShort) + assert(BitwiseNot(1.toShort).dataType === ShortType) + assert(BitwiseNot(1.toShort).eval(EmptyRow).isInstanceOf[Short]) + + checkEvaluation(BitwiseNot(1.toByte), -2.toByte) + assert(BitwiseNot(1.toByte).dataType === ByteType) + assert(BitwiseNot(1.toByte).eval(EmptyRow).isInstanceOf[Byte]) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala new file mode 100644 index 0000000000000..5bc7c30eee1b6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -0,0 +1,532 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Timestamp, Date} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +/** + * Test suite for data type casting expression [[Cast]]. + */ +class CastSuite extends SparkFunSuite with ExpressionEvalHelper { + + private def cast(v: Any, targetType: DataType): Cast = { + v match { + case lit: Expression => Cast(lit, targetType) + case _ => Cast(Literal(v), targetType) + } + } + + // expected cannot be null + private def checkCast(v: Any, expected: Any): Unit = { + checkEvaluation(cast(v, Literal(expected).dataType), expected) + } + + test("cast from int") { + checkCast(0, false) + checkCast(1, true) + checkCast(5, true) + checkCast(1, 1.toByte) + checkCast(1, 1.toShort) + checkCast(1, 1) + checkCast(1, 1.toLong) + checkCast(1, 1.0f) + checkCast(1, 1.0) + checkCast(123, "123") + + checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 1)), null) + checkEvaluation(cast(123, DecimalType(2, 0)), null) + } + + test("cast from long") { + checkCast(0L, false) + checkCast(1L, true) + checkCast(5L, true) + checkCast(1L, 1.toByte) + checkCast(1L, 1.toShort) + checkCast(1L, 1) + checkCast(1L, 1.toLong) + checkCast(1L, 1.0f) + checkCast(1L, 1.0) + checkCast(123L, "123") + + checkEvaluation(cast(123L, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) + checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0)) + + // TODO: Fix the following bug and re-enable it. + // checkEvaluation(cast(123L, DecimalType(2, 0)), null) + } + + test("cast from boolean") { + checkEvaluation(cast(true, IntegerType), 1) + checkEvaluation(cast(false, IntegerType), 0) + checkEvaluation(cast(true, StringType), "true") + checkEvaluation(cast(false, StringType), "false") + checkEvaluation(cast(cast(1, BooleanType), IntegerType), 1) + checkEvaluation(cast(cast(0, BooleanType), IntegerType), 0) + } + + test("cast from int 2") { + checkEvaluation(cast(1, LongType), 1.toLong) + checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) + checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) + + checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 1)), null) + checkEvaluation(cast(123, DecimalType(2, 0)), null) + } + + test("cast from float") { + + } + + test("cast from double") { + checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) + checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) + } + + test("cast from string") { + assert(cast("abcdef", StringType).nullable === false) + assert(cast("abcdef", BinaryType).nullable === false) + assert(cast("abcdef", BooleanType).nullable === false) + assert(cast("abcdef", TimestampType).nullable === true) + assert(cast("abcdef", LongType).nullable === true) + assert(cast("abcdef", IntegerType).nullable === true) + assert(cast("abcdef", ShortType).nullable === true) + assert(cast("abcdef", ByteType).nullable === true) + assert(cast("abcdef", DecimalType.Unlimited).nullable === true) + assert(cast("abcdef", DecimalType(4, 2)).nullable === true) + assert(cast("abcdef", DoubleType).nullable === true) + assert(cast("abcdef", FloatType).nullable === true) + } + + test("data type casting") { + val sd = "1970-01-01" + val d = Date.valueOf(sd) + val zts = sd + " 00:00:00" + val sts = sd + " 00:00:02" + val nts = sts + ".1" + val ts = Timestamp.valueOf(nts) + + checkEvaluation(cast("abdef", StringType), "abdef") + checkEvaluation(cast("abdef", DecimalType.Unlimited), null) + checkEvaluation(cast("abdef", TimestampType), null) + checkEvaluation(cast("12.65", DecimalType.Unlimited), Decimal(12.65)) + + checkEvaluation(cast(cast(sd, DateType), StringType), sd) + checkEvaluation(cast(cast(d, StringType), DateType), 0) + checkEvaluation(cast(cast(nts, TimestampType), StringType), nts) + checkEvaluation(cast(cast(ts, StringType), TimestampType), ts) + + // all convert to string type to check + checkEvaluation(cast(cast(cast(nts, TimestampType), DateType), StringType), sd) + checkEvaluation(cast(cast(cast(ts, DateType), TimestampType), StringType), zts) + + checkEvaluation(cast(cast("abdef", BinaryType), StringType), "abdef") + + checkEvaluation(cast(cast(cast(cast( + cast(cast("5", ByteType), ShortType), IntegerType), FloatType), DoubleType), LongType), + 5.toLong) + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), + DecimalType.Unlimited), LongType), StringType), ShortType), + 0.toShort) + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType), + DecimalType.Unlimited), LongType), StringType), ShortType), + null) + checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.Unlimited), + ByteType), TimestampType), LongType), StringType), ShortType), + 0.toShort) + + checkEvaluation(cast("23", DoubleType), 23d) + checkEvaluation(cast("23", IntegerType), 23) + checkEvaluation(cast("23", FloatType), 23f) + checkEvaluation(cast("23", DecimalType.Unlimited), Decimal(23)) + checkEvaluation(cast("23", ByteType), 23.toByte) + checkEvaluation(cast("23", ShortType), 23.toShort) + checkEvaluation(cast("2012-12-11", DoubleType), null) + checkEvaluation(cast(123, IntegerType), 123) + + + checkEvaluation(cast(Literal.create(null, IntegerType), ShortType), null) + } + + test("cast and add") { + checkEvaluation(Add(Literal(23d), cast(true, DoubleType)), 24d) + checkEvaluation(Add(Literal(23), cast(true, IntegerType)), 24) + checkEvaluation(Add(Literal(23f), cast(true, FloatType)), 24f) + checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.Unlimited)), Decimal(24)) + checkEvaluation(Add(Literal(23.toByte), cast(true, ByteType)), 24.toByte) + checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort) + } + + test("casting to fixed-precision decimals") { + // Overflow and rounding for casting to fixed-precision decimals: + // - Values should round with HALF_UP mode by default when you lower scale + // - Values that would overflow the target precision should turn into null + // - Because of this, casts to fixed-precision decimals should be nullable + + assert(cast(123, DecimalType.Unlimited).nullable === false) + assert(cast(10.03f, DecimalType.Unlimited).nullable === true) + assert(cast(10.03, DecimalType.Unlimited).nullable === true) + assert(cast(Decimal(10.03), DecimalType.Unlimited).nullable === false) + + assert(cast(123, DecimalType(2, 1)).nullable === true) + assert(cast(10.03f, DecimalType(2, 1)).nullable === true) + assert(cast(10.03, DecimalType(2, 1)).nullable === true) + assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true) + + + checkEvaluation(cast(10.03, DecimalType.Unlimited), Decimal(10.03)) + checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03)) + checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10)) + checkEvaluation(cast(10.03, DecimalType(1, 0)), null) + checkEvaluation(cast(10.03, DecimalType(2, 1)), null) + checkEvaluation(cast(10.03, DecimalType(3, 2)), null) + checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null) + + checkEvaluation(cast(10.05, DecimalType.Unlimited), Decimal(10.05)) + checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05)) + checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10)) + checkEvaluation(cast(10.05, DecimalType(1, 0)), null) + checkEvaluation(cast(10.05, DecimalType(2, 1)), null) + checkEvaluation(cast(10.05, DecimalType(3, 2)), null) + checkEvaluation(cast(Decimal(10.05), DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(cast(Decimal(10.05), DecimalType(3, 2)), null) + + checkEvaluation(cast(9.95, DecimalType(3, 2)), Decimal(9.95)) + checkEvaluation(cast(9.95, DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(cast(9.95, DecimalType(2, 0)), Decimal(10)) + checkEvaluation(cast(9.95, DecimalType(2, 1)), null) + checkEvaluation(cast(9.95, DecimalType(1, 0)), null) + checkEvaluation(cast(Decimal(9.95), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(cast(Decimal(9.95), DecimalType(1, 0)), null) + + checkEvaluation(cast(-9.95, DecimalType(3, 2)), Decimal(-9.95)) + checkEvaluation(cast(-9.95, DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(cast(-9.95, DecimalType(2, 0)), Decimal(-10)) + checkEvaluation(cast(-9.95, DecimalType(2, 1)), null) + checkEvaluation(cast(-9.95, DecimalType(1, 0)), null) + checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null) + + checkEvaluation(cast(Double.NaN, DecimalType.Unlimited), null) + checkEvaluation(cast(1.0 / 0.0, DecimalType.Unlimited), null) + checkEvaluation(cast(Float.NaN, DecimalType.Unlimited), null) + checkEvaluation(cast(1.0f / 0.0f, DecimalType.Unlimited), null) + + checkEvaluation(cast(Double.NaN, DecimalType(2, 1)), null) + checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null) + checkEvaluation(cast(Float.NaN, DecimalType(2, 1)), null) + checkEvaluation(cast(1.0f / 0.0f, DecimalType(2, 1)), null) + } + + test("cast from date") { + val d = Date.valueOf("1970-01-01") + checkEvaluation(cast(d, ShortType), null) + checkEvaluation(cast(d, IntegerType), null) + checkEvaluation(cast(d, LongType), null) + checkEvaluation(cast(d, FloatType), null) + checkEvaluation(cast(d, DoubleType), null) + checkEvaluation(cast(d, DecimalType.Unlimited), null) + checkEvaluation(cast(d, DecimalType(10, 2)), null) + checkEvaluation(cast(d, StringType), "1970-01-01") + checkEvaluation(cast(cast(d, TimestampType), StringType), "1970-01-01 00:00:00") + } + + test("cast from timestamp") { + val millis = 15 * 1000 + 2 + val seconds = millis * 1000 + 2 + val ts = new Timestamp(millis) + val tss = new Timestamp(seconds) + checkEvaluation(cast(ts, ShortType), 15.toShort) + checkEvaluation(cast(ts, IntegerType), 15) + checkEvaluation(cast(ts, LongType), 15.toLong) + checkEvaluation(cast(ts, FloatType), 15.002f) + checkEvaluation(cast(ts, DoubleType), 15.002) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), ts) + checkEvaluation(cast(cast(tss, IntegerType), TimestampType), ts) + checkEvaluation(cast(cast(tss, LongType), TimestampType), ts) + checkEvaluation( + cast(cast(millis.toFloat / 1000, TimestampType), FloatType), + millis.toFloat / 1000) + checkEvaluation( + cast(cast(millis.toDouble / 1000, TimestampType), DoubleType), + millis.toDouble / 1000) + checkEvaluation( + cast(cast(Decimal(1), TimestampType), DecimalType.Unlimited), + Decimal(1)) + + // A test for higher precision than millis + checkEvaluation(cast(cast(0.00000001, TimestampType), DoubleType), 0.00000001) + + checkEvaluation(cast(Double.NaN, TimestampType), null) + checkEvaluation(cast(1.0 / 0.0, TimestampType), null) + checkEvaluation(cast(Float.NaN, TimestampType), null) + checkEvaluation(cast(1.0f / 0.0f, TimestampType), null) + } + + test("cast from array") { + val array = Literal.create(Seq("123", "abc", "", null), + ArrayType(StringType, containsNull = true)) + val array_notNull = Literal.create(Seq("123", "abc", ""), + ArrayType(StringType, containsNull = false)) + + { + val ret = cast(array, ArrayType(IntegerType, containsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(123, null, null, null)) + } + { + val ret = cast(array, ArrayType(IntegerType, containsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(array, ArrayType(BooleanType, containsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(true, true, false, null)) + } + { + val ret = cast(array, ArrayType(BooleanType, containsNull = false)) + assert(ret.resolved === false) + } + + { + val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(123, null, null)) + } + { + val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(true, true, false)) + } + { + val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(true, true, false)) + } + + { + val ret = cast(array, IntegerType) + assert(ret.resolved === false) + } + } + + test("cast from map") { + val map = Literal.create( + Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val map_notNull = Literal.create( + Map("a" -> "123", "b" -> "abc", "c" -> ""), + MapType(StringType, StringType, valueContainsNull = false)) + + { + val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null)) + } + { + val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null)) + } + { + val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true)) + assert(ret.resolved === false) + } + + { + val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null)) + } + { + val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + } + { + val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + } + { + val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) + assert(ret.resolved === false) + } + + { + val ret = cast(map, IntegerType) + assert(ret.resolved === false) + } + } + + test("cast from struct") { + val struct = Literal.create( + Row("123", "abc", "", null), + StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = true), + StructField("c", StringType, nullable = true), + StructField("d", StringType, nullable = true)))) + val struct_notNull = Literal.create( + Row("123", "abc", ""), + StructType(Seq( + StructField("a", StringType, nullable = false), + StructField("b", StringType, nullable = false), + StructField("c", StringType, nullable = false)))) + + { + val ret = cast(struct, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true), + StructField("d", IntegerType, nullable = true)))) + assert(ret.resolved === true) + checkEvaluation(ret, Row(123, null, null, null)) + } + { + val ret = cast(struct, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = true)))) + assert(ret.resolved === false) + } + { + val ret = cast(struct, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = true), + StructField("d", BooleanType, nullable = true)))) + assert(ret.resolved === true) + checkEvaluation(ret, Row(true, true, false, null)) + } + { + val ret = cast(struct, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = false), + StructField("d", BooleanType, nullable = true)))) + assert(ret.resolved === false) + } + + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true)))) + assert(ret.resolved === true) + checkEvaluation(ret, Row(123, null, null)) + } + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false)))) + assert(ret.resolved === false) + } + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = true)))) + assert(ret.resolved === true) + checkEvaluation(ret, Row(true, true, false)) + } + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = false)))) + assert(ret.resolved === true) + checkEvaluation(ret, Row(true, true, false)) + } + + { + val ret = cast(struct, StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = true), + StructField("c", StringType, nullable = true)))) + assert(ret.resolved === false) + } + { + val ret = cast(struct, IntegerType) + assert(ret.resolved === false) + } + } + + test("complex casting") { + val complex = Literal.create( + Row( + Seq("123", "abc", ""), + Map("a" -> "123", "b" -> "abc", "c" -> ""), + Row(0)), + StructType(Seq( + StructField("a", + ArrayType(StringType, containsNull = false), nullable = true), + StructField("m", + MapType(StringType, StringType, valueContainsNull = false), nullable = true), + StructField("s", + StructType(Seq( + StructField("i", IntegerType, nullable = true))))))) + + val ret = cast(complex, StructType(Seq( + StructField("a", + ArrayType(IntegerType, containsNull = true), nullable = true), + StructField("m", + MapType(StringType, BooleanType, valueContainsNull = false), nullable = true), + StructField("s", + StructType(Seq( + StructField("l", LongType, nullable = true))))))) + + assert(ret.resolved === true) + checkEvaluation(ret, Row( + Seq(123, null, null), + Map("a" -> true, "b" -> true, "c" -> false), + Row(0L))) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala similarity index 62% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index b5ebe4b38e337..481b335d15dfd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -17,37 +17,14 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ /** - * Overrides our expression evaluation tests to use code generation for evaluation. + * Additional tests for code generation. */ -class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val plan = try { - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() - } catch { - case e: Throwable => - val evaluated = GenerateProjection.expressionEvaluator(expression) - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code.mkString("\n")} - |$e - """.stripMargin) - } - - val actual = plan(inputRow).apply(0) - if(actual != expected) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - +class CodeGenerationSuite extends SparkFunSuite { test("multithreaded eval") { import scala.concurrent._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala new file mode 100644 index 0000000000000..f151dd2a47f78 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types._ + + +class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("CreateStruct") { + val row = Row(1, 2, 3) + val c1 = 'a.int.at(0).as("a") + val c3 = 'c.int.at(2).as("c") + checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row) + } + + test("complex type") { + val row = create_row( + "^Ba*n", // 0 + null.asInstanceOf[UTF8String], // 1 + create_row("aa", "bb"), // 2 + Map("aa"->"bb"), // 3 + Seq("aa", "bb") // 4 + ) + + val typeS = StructType( + StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil + ) + val typeMap = MapType(StringType, StringType) + val typeArray = ArrayType(StringType) + + checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), + Literal("aa")), "bb", row) + checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row) + checkEvaluation( + GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) + checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), + Literal.create(null, StringType)), null, row) + + checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), + Literal(1)), "bb", row) + checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row) + checkEvaluation( + GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) + checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), + Literal.create(null, IntegerType)), null, row) + + def getStructField(expr: Expression, fieldName: String): ExtractValue = { + expr.dataType match { + case StructType(fields) => + val field = fields.find(_.name == fieldName).get + GetStructField(expr, field, fields.indexOf(field)) + } + } + + def quickResolve(u: UnresolvedExtractValue): ExtractValue = { + ExtractValue(u.child, u.extraction, _ == _) + } + + checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) + checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row) + + val typeS_notNullable = StructType( + StructField("a", StringType, nullable = false) + :: StructField("b", StringType, nullable = false) :: Nil + ) + + assert(getStructField(BoundReference(2, typeS, nullable = true), "a").nullable === true) + assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable + === false) + + assert(getStructField(Literal.create(null, typeS), "a").nullable === true) + assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true) + + checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row) + checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row) + checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row) + } + + test("error message of ExtractValue") { + val structType = StructType(StructField("a", StringType, true) :: Nil) + val arrayStructType = ArrayType(structType) + val arrayType = ArrayType(StringType) + val otherType = StringType + + def checkErrorMessage( + childDataType: DataType, + fieldDataType: DataType, + errorMesage: String): Unit = { + val e = intercept[org.apache.spark.sql.AnalysisException] { + ExtractValue( + Literal.create(null, childDataType), + Literal.create(null, fieldDataType), + _ == _) + } + assert(e.getMessage().contains(errorMesage)) + } + + checkErrorMessage(structType, IntegerType, "Field name should be String Literal") + checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal") + checkErrorMessage(arrayType, StringType, "Array index should be integral type") + checkErrorMessage(otherType, StringType, "Can't extract value from") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala new file mode 100644 index 0000000000000..152c4e4111244 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{IntegerType, BooleanType} + + +class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("case when") { + val row = create_row(null, false, true, "a", "b", "c") + val c1 = 'a.boolean.at(0) + val c2 = 'a.boolean.at(1) + val c3 = 'a.boolean.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + val c6 = 'a.string.at(5) + + checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) + + checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) + + assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) + + val c4_notNull = 'a.boolean.notNull.at(3) + val c5_notNull = 'a.boolean.notNull.at(4) + val c6_notNull = 'a.boolean.notNull.at(5) + + assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) + } + + test("case key when") { + val row = create_row(null, 1, 2, "a", "b", "c") + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.int.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + val c6 = 'a.string.at(5) + + val literalNull = Literal.create(null, IntegerType) + val literalInt = Literal(1) + val literalString = Literal("a") + + checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row) + checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row) + checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row) + checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row) + checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row) + checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row) + + checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row) + checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row) + checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row) + checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala new file mode 100644 index 0000000000000..87a92b87962f8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalactic.TripleEqualsSupport.Spread +import org.scalatest.Matchers._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} + +/** + * A few helper functions for expression evaluation testing. Mixin this trait to use them. + */ +trait ExpressionEvalHelper { + self: SparkFunSuite => + + protected def create_row(values: Any*): Row = { + new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) + } + + protected def checkEvaluation( + expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + checkEvaluationWithoutCodegen(expression, expected, inputRow) + checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow) + checkEvaluationWithGeneratedProjection(expression, expected, inputRow) + } + + protected def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { + expression.eval(inputRow) + } + + protected def checkEvaluationWithoutCodegen( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if (actual != expected) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect evaluation (codegen off): $expression, " + + s"actual: $actual, " + + s"expected: $expected$input") + } + } + + protected def checkEvaluationWithGeneratedMutableProjection( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + + val plan = try { + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val ctx = GenerateProjection.newCodeGenContext() + val evaluated = expression.gen(ctx) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if (actual != expected) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + protected def checkEvaluationWithGeneratedProjection( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val ctx = GenerateProjection.newCodeGenContext() + lazy val evaluated = expression.gen(ctx) + + val plan = try { + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |Expressions: $expression + |Code: $evaluated + """.stripMargin) + } + if (actual != expectedRow) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + protected def checkDoubleEvaluation( + expression: Expression, + expected: Spread[Double], + inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + actual.asInstanceOf[Double] shouldBe expected + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala deleted file mode 100644 index 5c4a1527c27c9..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ /dev/null @@ -1,1371 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import java.sql.{Date, Timestamp} - -import scala.collection.immutable.HashSet - -import org.scalactic.TripleEqualsSupport.Spread -import org.scalatest.FunSuite -import org.scalatest.Matchers._ - -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.mathfuncs._ -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.types._ - - -class ExpressionEvaluationBaseSuite extends FunSuite { - - def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { - expression.eval(inputRow) - } - - def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - if(actual != expected) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - - def checkDoubleEvaluation( - expression: Expression, - expected: Spread[Double], - inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - actual.asInstanceOf[Double] shouldBe expected - } -} - -class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { - - def create_row(values: Any*): Row = { - new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) - } - - test("literals") { - checkEvaluation(Literal(1), 1) - checkEvaluation(Literal(true), true) - checkEvaluation(Literal(0L), 0L) - checkEvaluation(Literal("test"), "test") - checkEvaluation(Literal(1) + Literal(1), 2) - } - - test("unary BitwiseNOT") { - checkEvaluation(BitwiseNot(1), -2) - assert(BitwiseNot(1).dataType === IntegerType) - assert(BitwiseNot(1).eval(EmptyRow).isInstanceOf[Int]) - checkEvaluation(BitwiseNot(1.toLong), -2.toLong) - assert(BitwiseNot(1.toLong).dataType === LongType) - assert(BitwiseNot(1.toLong).eval(EmptyRow).isInstanceOf[Long]) - checkEvaluation(BitwiseNot(1.toShort), -2.toShort) - assert(BitwiseNot(1.toShort).dataType === ShortType) - assert(BitwiseNot(1.toShort).eval(EmptyRow).isInstanceOf[Short]) - checkEvaluation(BitwiseNot(1.toByte), -2.toByte) - assert(BitwiseNot(1.toByte).dataType === ByteType) - assert(BitwiseNot(1.toByte).eval(EmptyRow).isInstanceOf[Byte]) - } - - // scalastyle:off - /** - * Checks for three-valued-logic. Based on: - * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 - * I.e. in flat cpo "False -> Unknown -> True", - * OR is lowest upper bound, - * AND is greatest lower bound. - * p q p OR q p AND q p = q - * True True True True True - * True False True False False - * True Unknown True Unknown Unknown - * False True True False False - * False False False False True - * False Unknown Unknown False Unknown - * Unknown True True Unknown Unknown - * Unknown False Unknown False Unknown - * Unknown Unknown Unknown Unknown Unknown - * - * p NOT p - * True False - * False True - * Unknown Unknown - */ - // scalastyle:on - val notTrueTable = - (true, false) :: - (false, true) :: - (null, null) :: Nil - - test("3VL Not") { - notTrueTable.foreach { - case (v, answer) => - checkEvaluation(!Literal.create(v, BooleanType), answer) - } - } - - booleanLogicTest("AND", _ && _, - (true, true, true) :: - (true, false, false) :: - (true, null, null) :: - (false, true, false) :: - (false, false, false) :: - (false, null, false) :: - (null, true, null) :: - (null, false, false) :: - (null, null, null) :: Nil) - - booleanLogicTest("OR", _ || _, - (true, true, true) :: - (true, false, true) :: - (true, null, true) :: - (false, true, true) :: - (false, false, false) :: - (false, null, null) :: - (null, true, true) :: - (null, false, null) :: - (null, null, null) :: Nil) - - booleanLogicTest("=", _ === _, - (true, true, true) :: - (true, false, false) :: - (true, null, null) :: - (false, true, false) :: - (false, false, true) :: - (false, null, null) :: - (null, true, null) :: - (null, false, null) :: - (null, null, null) :: Nil) - - def booleanLogicTest( - name: String, - op: (Expression, Expression) => Expression, - truthTable: Seq[(Any, Any, Any)]) { - test(s"3VL $name") { - truthTable.foreach { - case (l,r,answer) => - val expr = op(Literal.create(l, BooleanType), Literal.create(r, BooleanType)) - checkEvaluation(expr, answer) - } - } - } - - test("IN") { - checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) - checkEvaluation( - In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), - true) - } - - test("Divide") { - checkEvaluation(Divide(Literal(2), Literal(1)), 2) - checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) - checkEvaluation(Divide(Literal(1), Literal(2)), 0) - checkEvaluation(Divide(Literal(1), Literal(0)), null) - checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) - checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) - checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), - null) - } - - test("Remainder") { - checkEvaluation(Remainder(Literal(2), Literal(1)), 0) - checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0) - checkEvaluation(Remainder(Literal(1), Literal(2)), 1) - checkEvaluation(Remainder(Literal(1), Literal(0)), null) - checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) - checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), - null) - } - - test("INSET") { - val hS = HashSet[Any]() + 1 + 2 - val nS = HashSet[Any]() + 1 + 2 + null - val one = Literal(1) - val two = Literal(2) - val three = Literal(3) - val nl = Literal(null) - val s = Seq(one, two) - val nullS = Seq(one, two, null) - checkEvaluation(InSet(one, hS), true) - checkEvaluation(InSet(two, hS), true) - checkEvaluation(InSet(two, nS), true) - checkEvaluation(InSet(nl, nS), true) - checkEvaluation(InSet(three, hS), false) - checkEvaluation(InSet(three, nS), false) - checkEvaluation(InSet(one, hS) && InSet(two, hS), true) - } - - test("MaxOf") { - checkEvaluation(MaxOf(1, 2), 2) - checkEvaluation(MaxOf(2, 1), 2) - checkEvaluation(MaxOf(1L, 2L), 2L) - checkEvaluation(MaxOf(2L, 1L), 2L) - - checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) - checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) - } - - test("MinOf") { - checkEvaluation(MinOf(1, 2), 1) - checkEvaluation(MinOf(2, 1), 1) - checkEvaluation(MinOf(1L, 2L), 1L) - checkEvaluation(MinOf(2L, 1L), 1L) - - checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) - checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) - } - - test("LIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType).like("a"), null) - checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) - checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) - checkEvaluation("abdef" like "abdef", true) - checkEvaluation("a_%b" like "a\\__b", true) - checkEvaluation("addb" like "a_%b", true) - checkEvaluation("addb" like "a\\__b", false) - checkEvaluation("addb" like "a%\\%b", false) - checkEvaluation("a_%b" like "a%\\%b", true) - checkEvaluation("addb" like "a%", true) - checkEvaluation("addb" like "**", false) - checkEvaluation("abc" like "a%", true) - checkEvaluation("abc" like "b%", false) - checkEvaluation("abc" like "bc%", false) - checkEvaluation("a\nb" like "a_b", true) - checkEvaluation("ab" like "a%b", true) - checkEvaluation("a\nb" like "a%b", true) - } - - test("LIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, create_row(null)) - checkEvaluation("abdef" like regEx, true, create_row("abdef")) - checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) - checkEvaluation("addb" like regEx, true, create_row("a_%b")) - checkEvaluation("addb" like regEx, false, create_row("a\\__b")) - checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) - checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) - checkEvaluation("addb" like regEx, true, create_row("a%")) - checkEvaluation("addb" like regEx, false, create_row("**")) - checkEvaluation("abc" like regEx, true, create_row("a%")) - checkEvaluation("abc" like regEx, false, create_row("b%")) - checkEvaluation("abc" like regEx, false, create_row("bc%")) - checkEvaluation("a\nb" like regEx, true, create_row("a_b")) - checkEvaluation("ab" like regEx, true, create_row("a%b")) - checkEvaluation("a\nb" like regEx, true, create_row("a%b")) - - checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) - } - - test("RLIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) - checkEvaluation("abdef" rlike Literal.create(null, StringType), null) - checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) - checkEvaluation("abdef" rlike "abdef", true) - checkEvaluation("abbbbc" rlike "a.*c", true) - - checkEvaluation("fofo" rlike "^fo", true) - checkEvaluation("fo\no" rlike "^fo\no$", true) - checkEvaluation("Bn" rlike "^Ba*n", true) - checkEvaluation("afofo" rlike "fo", true) - checkEvaluation("afofo" rlike "^fo", false) - checkEvaluation("Baan" rlike "^Ba?n", false) - checkEvaluation("axe" rlike "pi|apa", false) - checkEvaluation("pip" rlike "^(pi)*$", false) - - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) - - intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike "**") - } - } - - test("RLIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) - checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) - checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) - checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) - checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) - - intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, create_row("**")) - } - } - - test("data type casting") { - - val sd = "1970-01-01" - val d = Date.valueOf(sd) - val zts = sd + " 00:00:00" - val sts = sd + " 00:00:02" - val nts = sts + ".1" - val ts = Timestamp.valueOf(nts) - - checkEvaluation("abdef" cast StringType, "abdef") - checkEvaluation("abdef" cast DecimalType.Unlimited, null) - checkEvaluation("abdef" cast TimestampType, null) - checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65)) - - checkEvaluation(Literal(1) cast LongType, 1) - checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) - checkEvaluation(Cast(Literal(-1200) cast TimestampType, LongType), -2.toLong) - checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) - checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) - - checkEvaluation(Cast(Literal(sd) cast DateType, StringType), sd) - checkEvaluation(Cast(Literal(d) cast StringType, DateType), 0) - checkEvaluation(Cast(Literal(nts) cast TimestampType, StringType), nts) - checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts) - // all convert to string type to check - checkEvaluation( - Cast(Cast(Literal(nts) cast TimestampType, DateType), StringType), sd) - checkEvaluation( - Cast(Cast(Literal(ts) cast DateType, TimestampType), StringType), zts) - - checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") - - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) - checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0) - checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null) - checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) - checkEvaluation(Literal(true) cast IntegerType, 1) - checkEvaluation(Literal(false) cast IntegerType, 0) - checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) - checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) - checkEvaluation("23" cast DoubleType, 23d) - checkEvaluation("23" cast IntegerType, 23) - checkEvaluation("23" cast FloatType, 23f) - checkEvaluation("23" cast DecimalType.Unlimited, Decimal(23)) - checkEvaluation("23" cast ByteType, 23.toByte) - checkEvaluation("23" cast ShortType, 23.toShort) - checkEvaluation("2012-12-11" cast DoubleType, null) - checkEvaluation(Literal(123) cast IntegerType, 123) - - checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d) - checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) - checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f) - checkEvaluation(Literal(Decimal(23)) + Cast(true, DecimalType.Unlimited), Decimal(24)) - checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte) - checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort) - - intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} - - assert(("abcdef" cast StringType).nullable === false) - assert(("abcdef" cast BinaryType).nullable === false) - assert(("abcdef" cast BooleanType).nullable === false) - assert(("abcdef" cast TimestampType).nullable === true) - assert(("abcdef" cast LongType).nullable === true) - assert(("abcdef" cast IntegerType).nullable === true) - assert(("abcdef" cast ShortType).nullable === true) - assert(("abcdef" cast ByteType).nullable === true) - assert(("abcdef" cast DecimalType.Unlimited).nullable === true) - assert(("abcdef" cast DecimalType(4, 2)).nullable === true) - assert(("abcdef" cast DoubleType).nullable === true) - assert(("abcdef" cast FloatType).nullable === true) - - checkEvaluation(Cast(Literal.create(null, IntegerType), ShortType), null) - } - - test("date") { - val d1 = DateUtils.fromJavaDate(Date.valueOf("1970-01-01")) - val d2 = DateUtils.fromJavaDate(Date.valueOf("1970-01-02")) - checkEvaluation(Literal(d1) < Literal(d2), true) - } - - test("casting to fixed-precision decimals") { - // Overflow and rounding for casting to fixed-precision decimals: - // - Values should round with HALF_UP mode by default when you lower scale - // - Values that would overflow the target precision should turn into null - // - Because of this, casts to fixed-precision decimals should be nullable - - assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false) - assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === true) - assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === true) - assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false) - - assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true) - - checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123)) - checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123)) - checkEvaluation(Cast(Literal(123), DecimalType(3, 1)), null) - checkEvaluation(Cast(Literal(123), DecimalType(2, 0)), null) - - checkEvaluation(Cast(Literal(10.03), DecimalType.Unlimited), Decimal(10.03)) - checkEvaluation(Cast(Literal(10.03), DecimalType(4, 2)), Decimal(10.03)) - checkEvaluation(Cast(Literal(10.03), DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(Cast(Literal(10.03), DecimalType(2, 0)), Decimal(10)) - checkEvaluation(Cast(Literal(10.03), DecimalType(1, 0)), null) - checkEvaluation(Cast(Literal(10.03), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(10.03), DecimalType(3, 2)), null) - checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 2)), null) - - checkEvaluation(Cast(Literal(10.05), DecimalType.Unlimited), Decimal(10.05)) - checkEvaluation(Cast(Literal(10.05), DecimalType(4, 2)), Decimal(10.05)) - checkEvaluation(Cast(Literal(10.05), DecimalType(3, 1)), Decimal(10.1)) - checkEvaluation(Cast(Literal(10.05), DecimalType(2, 0)), Decimal(10)) - checkEvaluation(Cast(Literal(10.05), DecimalType(1, 0)), null) - checkEvaluation(Cast(Literal(10.05), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(10.05), DecimalType(3, 2)), null) - checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 1)), Decimal(10.1)) - checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 2)), null) - - checkEvaluation(Cast(Literal(9.95), DecimalType(3, 2)), Decimal(9.95)) - checkEvaluation(Cast(Literal(9.95), DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(Cast(Literal(9.95), DecimalType(2, 0)), Decimal(10)) - checkEvaluation(Cast(Literal(9.95), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(9.95), DecimalType(1, 0)), null) - checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(1, 0)), null) - - checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 2)), Decimal(-9.95)) - checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) - checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 0)), Decimal(-10)) - checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(-9.95), DecimalType(1, 0)), null) - checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(3, 1)), Decimal(-10.0)) - checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(1, 0)), null) - - checkEvaluation(Cast(Literal(Double.NaN), DecimalType.Unlimited), null) - checkEvaluation(Cast(Literal(1.0 / 0.0), DecimalType.Unlimited), null) - checkEvaluation(Cast(Literal(Float.NaN), DecimalType.Unlimited), null) - checkEvaluation(Cast(Literal(1.0f / 0.0f), DecimalType.Unlimited), null) - - checkEvaluation(Cast(Literal(Double.NaN), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(1.0 / 0.0), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(Float.NaN), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(1.0f / 0.0f), DecimalType(2, 1)), null) - } - - test("timestamp") { - val ts1 = new Timestamp(12) - val ts2 = new Timestamp(123) - checkEvaluation(Literal("ab") < Literal("abc"), true) - checkEvaluation(Literal(ts1) < Literal(ts2), true) - } - - test("date casting") { - val d = Date.valueOf("1970-01-01") - checkEvaluation(Cast(Literal(d), ShortType), null) - checkEvaluation(Cast(Literal(d), IntegerType), null) - checkEvaluation(Cast(Literal(d), LongType), null) - checkEvaluation(Cast(Literal(d), FloatType), null) - checkEvaluation(Cast(Literal(d), DoubleType), null) - checkEvaluation(Cast(Literal(d), DecimalType.Unlimited), null) - checkEvaluation(Cast(Literal(d), DecimalType(10, 2)), null) - checkEvaluation(Cast(Literal(d), StringType), "1970-01-01") - checkEvaluation(Cast(Cast(Literal(d), TimestampType), StringType), "1970-01-01 00:00:00") - } - - test("timestamp casting") { - val millis = 15 * 1000 + 2 - val seconds = millis * 1000 + 2 - val ts = new Timestamp(millis) - val tss = new Timestamp(seconds) - checkEvaluation(Cast(ts, ShortType), 15) - checkEvaluation(Cast(ts, IntegerType), 15) - checkEvaluation(Cast(ts, LongType), 15) - checkEvaluation(Cast(ts, FloatType), 15.002f) - checkEvaluation(Cast(ts, DoubleType), 15.002) - checkEvaluation(Cast(Cast(tss, ShortType), TimestampType), ts) - checkEvaluation(Cast(Cast(tss, IntegerType), TimestampType), ts) - checkEvaluation(Cast(Cast(tss, LongType), TimestampType), ts) - checkEvaluation(Cast(Cast(millis.toFloat / 1000, TimestampType), FloatType), - millis.toFloat / 1000) - checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType), - millis.toDouble / 1000) - checkEvaluation(Cast(Literal(Decimal(1)) cast TimestampType, DecimalType.Unlimited), Decimal(1)) - - // A test for higher precision than millis - checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001) - - checkEvaluation(Cast(Literal(Double.NaN), TimestampType), null) - checkEvaluation(Cast(Literal(1.0 / 0.0), TimestampType), null) - checkEvaluation(Cast(Literal(Float.NaN), TimestampType), null) - checkEvaluation(Cast(Literal(1.0f / 0.0f), TimestampType), null) - } - - test("array casting") { - val array = Literal.create(Seq("123", "abc", "", null), - ArrayType(StringType, containsNull = true)) - val array_notNull = Literal.create(Seq("123", "abc", ""), - ArrayType(StringType, containsNull = false)) - - { - val cast = Cast(array, ArrayType(IntegerType, containsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(123, null, null, null)) - } - { - val cast = Cast(array, ArrayType(IntegerType, containsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(array, ArrayType(BooleanType, containsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(true, true, false, null)) - } - { - val cast = Cast(array, ArrayType(BooleanType, containsNull = false)) - assert(cast.resolved === false) - } - - { - val cast = Cast(array_notNull, ArrayType(IntegerType, containsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(123, null, null)) - } - { - val cast = Cast(array_notNull, ArrayType(IntegerType, containsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(array_notNull, ArrayType(BooleanType, containsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(true, true, false)) - } - { - val cast = Cast(array_notNull, ArrayType(BooleanType, containsNull = false)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(true, true, false)) - } - - { - val cast = Cast(array, IntegerType) - assert(cast.resolved === false) - } - } - - test("map casting") { - val map = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), - MapType(StringType, StringType, valueContainsNull = true)) - val map_notNull = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> ""), - MapType(StringType, StringType, valueContainsNull = false)) - - { - val cast = Cast(map, MapType(StringType, IntegerType, valueContainsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null)) - } - { - val cast = Cast(map, MapType(StringType, IntegerType, valueContainsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null)) - } - { - val cast = Cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(map, MapType(IntegerType, StringType, valueContainsNull = true)) - assert(cast.resolved === false) - } - - { - val cast = Cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> 123, "b" -> null, "c" -> null)) - } - { - val cast = Cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false)) - } - { - val cast = Cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false)) - } - { - val cast = Cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) - assert(cast.resolved === false) - } - - { - val cast = Cast(map, IntegerType) - assert(cast.resolved === false) - } - } - - test("struct casting") { - val struct = Literal.create( - Row("123", "abc", "", null), - StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", StringType, nullable = true), - StructField("c", StringType, nullable = true), - StructField("d", StringType, nullable = true)))) - val struct_notNull = Literal.create( - Row("123", "abc", ""), - StructType(Seq( - StructField("a", StringType, nullable = false), - StructField("b", StringType, nullable = false), - StructField("c", StringType, nullable = false)))) - - { - val cast = Cast(struct, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = true), - StructField("d", IntegerType, nullable = true)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(123, null, null, null)) - } - { - val cast = Cast(struct, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = false), - StructField("d", IntegerType, nullable = true)))) - assert(cast.resolved === false) - } - { - val cast = Cast(struct, StructType(Seq( - StructField("a", BooleanType, nullable = true), - StructField("b", BooleanType, nullable = true), - StructField("c", BooleanType, nullable = true), - StructField("d", BooleanType, nullable = true)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(true, true, false, null)) - } - { - val cast = Cast(struct, StructType(Seq( - StructField("a", BooleanType, nullable = true), - StructField("b", BooleanType, nullable = true), - StructField("c", BooleanType, nullable = false), - StructField("d", BooleanType, nullable = true)))) - assert(cast.resolved === false) - } - - { - val cast = Cast(struct_notNull, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = true)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(123, null, null)) - } - { - val cast = Cast(struct_notNull, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = false)))) - assert(cast.resolved === false) - } - { - val cast = Cast(struct_notNull, StructType(Seq( - StructField("a", BooleanType, nullable = true), - StructField("b", BooleanType, nullable = true), - StructField("c", BooleanType, nullable = true)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(true, true, false)) - } - { - val cast = Cast(struct_notNull, StructType(Seq( - StructField("a", BooleanType, nullable = true), - StructField("b", BooleanType, nullable = true), - StructField("c", BooleanType, nullable = false)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(true, true, false)) - } - - { - val cast = Cast(struct, StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", StringType, nullable = true), - StructField("c", StringType, nullable = true)))) - assert(cast.resolved === false) - } - { - val cast = Cast(struct, IntegerType) - assert(cast.resolved === false) - } - } - - test("complex casting") { - val complex = Literal.create( - Row( - Seq("123", "abc", ""), - Map("a" -> "123", "b" -> "abc", "c" -> ""), - Row(0)), - StructType(Seq( - StructField("a", - ArrayType(StringType, containsNull = false), nullable = true), - StructField("m", - MapType(StringType, StringType, valueContainsNull = false), nullable = true), - StructField("s", - StructType(Seq( - StructField("i", IntegerType, nullable = true))))))) - - val cast = Cast(complex, StructType(Seq( - StructField("a", - ArrayType(IntegerType, containsNull = true), nullable = true), - StructField("m", - MapType(StringType, BooleanType, valueContainsNull = false), nullable = true), - StructField("s", - StructType(Seq( - StructField("l", LongType, nullable = true))))))) - - assert(cast.resolved === true) - checkEvaluation(cast, Row( - Seq(123, null, null), - Map("a" -> true, "b" -> true, "c" -> false), - Row(0L))) - } - - test("null checking") { - val row = create_row("^Ba*n", null, true, null) - val c1 = 'a.string.at(0) - val c2 = 'a.string.at(1) - val c3 = 'a.boolean.at(2) - val c4 = 'a.boolean.at(3) - - checkEvaluation(c1.isNull, false, row) - checkEvaluation(c1.isNotNull, true, row) - - checkEvaluation(c2.isNull, true, row) - checkEvaluation(c2.isNotNull, false, row) - - checkEvaluation(Literal.create(1, ShortType).isNull, false) - checkEvaluation(Literal.create(1, ShortType).isNotNull, true) - - checkEvaluation(Literal.create(null, ShortType).isNull, true) - checkEvaluation(Literal.create(null, ShortType).isNotNull, false) - - checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) - checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) - - checkEvaluation( - If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) - checkEvaluation(If(c3, c1, c2), "^Ba*n", row) - checkEvaluation(If(c4, c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) - checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(false, BooleanType), - Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) - - checkEvaluation(c1 in (c1, c2), true, row) - checkEvaluation( - Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) - checkEvaluation( - Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) - } - - test("case when") { - val row = create_row(null, false, true, "a", "b", "c") - val c1 = 'a.boolean.at(0) - val c2 = 'a.boolean.at(1) - val c3 = 'a.boolean.at(2) - val c4 = 'a.string.at(3) - val c5 = 'a.string.at(4) - val c6 = 'a.string.at(5) - - checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) - - checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) - checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) - - assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) - - val c4_notNull = 'a.boolean.notNull.at(3) - val c5_notNull = 'a.boolean.notNull.at(4) - val c6_notNull = 'a.boolean.notNull.at(5) - - assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) - - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) - - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) - } - - test("case key when") { - val row = create_row(null, 1, 2, "a", "b", "c") - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.string.at(3) - val c5 = 'a.string.at(4) - val c6 = 'a.string.at(5) - - val literalNull = Literal.create(null, BooleanType) - val literalInt = Literal(1) - val literalString = Literal("a") - - checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row) - checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row) - checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row) - checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row) - checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row) - checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row) - - checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row) - checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row) - checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row) - checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row) - } - - test("complex type") { - val row = create_row( - "^Ba*n", // 0 - null.asInstanceOf[UTF8String], // 1 - create_row("aa", "bb"), // 2 - Map("aa"->"bb"), // 3 - Seq("aa", "bb") // 4 - ) - - val typeS = StructType( - StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil - ) - val typeMap = MapType(StringType, StringType) - val typeArray = ArrayType(StringType) - - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal("aa")), "bb", row) - checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row) - checkEvaluation( - GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal.create(null, StringType)), null, row) - - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal(1)), "bb", row) - checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row) - checkEvaluation( - GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal.create(null, IntegerType)), null, row) - - def getStructField(expr: Expression, fieldName: String): ExtractValue = { - expr.dataType match { - case StructType(fields) => - val field = fields.find(_.name == fieldName).get - GetStructField(expr, field, fields.indexOf(field)) - } - } - - def quickResolve(u: UnresolvedExtractValue): ExtractValue = { - ExtractValue(u.child, u.extraction, _ == _) - } - - checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row) - - val typeS_notNullable = StructType( - StructField("a", StringType, nullable = false) - :: StructField("b", StringType, nullable = false) :: Nil - ) - - assert(getStructField(BoundReference(2,typeS, nullable = true), "a").nullable === true) - assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable - === false) - - assert(getStructField(Literal.create(null, typeS), "a").nullable === true) - assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true) - - checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row) - checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row) - checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row) - } - - test("error message of ExtractValue") { - val structType = StructType(StructField("a", StringType, true) :: Nil) - val arrayStructType = ArrayType(structType) - val arrayType = ArrayType(StringType) - val otherType = StringType - - def checkErrorMessage( - childDataType: DataType, - fieldDataType: DataType, - errorMesage: String): Unit = { - val e = intercept[org.apache.spark.sql.AnalysisException] { - ExtractValue( - Literal.create(null, childDataType), - Literal.create(null, fieldDataType), - _ == _) - } - assert(e.getMessage().contains(errorMesage)) - } - - checkErrorMessage(structType, IntegerType, "Field name should be String Literal") - checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal") - checkErrorMessage(arrayType, StringType, "Array index should be integral type") - checkErrorMessage(otherType, StringType, "Can't extract value from") - } - - test("arithmetic") { - val row = create_row(1, 2, 3, null) - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.int.at(3) - - checkEvaluation(UnaryMinus(c1), -1, row) - checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) - - checkEvaluation(Add(c1, c4), null, row) - checkEvaluation(Add(c1, c2), 3, row) - checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) - checkEvaluation( - Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(-c1, -1, row) - checkEvaluation(c1 + c2, 3, row) - checkEvaluation(c1 - c2, -1, row) - checkEvaluation(c1 * c2, 2, row) - checkEvaluation(c1 / c2, 0, row) - checkEvaluation(c1 % c2, 1, row) - } - - test("fractional arithmetic") { - val row = create_row(1.1, 2.0, 3.1, null) - val c1 = 'a.double.at(0) - val c2 = 'a.double.at(1) - val c3 = 'a.double.at(2) - val c4 = 'a.double.at(3) - - checkEvaluation(UnaryMinus(c1), -1.1, row) - checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) - checkEvaluation(Add(c1, c4), null, row) - checkEvaluation(Add(c1, c2), 3.1, row) - checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) - checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) - checkEvaluation( - Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) - - checkEvaluation(-c1, -1.1, row) - checkEvaluation(c1 + c2, 3.1, row) - checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row) - checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row) - checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row) - checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row) - } - - test("BinaryComparison") { - val row = create_row(1, 2, 3, null, 3, null) - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.int.at(3) - val c5 = 'a.int.at(4) - val c6 = 'a.int.at(5) - - checkEvaluation(LessThan(c1, c4), null, row) - checkEvaluation(LessThan(c1, c2), true, row) - checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row) - checkEvaluation( - LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(c1 < c2, true, row) - checkEvaluation(c1 <= c2, true, row) - checkEvaluation(c1 > c2, false, row) - checkEvaluation(c1 >= c2, false, row) - checkEvaluation(c1 === c2, false, row) - checkEvaluation(c1 !== c2, true, row) - checkEvaluation(c4 <=> c1, false, row) - checkEvaluation(c1 <=> c4, false, row) - checkEvaluation(c4 <=> c6, true, row) - checkEvaluation(c3 <=> c5, true, row) - checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row) - checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row) - } - - test("StringComparison") { - val row = create_row("abc", null) - val c1 = 'a.string.at(0) - val c2 = 'a.string.at(1) - - checkEvaluation(c1 contains "b", true, row) - checkEvaluation(c1 contains "x", false, row) - checkEvaluation(c2 contains "b", null, row) - checkEvaluation(c1 contains Literal.create(null, StringType), null, row) - - checkEvaluation(c1 startsWith "a", true, row) - checkEvaluation(c1 startsWith "b", false, row) - checkEvaluation(c2 startsWith "a", null, row) - checkEvaluation(c1 startsWith Literal.create(null, StringType), null, row) - - checkEvaluation(c1 endsWith "c", true, row) - checkEvaluation(c1 endsWith "b", false, row) - checkEvaluation(c2 endsWith "b", null, row) - checkEvaluation(c1 endsWith Literal.create(null, StringType), null, row) - } - - test("Substring") { - val row = create_row("example", "example".toArray.map(_.toByte)) - - val s = 'a.string.at(0) - - // substring from zero position with less-than-full length - checkEvaluation( - Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)), "ex", row) - checkEvaluation( - Substring(s, Literal.create(1, IntegerType), Literal.create(2, IntegerType)), "ex", row) - - // substring from zero position with full length - checkEvaluation( - Substring(s, Literal.create(0, IntegerType), Literal.create(7, IntegerType)), "example", row) - checkEvaluation( - Substring(s, Literal.create(1, IntegerType), Literal.create(7, IntegerType)), "example", row) - - // substring from zero position with greater-than-full length - checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(100, IntegerType)), - "example", row) - checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(100, IntegerType)), - "example", row) - - // substring from nonzero position with less-than-full length - checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(2, IntegerType)), - "xa", row) - - // substring from nonzero position with full length - checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(6, IntegerType)), - "xample", row) - - // substring from nonzero position with greater-than-full length - checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(100, IntegerType)), - "xample", row) - - // zero-length substring (within string bounds) - checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(0, IntegerType)), - "", row) - - // zero-length substring (beyond string bounds) - checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), - "", row) - - // substring(null, _, _) -> null - checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), - null, create_row(null)) - - // substring(_, null, _) -> null - checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), - null, row) - - // substring(_, _, null) -> null - checkEvaluation( - Substring(s, Literal.create(100, IntegerType), Literal.create(null, IntegerType)), - null, - row) - - // 2-arg substring from zero position - checkEvaluation( - Substring(s, Literal.create(0, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), - "example", - row) - checkEvaluation( - Substring(s, Literal.create(1, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), - "example", - row) - - // 2-arg substring from nonzero position - checkEvaluation( - Substring(s, Literal.create(2, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), - "xample", - row) - - val s_notNull = 'a.string.notNull.at(0) - - assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable - === true) - assert( - Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable - === false) - assert(Substring(s_notNull, - Literal.create(null, IntegerType), Literal.create(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, - Literal.create(0, IntegerType), Literal.create(null, IntegerType)).nullable === true) - - checkEvaluation(s.substr(0, 2), "ex", row) - checkEvaluation(s.substr(0), "example", row) - checkEvaluation(s.substring(0, 2), "ex", row) - checkEvaluation(s.substring(0), "example", row) - } - - test("SQRT") { - val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) - val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) - val rowSequence = inputSequence.map(l => create_row(l.toDouble)) - val d = 'a.double.at(0) - - for ((row, expected) <- rowSequence zip expectedResults) { - checkEvaluation(Sqrt(d), expected, row) - } - - checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) - checkEvaluation(Sqrt(-1), null, EmptyRow) - checkEvaluation(Sqrt(-1.5), null, EmptyRow) - } - - test("Bitwise operations") { - val row = create_row(1, 2, 3, null) - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.int.at(3) - - checkEvaluation(BitwiseAnd(c1, c4), null, row) - checkEvaluation(BitwiseAnd(c1, c2), 0, row) - checkEvaluation(BitwiseAnd(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation( - BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(BitwiseOr(c1, c4), null, row) - checkEvaluation(BitwiseOr(c1, c2), 3, row) - checkEvaluation(BitwiseOr(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation( - BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(BitwiseXor(c1, c4), null, row) - checkEvaluation(BitwiseXor(c1, c2), 3, row) - checkEvaluation(BitwiseXor(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation( - BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(BitwiseNot(c4), null, row) - checkEvaluation(BitwiseNot(c1), -2, row) - checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null, row) - - checkEvaluation(c1 & c2, 0, row) - checkEvaluation(c1 | c2, 3, row) - checkEvaluation(c1 ^ c2, 3, row) - checkEvaluation(~c1, -2, row) - } - - /** - * Used for testing math functions for DataFrames. - * @param c The DataFrame function - * @param f The functions in scala.math - * @param domain The set of values to run the function with - * @param expectNull Whether the given values should return null or not - * @tparam T Generic type for primitives - */ - def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T]( - c: Expression => Expression, - f: T => T, - domain: Iterable[T] = (-20 to 20).map(_ * 0.1), - expectNull: Boolean = false): Unit = { - if (expectNull) { - domain.foreach { value => - checkEvaluation(c(Literal(value)), null, EmptyRow) - } - } else { - domain.foreach { value => - checkEvaluation(c(Literal(value)), f(value), EmptyRow) - } - } - checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) - } - - test("sin") { - unaryMathFunctionEvaluation(Sin, math.sin) - } - - test("asin") { - unaryMathFunctionEvaluation(Asin, math.asin, (-10 to 10).map(_ * 0.1)) - unaryMathFunctionEvaluation(Asin, math.asin, (11 to 20).map(_ * 0.1), true) - } - - test("sinh") { - unaryMathFunctionEvaluation(Sinh, math.sinh) - } - - test("cos") { - unaryMathFunctionEvaluation(Cos, math.cos) - } - - test("acos") { - unaryMathFunctionEvaluation(Acos, math.acos, (-10 to 10).map(_ * 0.1)) - unaryMathFunctionEvaluation(Acos, math.acos, (11 to 20).map(_ * 0.1), true) - } - - test("cosh") { - unaryMathFunctionEvaluation(Cosh, math.cosh) - } - - test("tan") { - unaryMathFunctionEvaluation(Tan, math.tan) - } - - test("atan") { - unaryMathFunctionEvaluation(Atan, math.atan) - } - - test("tanh") { - unaryMathFunctionEvaluation(Tanh, math.tanh) - } - - test("toDegrees") { - unaryMathFunctionEvaluation(ToDegrees, math.toDegrees) - } - - test("toRadians") { - unaryMathFunctionEvaluation(ToRadians, math.toRadians) - } - - test("cbrt") { - unaryMathFunctionEvaluation(Cbrt, math.cbrt) - } - - test("ceil") { - unaryMathFunctionEvaluation(Ceil, math.ceil) - } - - test("floor") { - unaryMathFunctionEvaluation(Floor, math.floor) - } - - test("rint") { - unaryMathFunctionEvaluation(Rint, math.rint) - } - - test("exp") { - unaryMathFunctionEvaluation(Exp, math.exp) - } - - test("expm1") { - unaryMathFunctionEvaluation(Expm1, math.expm1) - } - - test("signum") { - unaryMathFunctionEvaluation[Double](Signum, math.signum) - } - - test("log") { - unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1)) - unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true) - } - - test("log10") { - unaryMathFunctionEvaluation(Log10, math.log10, (0 to 20).map(_ * 0.1)) - unaryMathFunctionEvaluation(Log10, math.log10, (-5 to -1).map(_ * 0.1), true) - } - - test("log1p") { - unaryMathFunctionEvaluation(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) - unaryMathFunctionEvaluation(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), true) - } - - /** - * Used for testing math functions for DataFrames. - * @param c The DataFrame function - * @param f The functions in scala.math - * @param domain The set of values to run the function with - */ - def binaryMathFunctionEvaluation( - c: (Expression, Expression) => Expression, - f: (Double, Double) => Double, - domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), - expectNull: Boolean = false): Unit = { - if (expectNull) { - domain.foreach { case (v1, v2) => - checkEvaluation(c(v1, v2), null, create_row(null)) - } - } else { - domain.foreach { case (v1, v2) => - checkEvaluation(c(v1, v2), f(v1 + 0.0, v2 + 0.0), EmptyRow) - checkEvaluation(c(v2, v1), f(v2 + 0.0, v1 + 0.0), EmptyRow) - } - } - checkEvaluation(c(Literal.create(null, DoubleType), 1.0), null, create_row(null)) - checkEvaluation(c(1.0, Literal.create(null, DoubleType)), null, create_row(null)) - } - - test("pow") { - binaryMathFunctionEvaluation(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) - binaryMathFunctionEvaluation(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), true) - } - - test("hypot") { - binaryMathFunctionEvaluation(Hypot, math.hypot) - } - - test("atan2") { - binaryMathFunctionEvaluation(Atan2, math.atan2) - } -} - -// TODO: Make the tests work with codegen. -class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite { - - test("CreateStruct") { - val row = Row(1, 2, 3) - val c1 = 'a.int.at(0).as("a") - val c3 = 'c.int.at(2).as("c") - checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala new file mode 100644 index 0000000000000..dcb3635c5ccae --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types.StringType + +class ExpressionTypeCheckingSuite extends SparkFunSuite { + + val testRelation = LocalRelation( + 'intField.int, + 'stringField.string, + 'booleanField.boolean, + 'complexField.array(StringType)) + + def assertError(expr: Expression, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + assertSuccess(expr) + } + assert(e.getMessage.contains( + s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) + assert(e.getMessage.contains(errorMessage)) + } + + def assertSuccess(expr: Expression): Unit = { + val analyzed = testRelation.select(expr.as("c")).analyze + SimpleAnalyzer.checkAnalysis(analyzed) + } + + def assertErrorForDifferingTypes(expr: Expression): Unit = { + assertError(expr, + s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") + } + + test("check types for unary arithmetic") { + assertError(UnaryMinus('stringField), "operator - accepts numeric type") + assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt + assertError(Sqrt('booleanField), "function sqrt accepts numeric type") + assertError(Abs('stringField), "function abs accepts numeric type") + assertError(BitwiseNot('stringField), "operator ~ accepts integral type") + } + + test("check types for binary arithmetic") { + // We will cast String to Double for binary arithmetic + assertSuccess(Add('intField, 'stringField)) + assertSuccess(Subtract('intField, 'stringField)) + assertSuccess(Multiply('intField, 'stringField)) + assertSuccess(Divide('intField, 'stringField)) + assertSuccess(Remainder('intField, 'stringField)) + // checkAnalysis(BitwiseAnd('intField, 'stringField)) + + assertErrorForDifferingTypes(Add('intField, 'booleanField)) + assertErrorForDifferingTypes(Subtract('intField, 'booleanField)) + assertErrorForDifferingTypes(Multiply('intField, 'booleanField)) + assertErrorForDifferingTypes(Divide('intField, 'booleanField)) + assertErrorForDifferingTypes(Remainder('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseAnd('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseOr('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseXor('intField, 'booleanField)) + assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) + assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) + + assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") + assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") + assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") + assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") + assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") + + assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") + + assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") + assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") + } + + test("check types for predicates") { + // We will cast String to Double for binary comparison + assertSuccess(EqualTo('intField, 'stringField)) + assertSuccess(EqualNullSafe('intField, 'stringField)) + assertSuccess(LessThan('intField, 'stringField)) + assertSuccess(LessThanOrEqual('intField, 'stringField)) + assertSuccess(GreaterThan('intField, 'stringField)) + assertSuccess(GreaterThanOrEqual('intField, 'stringField)) + + // We will transform EqualTo with numeric and boolean types to CaseKeyWhen + assertSuccess(EqualTo('intField, 'booleanField)) + assertSuccess(EqualNullSafe('intField, 'booleanField)) + + assertError(EqualTo('intField, 'complexField), "differing types") + assertError(EqualNullSafe('intField, 'complexField), "differing types") + + assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) + assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) + + assertError( + LessThan('complexField, 'complexField), "operator < accepts non-complex type") + assertError( + LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") + assertError( + GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") + assertError( + GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + + assertError( + If('intField, 'stringField, 'stringField), + "type of predicate expression in If should be boolean") + assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) + + assertError( + CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), + "WHEN expressions in CaseWhen should all be boolean type") + + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala deleted file mode 100644 index 97af2e0fd0502..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen._ - -/** - * Overrides our expression evaluation tests to use generated code on mutable rows. - */ -class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - lazy val evaluated = GenerateProjection.expressionEvaluator(expression) - - val plan = try { - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) - } catch { - case e: Throwable => - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code.mkString("\n")} - |$e - """.stripMargin) - } - - val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) - if (actual.hashCode() != expectedRow.hashCode()) { - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |${evaluated.code.mkString("\n")} - """.stripMargin) - } - if (actual != expectedRow) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala new file mode 100644 index 0000000000000..f44f55dfb92d1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.StringType + + +class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + // TODO: Add tests for all data types. + + test("boolean literals") { + checkEvaluation(Literal(true), true) + checkEvaluation(Literal(false), false) + } + + test("int literals") { + checkEvaluation(Literal(1), 1) + checkEvaluation(Literal(0L), 0L) + } + + test("double literals") { + List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { + d => { + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toFloat), d.toFloat) + } + } + } + + test("string literals") { + checkEvaluation(Literal("test"), "test") + checkEvaluation(Literal.create(null, StringType), null) + } + + test("sum two literals") { + checkEvaluation(Add(Literal(1), Literal(1)), 2) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala new file mode 100644 index 0000000000000..25ebc70d095d8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.DoubleType + +class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + /** + * Used for testing unary math expressions. + * + * @param c expression + * @param f The functions in scala.math + * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @tparam T Generic type for primitives + */ + private def testUnary[T]( + c: Expression => Expression, + f: T => T, + domain: Iterable[T] = (-20 to 20).map(_ * 0.1), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { value => + checkEvaluation(c(Literal(value)), null, EmptyRow) + } + } else { + domain.foreach { value => + checkEvaluation(c(Literal(value)), f(value), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) + } + + /** + * Used for testing binary math expressions. + * + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + */ + private def testBinary( + c: (Expression, Expression) => Expression, + f: (Double, Double) => Double, + domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { case (v1, v2) => + checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null)) + } + } else { + domain.foreach { case (v1, v2) => + checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(c(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType), Literal(1.0)), null, create_row(null)) + checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("sin") { + testUnary(Sin, math.sin) + } + + test("asin") { + testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) + testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNull = true) + } + + test("sinh") { + testUnary(Sinh, math.sinh) + } + + test("cos") { + testUnary(Cos, math.cos) + } + + test("acos") { + testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) + testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNull = true) + } + + test("cosh") { + testUnary(Cosh, math.cosh) + } + + test("tan") { + testUnary(Tan, math.tan) + } + + test("atan") { + testUnary(Atan, math.atan) + } + + test("tanh") { + testUnary(Tanh, math.tanh) + } + + test("toDegrees") { + testUnary(ToDegrees, math.toDegrees) + } + + test("toRadians") { + testUnary(ToRadians, math.toRadians) + } + + test("cbrt") { + testUnary(Cbrt, math.cbrt) + } + + test("ceil") { + testUnary(Ceil, math.ceil) + } + + test("floor") { + testUnary(Floor, math.floor) + } + + test("rint") { + testUnary(Rint, math.rint) + } + + test("exp") { + testUnary(Exp, math.exp) + } + + test("expm1") { + testUnary(Expm1, math.expm1) + } + + test("signum") { + testUnary[Double](Signum, math.signum) + } + + test("log") { + testUnary(Log, math.log, (0 to 20).map(_ * 0.1)) + testUnary(Log, math.log, (-5 to -1).map(_ * 0.1), expectNull = true) + } + + test("log10") { + testUnary(Log10, math.log10, (0 to 20).map(_ * 0.1)) + testUnary(Log10, math.log10, (-5 to -1).map(_ * 0.1), expectNull = true) + } + + test("log1p") { + testUnary(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) + testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true) + } + + test("pow") { + testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) + } + + test("hypot") { + testBinary(Hypot, math.hypot) + } + + test("atan2") { + testBinary(Atan2, math.atan2) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala new file mode 100644 index 0000000000000..ccdada8b56f83 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{BooleanType, StringType, ShortType} + +class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("null checking") { + val row = create_row("^Ba*n", null, true, null) + val c1 = 'a.string.at(0) + val c2 = 'a.string.at(1) + val c3 = 'a.boolean.at(2) + val c4 = 'a.boolean.at(3) + + checkEvaluation(c1.isNull, false, row) + checkEvaluation(c1.isNotNull, true, row) + + checkEvaluation(c2.isNull, true, row) + checkEvaluation(c2.isNotNull, false, row) + + checkEvaluation(Literal.create(1, ShortType).isNull, false) + checkEvaluation(Literal.create(1, ShortType).isNotNull, true) + + checkEvaluation(Literal.create(null, ShortType).isNull, true) + checkEvaluation(Literal.create(null, ShortType).isNotNull, false) + + checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) + + checkEvaluation( + If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) + checkEvaluation(If(c3, c1, c2), "^Ba*n", row) + checkEvaluation(If(c4, c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), + Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) + + checkEvaluation(c1 in (c1, c2), true, row) + checkEvaluation( + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) + checkEvaluation( + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala new file mode 100644 index 0000000000000..b6261bfba0786 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date, Timestamp} + +import scala.collection.immutable.HashSet + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.types.{IntegerType, BooleanType} + + +class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { + + private def booleanLogicTest( + name: String, + op: (Expression, Expression) => Expression, + truthTable: Seq[(Any, Any, Any)]) { + test(s"3VL $name") { + truthTable.foreach { + case (l, r, answer) => + val expr = op(Literal.create(l, BooleanType), Literal.create(r, BooleanType)) + checkEvaluation(expr, answer) + } + } + } + + // scalastyle:off + /** + * Checks for three-valued-logic. Based on: + * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 + * I.e. in flat cpo "False -> Unknown -> True", + * OR is lowest upper bound, + * AND is greatest lower bound. + * p q p OR q p AND q p = q + * True True True True True + * True False True False False + * True Unknown True Unknown Unknown + * False True True False False + * False False False False True + * False Unknown Unknown False Unknown + * Unknown True True Unknown Unknown + * Unknown False Unknown False Unknown + * Unknown Unknown Unknown Unknown Unknown + * + * p NOT p + * True False + * False True + * Unknown Unknown + */ + // scalastyle:on + val notTrueTable = + (true, false) :: + (false, true) :: + (null, null) :: Nil + + test("3VL Not") { + notTrueTable.foreach { case (v, answer) => + checkEvaluation(Not(Literal.create(v, BooleanType)), answer) + } + } + + booleanLogicTest("AND", And, + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: + (false, false, false) :: + (false, null, false) :: + (null, true, null) :: + (null, false, false) :: + (null, null, null) :: Nil) + + booleanLogicTest("OR", Or, + (true, true, true) :: + (true, false, true) :: + (true, null, true) :: + (false, true, true) :: + (false, false, false) :: + (false, null, null) :: + (null, true, true) :: + (null, false, null) :: + (null, null, null) :: Nil) + + booleanLogicTest("=", EqualTo, + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: + (false, false, true) :: + (false, null, null) :: + (null, true, null) :: + (null, false, null) :: + (null, null, null) :: Nil) + + test("IN") { + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation( + And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), + true) + } + + test("INSET") { + val hS = HashSet[Any]() + 1 + 2 + val nS = HashSet[Any]() + 1 + 2 + null + val one = Literal(1) + val two = Literal(2) + val three = Literal(3) + val nl = Literal(null) + val s = Seq(one, two) + val nullS = Seq(one, two, null) + checkEvaluation(InSet(one, hS), true) + checkEvaluation(InSet(two, hS), true) + checkEvaluation(InSet(two, nS), true) + checkEvaluation(InSet(nl, nS), true) + checkEvaluation(InSet(three, hS), false) + checkEvaluation(InSet(three, nS), false) + checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) + } + + + test("BinaryComparison") { + val row = create_row(1, 2, 3, null, 3, null) + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.int.at(2) + val c4 = 'a.int.at(3) + val c5 = 'a.int.at(4) + val c6 = 'a.int.at(5) + + checkEvaluation(LessThan(c1, c4), null, row) + checkEvaluation(LessThan(c1, c2), true, row) + checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation( + LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + + checkEvaluation(c1 < c2, true, row) + checkEvaluation(c1 <= c2, true, row) + checkEvaluation(c1 > c2, false, row) + checkEvaluation(c1 >= c2, false, row) + checkEvaluation(c1 === c2, false, row) + checkEvaluation(c1 !== c2, true, row) + checkEvaluation(c4 <=> c1, false, row) + checkEvaluation(c1 <=> c4, false, row) + checkEvaluation(c4 <=> c6, true, row) + checkEvaluation(c3 <=> c5, true, row) + checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row) + checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row) + + val d1 = DateUtils.fromJavaDate(Date.valueOf("1970-01-01")) + val d2 = DateUtils.fromJavaDate(Date.valueOf("1970-01-02")) + checkEvaluation(Literal(d1) < Literal(d2), true) + + val ts1 = new Timestamp(12) + val ts2 = new Timestamp(123) + checkEvaluation(Literal("ab") < Literal("abc"), true) + checkEvaluation(Literal(ts1) < Literal(ts2), true) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala new file mode 100644 index 0000000000000..2e81296c4e623 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{IntegerType, StringType} + + +class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("StringComparison") { + val row = create_row("abc", null) + val c1 = 'a.string.at(0) + val c2 = 'a.string.at(1) + + checkEvaluation(c1 contains "b", true, row) + checkEvaluation(c1 contains "x", false, row) + checkEvaluation(c2 contains "b", null, row) + checkEvaluation(c1 contains Literal.create(null, StringType), null, row) + + checkEvaluation(c1 startsWith "a", true, row) + checkEvaluation(c1 startsWith "b", false, row) + checkEvaluation(c2 startsWith "a", null, row) + checkEvaluation(c1 startsWith Literal.create(null, StringType), null, row) + + checkEvaluation(c1 endsWith "c", true, row) + checkEvaluation(c1 endsWith "b", false, row) + checkEvaluation(c2 endsWith "b", null, row) + checkEvaluation(c1 endsWith Literal.create(null, StringType), null, row) + } + + test("Substring") { + val row = create_row("example", "example".toArray.map(_.toByte)) + + val s = 'a.string.at(0) + + // substring from zero position with less-than-full length + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)), "ex", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(2, IntegerType)), "ex", row) + + // substring from zero position with full length + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(7, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(7, IntegerType)), "example", row) + + // substring from zero position with greater-than-full length + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(100, IntegerType)), + "example", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(100, IntegerType)), + "example", row) + + // substring from nonzero position with less-than-full length + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(2, IntegerType)), + "xa", row) + + // substring from nonzero position with full length + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(6, IntegerType)), + "xample", row) + + // substring from nonzero position with greater-than-full length + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(100, IntegerType)), + "xample", row) + + // zero-length substring (within string bounds) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(0, IntegerType)), + "", row) + + // zero-length substring (beyond string bounds) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + "", row) + + // substring(null, _, _) -> null + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + null, create_row(null)) + + // substring(_, null, _) -> null + checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), + null, row) + + // substring(_, _, null) -> null + checkEvaluation( + Substring(s, Literal.create(100, IntegerType), Literal.create(null, IntegerType)), + null, + row) + + // 2-arg substring from zero position + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) + + // 2-arg substring from nonzero position + checkEvaluation( + Substring(s, Literal.create(2, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "xample", + row) + + val s_notNull = 'a.string.notNull.at(0) + + assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === true) + assert( + Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === false) + assert(Substring(s_notNull, + Literal.create(null, IntegerType), Literal.create(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, + Literal.create(0, IntegerType), Literal.create(null, IntegerType)).nullable === true) + + checkEvaluation(s.substr(0, 2), "ex", row) + checkEvaluation(s.substr(0), "example", row) + checkEvaluation(s.substring(0, 2), "ex", row) + checkEvaluation(s.substring(0), "example", row) + } + + test("LIKE literal Regular Expression") { + checkEvaluation(Literal.create(null, StringType).like("a"), null) + checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) + checkEvaluation("abdef" like "abdef", true) + checkEvaluation("a_%b" like "a\\__b", true) + checkEvaluation("addb" like "a_%b", true) + checkEvaluation("addb" like "a\\__b", false) + checkEvaluation("addb" like "a%\\%b", false) + checkEvaluation("a_%b" like "a%\\%b", true) + checkEvaluation("addb" like "a%", true) + checkEvaluation("addb" like "**", false) + checkEvaluation("abc" like "a%", true) + checkEvaluation("abc" like "b%", false) + checkEvaluation("abc" like "bc%", false) + checkEvaluation("a\nb" like "a_b", true) + checkEvaluation("ab" like "a%b", true) + checkEvaluation("a\nb" like "a%b", true) + } + + test("LIKE Non-literal Regular Expression") { + val regEx = 'a.string.at(0) + checkEvaluation("abcd" like regEx, null, create_row(null)) + checkEvaluation("abdef" like regEx, true, create_row("abdef")) + checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) + checkEvaluation("addb" like regEx, true, create_row("a_%b")) + checkEvaluation("addb" like regEx, false, create_row("a\\__b")) + checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) + checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) + checkEvaluation("addb" like regEx, true, create_row("a%")) + checkEvaluation("addb" like regEx, false, create_row("**")) + checkEvaluation("abc" like regEx, true, create_row("a%")) + checkEvaluation("abc" like regEx, false, create_row("b%")) + checkEvaluation("abc" like regEx, false, create_row("bc%")) + checkEvaluation("a\nb" like regEx, true, create_row("a_b")) + checkEvaluation("ab" like regEx, true, create_row("a%b")) + checkEvaluation("a\nb" like regEx, true, create_row("a%b")) + + checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) + } + + test("RLIKE literal Regular Expression") { + checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) + checkEvaluation("abdef" rlike Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) + checkEvaluation("abdef" rlike "abdef", true) + checkEvaluation("abbbbc" rlike "a.*c", true) + + checkEvaluation("fofo" rlike "^fo", true) + checkEvaluation("fo\no" rlike "^fo\no$", true) + checkEvaluation("Bn" rlike "^Ba*n", true) + checkEvaluation("afofo" rlike "fo", true) + checkEvaluation("afofo" rlike "^fo", false) + checkEvaluation("Baan" rlike "^Ba?n", false) + checkEvaluation("axe" rlike "pi|apa", false) + checkEvaluation("pip" rlike "^(pi)*$", false) + + checkEvaluation("abc" rlike "^ab", true) + checkEvaluation("abc" rlike "^bc", false) + checkEvaluation("abc" rlike "^ab", true) + checkEvaluation("abc" rlike "^bc", false) + + intercept[java.util.regex.PatternSyntaxException] { + evaluate("abbbbc" rlike "**") + } + } + + test("RLIKE Non-literal Regular Expression") { + val regEx = 'a.string.at(0) + checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) + checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) + checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) + checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) + checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) + + intercept[java.util.regex.PatternSyntaxException] { + evaluate("abbbbc" rlike regEx, create_row("**")) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 7a19e511eb8b5..88a36aa121b55 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -20,12 +20,16 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.JavaConverters._ import scala.util.Random +import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} -import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.sql.types._ -class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with BeforeAndAfterEach { +class UnsafeFixedWidthAggregationMapSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach { import UnsafeFixedWidthAggregationMap._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 3a60c7fd32675..61722f1ffa462 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Arrays -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods -class UnsafeRowConverterSuite extends FunSuite with Matchers { +class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("basic conversion with only primitive types") { val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 6255578d7fa57..465a5e6914204 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -78,9 +78,9 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { test("(a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ...") { checkCondition('b > 3 || 'c > 5, 'b > 3 || 'c > 5) - checkCondition(('a < 2 && 'a > 3 && 'b > 5) || 'a < 2, 'a < 2) + checkCondition(('a < 2 && 'a > 3 && 'b > 5) || 'a < 2, 'a < 2) - checkCondition('a < 2 || ('a < 2 && 'a > 3 && 'b > 5), 'a < 2) + checkCondition('a < 2 || ('a < 2 && 'a > 3 && 'b > 5), 'a < 2) val input = ('a === 'b && 'b > 3 && 'c > 2) || ('a === 'b && 'c < 1 && 'a === 5) || diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index a30052b38fc11..06c592f4905a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -71,7 +71,7 @@ class CombiningLimitsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("limits: combines two limits after ColumnPruning") { val originalQuery = testRelation @@ -79,7 +79,7 @@ class CombiningLimitsSuite extends PlanTest { .limit(2) .select('a) .limit(5) - + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 5697c2272b8e8..ec3b2f1edfa05 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -248,7 +248,7 @@ class ConstantFoldingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("Constant folding test: Fold In(v, list) into true or false") { var originalQuery = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala index a4a3a66b8b229..f33a18d53b1a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -24,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ * Overrides our expression evaluation tests and reruns them after optimization has occured. This * is to ensure that constant folding and other optimizations do not break anything. */ -class ExpressionOptimizationSuite extends ExpressionEvaluationSuite { +class ExpressionOptimizationSuite extends SparkFunSuite with ExpressionEvalHelper { override def checkEvaluation( expression: Expression, expected: Any, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index c0fe5db9b60f0..ffdc673cdc455 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -94,11 +94,11 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("column pruning for Project(ne, Limit)") { val originalQuery = testRelation - .select('a,'b) + .select('a, 'b) .limit(2) .select('a) @@ -110,7 +110,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + // After this line is unimplemented. test("simple push down") { val originalQuery = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 3eb399e68e70c..1d433275fed2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -46,7 +46,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause optimized to InSet") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2)))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -57,17 +57,17 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("OptimizedIn test: In clause not optimized in case filter has attributes") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala new file mode 100644 index 0000000000000..151654bffbd66 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Rand +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class ProjectCollapsingSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: + Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int) + + test("collapse two deterministic, independent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select('a_plus_1, ('b + 1).as('b_plus_1)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse two deterministic, dependent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select(('a_plus_1 + 1).as('a_plus_2), 'b) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation.select( + (('a + 1).as('a_plus_1) + 1).as('a_plus_2), + 'b).analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not collapse nondeterministic projects") { + val query = testRelation + .select(Rand(10).as('rand)) + .select(('rand + 1).as('rand1), ('rand + 2).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala new file mode 100644 index 0000000000000..df29a62ff0e15 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class ReplaceDistinctWithAggregateSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("ProjectCollapsing", Once, ReplaceDistinctWithAggregate) :: Nil + } + + test("replace distinct with aggregate") { + val input = LocalRelation('a.int, 'b.int) + + val query = Distinct(input) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = Aggregate(input.output, input.output, input) + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala index a3ad200800b02..35f50be46b76f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -33,8 +33,8 @@ class UnionPushdownSuite extends PlanTest { UnionPushdown) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) val testUnion = Union(testRelation, testRelation2) test("union: filter to each side") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index e7cafcc96de87..765c1e2dda99f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan} import org.apache.spark.sql.catalyst.util._ @@ -26,7 +25,7 @@ import org.apache.spark.sql.catalyst.util._ /** * Provides helper methods for comparing plans. */ -class PlanTest extends FunSuite { +class PlanTest extends SparkFunSuite { /** * Since attribute references are given globally unique ids during analysis, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 1273921f6394c..62d5f6ac74885 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} @@ -28,7 +27,7 @@ import org.apache.spark.sql.catalyst.util._ /** * Tests for the sameResult function of [[LogicalPlan]]. */ -class SameResultSuite extends FunSuite { +class SameResultSuite extends SparkFunSuite { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index 2a641c63f87bb..a7de7b052bdc3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.catalyst.trees -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} -class RuleExecutorSuite extends FunSuite { +class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 3d10dab5ba34c..67db3d5e6d751 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -19,21 +19,19 @@ package org.apache.spark.sql.catalyst.trees import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{IntegerType, StringType, NullType} case class Dummy(optKey: Option[Expression]) extends Expression { - def children: Seq[Expression] = optKey.toSeq - def nullable: Boolean = true - def dataType: NullType = NullType + override def children: Seq[Expression] = optKey.toSeq + override def nullable: Boolean = true + override def dataType: NullType = NullType override lazy val resolved = true - type EvaluatedType = Any - def eval(input: Row): Any = null.asInstanceOf[Any] + override def eval(input: Row): Any = null.asInstanceOf[Any] } -class TreeNodeSuite extends FunSuite { +class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } assert(after === Literal(2)) @@ -92,7 +90,7 @@ class TreeNodeSuite extends FunSuite { test("transform works on nodes with Option children") { val dummy1 = Dummy(Some(Literal.create("1", StringType))) val dummy2 = Dummy(None) - val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } var actual = dummy1 transformDown toZero assert(actual === Dummy(Some(Literal(0)))) @@ -105,7 +103,7 @@ class TreeNodeSuite extends FunSuite { } test("preserves origin") { - CurrentOrigin.setPosition(1,1) + CurrentOrigin.setPosition(1, 1) val add = Add(Literal(1), Literal(1)) CurrentOrigin.reset() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala index d7d60efee50fa..4030a1b1df358 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.catalyst.util import org.json4s.jackson.JsonMethods.parse -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.{MetadataBuilder, Metadata} -class MetadataSuite extends FunSuite { +class MetadataSuite extends SparkFunSuite { val baseMetadata = new MetadataBuilder() .putString("purpose", "ml") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index 3e7cf7cbb5e63..c6171b7b6916d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.types -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class DataTypeParserSuite extends FunSuite { +class DataTypeParserSuite extends SparkFunSuite { def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { test(s"parse ${dataTypeString.replace("\n", "")}") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index d797510f36685..261c4fcad24aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.types -import org.scalatest.FunSuite +import org.apache.spark.{SparkException, SparkFunSuite} -class DataTypeSuite extends FunSuite { +class DataTypeSuite extends SparkFunSuite { test("construct an ArrayType") { val array = ArrayType(StringType) @@ -69,6 +69,76 @@ class DataTypeSuite extends FunSuite { } } + test("fieldsMap returns map of name to StructField") { + val struct = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val mapped = StructType.fieldsMap(struct.fields) + + val expected = Map( + "a" -> StructField("a", LongType), + "b" -> StructField("b", FloatType)) + + assert(mapped === expected) + } + + test("merge where right is empty") { + val left = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val right = StructType(List()) + val merged = left.merge(right) + + assert(merged === left) + } + + test("merge where left is empty") { + + val left = StructType(List()) + + val right = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val merged = left.merge(right) + + assert(right === merged) + + } + + test("merge where both are non-empty") { + val left = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val right = StructType( + StructField("c", LongType) :: Nil) + + val expected = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: + StructField("c", LongType) :: Nil) + + val merged = left.merge(right) + + assert(merged === expected) + } + + test("merge where right contains type conflict") { + val left = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val right = StructType( + StructField("b", LongType) :: Nil) + + intercept[SparkException] { + left.merge(right) + } + } + def checkDataTypeJsonRepr(dataType: DataType): Unit = { test(s"JSON - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) @@ -120,7 +190,7 @@ class DataTypeSuite extends FunSuite { checkDefaultSize(DecimalType(10, 5), 4096) checkDefaultSize(DecimalType.Unlimited, 4096) checkDefaultSize(DateType, 4) - checkDefaultSize(TimestampType,12) + checkDefaultSize(TimestampType, 12) checkDefaultSize(StringType, 4096) checkDefaultSize(BinaryType, 4096) checkDefaultSize(ArrayType(DoubleType, true), 800) @@ -179,11 +249,11 @@ class DataTypeSuite extends FunSuite { expected = false) checkEqualsIgnoreCompatibleNullability( from = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), - to = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), + to = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), expected = false) checkEqualsIgnoreCompatibleNullability( from = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), - to = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), + to = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), expected = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala index a22aa6f244c48..81d7ab010f394 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.types -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite // scalastyle:off -class UTF8StringSuite extends FunSuite { +class UTF8StringSuite extends SparkFunSuite { test("basic") { def check(str: String, len: Int) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index de6a2cd448c47..28b373e258311 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.types.decimal +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.Decimal -import org.scalatest.{PrivateMethodTester, FunSuite} +import org.scalatest.PrivateMethodTester import scala.language.postfixOps -class DecimalSuite extends FunSuite with PrivateMethodTester { +class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("creating decimals") { /** Check that a Decimal has the given string representation, precision and scale */ def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ffe95bb49188f..ed75475a87067 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-catalyst_${scala.binary.version} @@ -54,11 +61,11 @@ test - com.twitter + org.apache.parquet parquet-column - com.twitter + org.apache.parquet parquet-hadoop diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 6895aa1010956..d3efa83380d04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -349,7 +349,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def when(condition: Column, value: Any):Column = this.expr match { + def when(condition: Column, value: Any): Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) case _ => @@ -378,7 +378,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def otherwise(value: Any):Column = this.expr match { + def otherwise(value: Any): Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => if (branches.size % 2 == 0) { CaseWhen(branches :+ lit(value).expr) @@ -716,6 +716,18 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def endsWith(literal: String): Column = this.endsWith(lit(literal)) + /** + * Gives the column an alias. Same as `as`. + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".alias("colB")) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def alias(alias: String): Column = as(alias) + /** * Gives the column an alias. * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f968577bc5848..59f64dd4bc648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -57,14 +57,11 @@ private[sql] object DataFrame { * :: Experimental :: * A distributed collection of data organized into named columns. * - * A [[DataFrame]] is equivalent to a relational table in Spark SQL. There are multiple ways - * to create a [[DataFrame]]: + * A [[DataFrame]] is equivalent to a relational table in Spark SQL. The following example creates + * a [[DataFrame]] by pointing Spark SQL to a Parquet data set. * {{{ - * // Create a DataFrame from Parquet files - * val people = sqlContext.parquetFile("...") - * - * // Create a DataFrame from data sources - * val df = sqlContext.load("...", "json") + * val people = sqlContext.read.parquet("...") // in Scala + * DataFrame people = sqlContext.read().parquet("...") // in Java * }}} * * Once created, it can be manipulated using the various domain-specific-language (DSL) functions @@ -86,8 +83,8 @@ private[sql] object DataFrame { * A more concrete example in Scala: * {{{ * // To create DataFrame using SQLContext - * val people = sqlContext.parquetFile("...") - * val department = sqlContext.parquetFile("...") + * val people = sqlContext.read.parquet("...") + * val department = sqlContext.read.parquet("...") * * people.filter("age > 30") * .join(department, people("deptId") === department("id")) @@ -98,8 +95,8 @@ private[sql] object DataFrame { * and in Java: * {{{ * // To create DataFrame using SQLContext - * DataFrame people = sqlContext.parquetFile("..."); - * DataFrame department = sqlContext.parquetFile("..."); + * DataFrame people = sqlContext.read().parquet("..."); + * DataFrame department = sqlContext.read().parquet("..."); * * people.filter("age".gt(30)) * .join(department, people.col("deptId").equalTo(department("id"))) @@ -255,7 +252,7 @@ class DataFrame private[sql]( val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => Column(oldAttribute).as(newName) } - select(newCols :_*) + select(newCols : _*) } /** @@ -398,22 +395,50 @@ class DataFrame private[sql]( * @since 1.4.0 */ def join(right: DataFrame, usingColumn: String): DataFrame = { + join(right, Seq(usingColumn)) + } + + /** + * Inner equi-join with another [[DataFrame]] using the given columns. + * + * Different from other join functions, the join columns will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * {{{ + * // Joining df1 and df2 using the columns "user_id" and "user_name" + * df1.join(df2, Seq("user_id", "user_name")) + * }}} + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * @group dfops + * @since 1.4.0 + */ + def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] - // Project only one of the join column. - val joinedCol = joined.right.resolve(usingColumn) + // Project only one of the join columns. + val joinedCols = usingColumns.map(col => joined.right.resolve(col)) + val condition = usingColumns.map { col => + catalyst.expressions.EqualTo(joined.left.resolve(col), joined.right.resolve(col)) + }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) => + catalyst.expressions.And(cond, eqTo) + } + Project( - joined.output.filterNot(_ == joinedCol), + joined.output.filterNot(joinedCols.contains(_)), Join( joined.left, joined.right, joinType = Inner, - Some(catalyst.expressions.EqualTo( - joined.left.resolve(usingColumn), - joined.right.resolve(usingColumn)))) + condition) ) } @@ -500,7 +525,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sort(sortCol: String, sortCols: String*): DataFrame = { - sort((sortCol +: sortCols).map(apply) :_*) + sort((sortCol +: sortCols).map(apply) : _*) } /** @@ -531,7 +556,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols :_*) + def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols : _*) /** * Returns a new [[DataFrame]] sorted by the given expressions. @@ -540,7 +565,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs :_*) + def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*) /** * Selects column based on the column name and return it as a [[Column]]. @@ -611,7 +636,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) :_*) + def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*) /** * Selects a set of SQL expressions. This is a variant of `select` that accepts @@ -825,7 +850,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - groupBy().agg(aggExpr, aggExprs :_*) + groupBy().agg(aggExpr, aggExprs : _*) } /** @@ -863,7 +888,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*) + def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) /** * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function @@ -1039,7 +1064,7 @@ class DataFrame private[sql]( val name = field.name if (resolver(name, colName)) col.as(colName) else Column(name) } - select(colNames :_*) + select(colNames : _*) } else { select(Column("*"), col.as(colName)) } @@ -1085,6 +1110,22 @@ class DataFrame private[sql]( } } + /** + * Returns a new [[DataFrame]] with a column dropped. + * This version of drop accepts a Column rather than a name. + * This is a no-op if the DataFrame doesn't have a column + * with an equivalent expression. + * @group dfops + * @since 1.4.1 + */ + def drop(col: Column): DataFrame = { + val attrs = this.logicalPlan.output + val colsAfterDrop = attrs.filter { attr => + attr != col.expr + }.map(attr => Column(attr)) + select(colsAfterDrop : _*) + } + /** * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. * This is an alias for `distinct`. @@ -1262,7 +1303,7 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*) + override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() : _*) /** * Returns the number of rows in the [[DataFrame]]. @@ -1298,7 +1339,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - override def distinct: DataFrame = Distinct(logicalPlan) + override def distinct: DataFrame = dropDuplicates() /** * @group basic @@ -1444,7 +1485,9 @@ class DataFrame private[sql]( //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// - /** Left here for backward compatibility. */ + /** + * @deprecated As of 1.3.0, replaced by `toDF()`. + */ @deprecated("use toDF", "1.3.0") def toSchemaRDD: DataFrame = this @@ -1455,6 +1498,7 @@ class DataFrame private[sql]( * given name; if you pass `false`, it will throw if the table already * exists. * @group output + * @deprecated As of 1.340, replaced by `write().jdbc()`. */ @deprecated("Use write.jdbc()", "1.4.0") def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { @@ -1473,6 +1517,7 @@ class DataFrame private[sql]( * the RDD in order via the simple statement * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. * @group output + * @deprecated As of 1.4.0, replaced by `write().jdbc()`. */ @deprecated("Use write.jdbc()", "1.4.0") def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { @@ -1485,6 +1530,7 @@ class DataFrame private[sql]( * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. * @group output + * @deprecated As of 1.4.0, replaced by `write().parquet()`. */ @deprecated("Use write.parquet(path)", "1.4.0") def saveAsParquetFile(path: String): Unit = { @@ -1508,6 +1554,7 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. */ @deprecated("Use write.saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String): Unit = { @@ -1526,6 +1573,7 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, mode: SaveMode): Unit = { @@ -1545,6 +1593,7 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. */ @deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String): Unit = { @@ -1564,6 +1613,7 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { @@ -1582,6 +1632,8 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. */ @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", "1.4.0") @@ -1606,6 +1658,8 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. */ @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", "1.4.0") @@ -1622,6 +1676,7 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default and * [[SaveMode.ErrorIfExists]] as the save mode. * @group output + * @deprecated As of 1.4.0, replaced by `write().save(path)`. */ @deprecated("Use write.save(path)", "1.4.0") def save(path: String): Unit = { @@ -1632,6 +1687,7 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode, * using the default data source configured by spark.sql.sources.default. * @group output + * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`. */ @deprecated("Use write.mode(mode).save(path)", "1.4.0") def save(path: String, mode: SaveMode): Unit = { @@ -1642,6 +1698,7 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame to the given path based on the given data source, * using [[SaveMode.ErrorIfExists]] as the save mode. * @group output + * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`. */ @deprecated("Use write.format(source).save(path)", "1.4.0") def save(path: String, source: String): Unit = { @@ -1652,6 +1709,7 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame to the given path based on the given data source and * [[SaveMode]] specified by mode. * @group output + * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`. */ @deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0") def save(path: String, source: String, mode: SaveMode): Unit = { @@ -1662,6 +1720,8 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).save(path)`. */ @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( @@ -1676,6 +1736,8 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options * @group output + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).save(path)`. */ @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( @@ -1689,6 +1751,8 @@ class DataFrame private[sql]( /** * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`. */ @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0") def insertInto(tableName: String, overwrite: Boolean): Unit = { @@ -1699,6 +1763,8 @@ class DataFrame private[sql]( * Adds the rows from this RDD to the specified table. * Throws an exception if the table already exists. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().mode(SaveMode.Append).saveAsTable(tableName)`. */ @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0") def insertInto(tableName: String): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala index b87efb58d51e5..2f19ec0403017 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala @@ -28,5 +28,5 @@ private[sql] case class DataFrameHolder(df: DataFrame) { // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = df - def toDF(colNames: String*): DataFrame = df.toDF(colNames :_*) + def toDF(colNames: String*): DataFrame = df.toDF(colNames : _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 5d106c1ac2674..edb9ed7bba56a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -43,7 +43,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson - * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in + * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in * MLlib's Statistics. * * @param col1 the name of the column @@ -97,6 +97,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * The `support` should be greater than 1e-4. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @param support The minimum frequency for an item to be considered `frequent`. Should be greater * than 1e-4. @@ -114,6 +117,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * Uses a `default` support of 1%. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * @@ -128,6 +134,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * frequent element count algorithm described in * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * @@ -143,6 +152,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * Uses a `default` support of 1%. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index f730e4ae00e2b..45b3e1bc627d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -40,22 +40,22 @@ private[sql] object GroupedData { /** * The Grouping Type */ - trait GroupType + private[sql] trait GroupType /** * To indicate it's the GroupBy */ - object GroupByType extends GroupType + private[sql] object GroupByType extends GroupType /** * To indicate it's the CUBE */ - object CubeType extends GroupType + private[sql] object CubeType extends GroupType /** * To indicate it's the ROLLUP */ - object RollupType extends GroupType + private[sql] object RollupType extends GroupType } /** @@ -247,9 +247,9 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def mean(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Average) + aggregateNumericColumns(colNames : _*)(Average) } - + /** * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. @@ -259,7 +259,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def max(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Max) + aggregateNumericColumns(colNames : _*)(Max) } /** @@ -271,7 +271,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def avg(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Average) + aggregateNumericColumns(colNames : _*)(Average) } /** @@ -283,7 +283,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def min(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Min) + aggregateNumericColumns(colNames : _*)(Min) } /** @@ -295,6 +295,6 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Sum) + aggregateNumericColumns(colNames : _*)(Sum) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 77c6af27d1007..be786f9b7f49e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -71,8 +71,12 @@ private[spark] object SQLConf { // Whether to perform partition discovery when loading external data sources. Default to true. val PARTITION_DISCOVERY_ENABLED = "spark.sql.sources.partitionDiscovery.enabled" + // Whether to perform partition column type inference. Default to true. + val PARTITION_COLUMN_TYPE_INFERENCE = "spark.sql.sources.partitionColumnTypeInference.enabled" + // The output committer class used by FSBasedRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. + // NOTE: This property should be set in Hadoop `Configuration` rather than Spark `SQLConf` val OUTPUT_COMMITTER_CLASS = "spark.sql.sources.outputCommitterClass" // Whether to perform eager analysis when constructing a dataframe. @@ -250,6 +254,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def partitionDiscoveryEnabled() = getConf(SQLConf.PARTITION_DISCOVERY_ENABLED, "true").toBoolean + private[spark] def partitionColumnTypeInferenceEnabled() = + getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE, "true").toBoolean + // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. private[spark] def schemaStringLengthThreshold: Int = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1ea596dddff02..8cad3885b7d46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -27,8 +27,6 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import com.google.common.reflect.TypeToken - import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} @@ -122,7 +120,11 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO how to handle the temp function per user session? @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(conf) + protected[sql] lazy val functionRegistry: FunctionRegistry = { + val fr = new SimpleFunctionRegistry + FunctionRegistry.expressions.foreach { case (name, func) => fr.registerFunction(name, func) } + fr + } @transient protected[sql] lazy val analyzer: Analyzer = @@ -184,9 +186,28 @@ class SQLContext(@transient val sparkContext: SparkContext) conf.dialect } - sparkContext.getConf.getAll.foreach { - case (key, value) if key.startsWith("spark.sql") => setConf(key, value) - case _ => + { + // We extract spark sql settings from SparkContext's conf and put them to + // Spark SQL's conf. + // First, we populate the SQLConf (conf). So, we can make sure that other values using + // those settings in their construction can get the correct settings. + // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version + // and spark.sql.hive.metastore.jars to get correctly constructed. + val properties = new Properties + sparkContext.getConf.getAll.foreach { + case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value) + case _ => + } + // We directly put those settings to conf to avoid of calling setConf, which may have + // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive + // get constructed. If we call setConf directly, the constructed metadataHive may have + // wrong settings, or the construction may fail. + conf.setConf(properties) + // After we have populated SQLConf, we call setConf to populate other confs in the subclass + // (e.g. hiveconf in HiveContext). + properties.foreach { + case (key, value) => setConf(key, value) + } } @transient @@ -300,7 +321,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit class StringToColumn(val sc: StringContext) { def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args :_*)) + new ColumnName(sc.s(args : _*)) } } @@ -392,7 +413,7 @@ class SQLContext(@transient val sparkContext: SparkContext) SparkPlan.currentContext.set(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes - val rowRDD = RDDConversions.productToRowRdd(rdd, schema) + val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self)) } @@ -688,7 +709,18 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * :: Experimental :: * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements - * in an range from `start` to `end`(exclusive) with step value 1. + * in an range from 0 to `end` (exclusive) with step value 1. + * + * @since 1.4.1 + * @group dataframe + */ + @Experimental + def range(end: Long): DataFrame = range(0, end) + + /** + * :: Experimental :: + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end` (exclusive) with step value 1. * * @since 1.4.0 * @group dataframe @@ -703,7 +735,7 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * :: Experimental :: * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements - * in an range from `start` to `end`(exclusive) with an step value, with partition number + * in an range from `start` to `end` (exclusive) with an step value, with partition number * specified. * * @since 1.4.0 @@ -888,6 +920,11 @@ class SQLContext(@transient val sparkContext: SparkContext) tlSession.remove() } + protected[sql] def setSession(session: SQLSession): Unit = { + detachSession() + tlSession.set(session) + } + protected[sql] class SQLSession { // Note that this is a lazy val so we can override the default value in subclasses. protected[sql] lazy val conf: SQLConf = new SQLConf @@ -1011,7 +1048,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * Returns a Catalyst Schema for the given java bean class. */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { - val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass)) + val (dataType, _) = JavaTypeInference.inferDataType(beanClass) dataType.asInstanceOf[StructType].fields.map { f => AttributeReference(f.name, f.dataType, f.nullable)() } @@ -1023,21 +1060,33 @@ class SQLContext(@transient val sparkContext: SparkContext) //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ @deprecated("use createDataFrame", "1.3.0") def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ @deprecated("use createDataFrame", "1.3.0") def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ @deprecated("use createDataFrame", "1.3.0") def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) } + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ @deprecated("use createDataFrame", "1.3.0") def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) @@ -1048,6 +1097,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * [[DataFrame]] if no paths are passed in. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().parquet()`. */ @deprecated("Use read.parquet()", "1.4.0") @scala.annotation.varargs @@ -1067,6 +1117,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It goes through the entire dataset once to determine the schema. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonFile(path: String): DataFrame = { @@ -1078,6 +1129,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonFile(path: String, schema: StructType): DataFrame = { @@ -1086,6 +1138,7 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonFile(path: String, samplingRatio: Double): DataFrame = { @@ -1098,6 +1151,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It goes through the entire dataset once to determine the schema. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: RDD[String]): DataFrame = read.json(json) @@ -1108,6 +1162,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It goes through the entire dataset once to determine the schema. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) @@ -1117,6 +1172,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { @@ -1128,6 +1184,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * schema, returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { @@ -1139,6 +1196,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * schema, returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { @@ -1150,6 +1208,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * schema, returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { @@ -1161,6 +1220,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * using the default data source configured by spark.sql.sources.default. * * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().load(path)`. */ @deprecated("Use read.load(path)", "1.4.0") def load(path: String): DataFrame = { @@ -1171,6 +1231,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * Returns the dataset stored at path as a DataFrame, using the given data source. * * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. */ @deprecated("Use read.format(source).load(path)", "1.4.0") def load(path: String, source: String): DataFrame = { @@ -1182,6 +1243,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * a set of options as a DataFrame. * * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. */ @deprecated("Use read.format(source).options(options).load()", "1.4.0") def load(source: String, options: java.util.Map[String, String]): DataFrame = { @@ -1193,6 +1255,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * a set of options as a DataFrame. * * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. */ @deprecated("Use read.format(source).options(options).load()", "1.4.0") def load(source: String, options: Map[String, String]): DataFrame = { @@ -1204,6 +1267,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. * * @group genericdata + * @deprecated As of 1.4.0, replaced by + * `read().format(source).schema(schema).options(options).load()`. */ @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = @@ -1216,6 +1281,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. * * @group genericdata + * @deprecated As of 1.4.0, replaced by + * `read().format(source).schema(schema).options(options).load()`. */ @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { @@ -1227,6 +1294,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * url named table. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. */ @deprecated("use read.jdbc()", "1.4.0") def jdbc(url: String, table: String): DataFrame = { @@ -1244,6 +1312,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split * evenly into this many partitions * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. */ @deprecated("use read.jdbc()", "1.4.0") def jdbc( @@ -1263,6 +1332,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * of the [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. */ @deprecated("use read.jdbc()", "1.4.0") def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index 6b1ae81972e4e..305b306a79871 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -54,15 +54,15 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr } } - protected val AS = Keyword("AS") - protected val CACHE = Keyword("CACHE") - protected val CLEAR = Keyword("CLEAR") - protected val IN = Keyword("IN") - protected val LAZY = Keyword("LAZY") - protected val SET = Keyword("SET") - protected val SHOW = Keyword("SHOW") - protected val TABLE = Keyword("TABLE") - protected val TABLES = Keyword("TABLES") + protected val AS = Keyword("AS") + protected val CACHE = Keyword("CACHE") + protected val CLEAR = Keyword("CLEAR") + protected val IN = Keyword("IN") + protected val LAZY = Keyword("LAZY") + protected val SET = Keyword("SET") + protected val SHOW = Keyword("SHOW") + protected val TABLE = Keyword("TABLE") + protected val TABLES = Keyword("TABLES") protected val UNCACHE = Keyword("UNCACHE") override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | show | others diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 423ecdff5804a..43b62f0e822f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -106,7 +106,7 @@ private[r] object SQLUtils { dfCols.map { col => colToRBytes(col) - } + } } def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { @@ -121,7 +121,7 @@ private[r] object SQLUtils { val numRows = col.length val bos = new ByteArrayOutputStream() val dos = new DataOutputStream(bos) - + SerDe.writeInt(dos, numRows) col.map { item => @@ -139,4 +139,19 @@ private[r] object SQLUtils { case "ignore" => SaveMode.Ignore } } + + def loadDF( + sqlContext: SQLContext, + source: String, + options: java.util.Map[String, String]): DataFrame = { + sqlContext.read.format(source).options(options).load() + } + + def loadDF( + sqlContext: SQLContext, + source: String, + schema: StructType, + options: java.util.Map[String, String]): DataFrame = { + sqlContext.read.format(source).schema(schema).options(options).load() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index a59d42cdd6028..3db26fad2b92f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -236,7 +236,7 @@ private[sql] case class InMemoryColumnarTableScan( case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l - case IsNull(a: Attribute) => statsFor(a).nullCount > 0 + case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 3e46596ecf6ac..f25d10fec0411 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -296,7 +296,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ .sliding(2) .map { case Seq(a) => true - case Seq(a,b) => a compatibleWith b + case Seq(a, b) => a.compatibleWith(b) }.exists(!_) // Adds Exchange or Sort operators as required diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index a500269f3cdcf..f931dc95ef575 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -21,9 +21,9 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Row, SQLContext} /** @@ -31,26 +31,19 @@ import org.apache.spark.sql.{Row, SQLContext} */ @DeveloperApi object RDDConversions { - def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = { + def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[Row] = { data.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val bufferedIterator = iterator.buffered - val mutableRow = new SpecificMutableRow(schema.fields.map(_.dataType)) - val schemaFields = schema.fields.toArray - val converters = schemaFields.map { - f => CatalystTypeConverters.createToCatalystConverter(f.dataType) - } - bufferedIterator.map { r => - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = converters(i)(r.productElement(i)) - i += 1 - } - - mutableRow + val numColumns = outputTypes.length + val mutableRow = new GenericMutableRow(numColumns) + val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) + iterator.map { r => + var i = 0 + while (i < numColumns) { + mutableRow(i) = converters(i)(r.productElement(i)) + i += 1 } + + mutableRow } } } @@ -58,26 +51,19 @@ object RDDConversions { /** * Convert the objects inside Row into the types Catalyst expected. */ - def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = { + def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[Row] = { data.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val bufferedIterator = iterator.buffered - val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray) - val schemaFields = schema.fields.toArray - val converters = schemaFields.map { - f => CatalystTypeConverters.createToCatalystConverter(f.dataType) - } - bufferedIterator.map { r => - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = converters(i)(r(i)) - i += 1 - } - - mutableRow + val numColumns = outputTypes.length + val mutableRow = new GenericMutableRow(numColumns) + val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) + iterator.map { r => + var i = 0 + while (i < numColumns) { + mutableRow(i) = converters(i)(r(i)) + i += 1 } + + mutableRow } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 2ec7d4fbc92de..3e27c1bde2dfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -138,15 +138,15 @@ case class GeneratedAggregate( case UnscaledValue(e) => e case _ => expr } - // partial sum result can be null only when no input rows present + // partial sum result can be null only when no input rows present val updateFunction = If( IsNotNull(actualExpr), Coalesce( Add( - Coalesce(currentSum :: zero :: Nil), + Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: zero :: Nil), currentSum) - + val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -155,7 +155,7 @@ case class GeneratedAggregate( } AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - + case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal.create(null, expr.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3f6a0345bc17d..7a1331a39151a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -243,8 +243,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case (predicate, None) => predicate // Filter needs to be applied above when it contains partitioning // columns - case (predicate, _) if(!predicate.references.map(_.name).toSet - .intersect (partitionColNames).isEmpty) => predicate + case (predicate, _) + if !predicate.references.map(_.name).toSet.intersect(partitionColNames).isEmpty => + predicate } } } else { @@ -270,7 +271,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { projectList, filters, identity[Seq[Expression]], // All filters still need to be evaluated. - InMemoryColumnarTableScan(_, filters, mem)) :: Nil + InMemoryColumnarTableScan(_, filters, mem)) :: Nil case _ => Nil } } @@ -283,8 +284,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: RunnableCommand => ExecutedCommand(r) :: Nil case logical.Distinct(child) => - execution.Distinct(partial = false, - execution.Distinct(partial = true, planLater(child))) :: Nil + throw new IllegalStateException( + "logical distinct operator should have been replaced by aggregate in the optimizer") case logical.Repartition(numPartitions, shuffle, child) => execution.Repartition(numPartitions, shuffle, planLater(child)) :: Nil case logical.SortPartitions(sortExprs, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6cb67b4bbbb65..fb42072f9d5a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -65,7 +65,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { * :: DeveloperApi :: * Sample the dataset. * @param lowerBound Lower-bound of the sampling probability (usually 0.0) - * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled * will be ub - lb. * @param withReplacement Whether to sample with replacement. * @param seed the random seed @@ -230,37 +230,6 @@ case class ExternalSort( override def outputOrdering: Seq[SortOrder] = sortOrder } -/** - * :: DeveloperApi :: - * Computes the set of distinct input rows using a HashSet. - * @param partial when true the distinct operation is performed partially, per partition, without - * shuffling the data. - * @param child the input query plan. - */ -@DeveloperApi -case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output - - override def requiredChildDistribution: Seq[Distribution] = - if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output) :: Nil - - protected override def doExecute(): RDD[Row] = { - child.execute().mapPartitions { iter => - val hashSet = new scala.collection.mutable.HashSet[Row]() - - var currentRow: Row = null - while (iter.hasNext) { - currentRow = iter.next() - if (!hashSet.contains(currentRow)) { - hashSet.add(currentRow.copy()) - } - } - - hashSet.iterator - } - } -} - /** * :: DeveloperApi :: * Return a new RDD that has exactly `numPartitions` partitions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 9ac732b55b188..e228a60c9029f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -39,8 +39,6 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { */ @transient private[this] var count: Long = 0L - override type EvaluatedType = Long - override def nullable: Boolean = false override def dataType: DataType = LongType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index c2c6cbd491598..1272793f88cd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.types.{IntegerType, DataType} */ private[sql] case object SparkPartitionID extends LeafExpression { - override type EvaluatedType = Int - override def nullable: Boolean = false override def dataType: DataType = IntegerType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 640fc26ba3baa..a32e5fc4f7ea4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -39,7 +39,7 @@ case class BroadcastLeftSemiJoinHash( override def output: Seq[Attribute] = left.output protected override def doExecute(): RDD[Row] = { - val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator + val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator val hashSet = new java.util.HashSet[Row]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 45574392996ca..c21a453115292 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -48,7 +48,8 @@ case class HashOuterJoin( case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + case x => + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } override def requiredChildDistribution: Seq[ClusteredDistribution] = @@ -63,7 +64,7 @@ case class HashOuterJoin( case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) case x => - throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } } @@ -216,7 +217,8 @@ case class HashOuterJoin( rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow) } - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + case x => + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 11b2897f76786..55f3ff4709013 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -56,7 +56,7 @@ private[spark] case class PythonUDF( def nullable: Boolean = true - override def eval(input: Row): PythonUDF.this.EvaluatedType = { + override def eval(input: Row): Any = { sys.error("PythonUDFs can not be directly evaluated.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 5ae7e107544f8..c41c21c0eeb50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -62,7 +62,7 @@ private[sql] object FrequentItems extends Logging { } /** - * Finding frequent items for columns, possibly with false positives. Using the + * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * The `support` should be greater than 1e-4. @@ -75,7 +75,7 @@ private[sql] object FrequentItems extends Logging { * @return A Local DataFrame with the Array of frequent items for each column. */ private[sql] def singlePassFreqItems( - df: DataFrame, + df: DataFrame, cols: Seq[String], support: Double): DataFrame = { require(support >= 1e-4, s"support ($support) must be greater than 1e-4.") @@ -88,8 +88,8 @@ private[sql] object FrequentItems extends Logging { val index = originalSchema.fieldIndex(name) (name, originalSchema.fields(index).dataType) } - - val freqItems = df.select(cols.map(Column(_)):_*).rdd.aggregate(countMaps)( + + val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { @@ -110,7 +110,7 @@ private[sql] object FrequentItems extends Logging { } ) val justItems = freqItems.map(m => m.baseMap.keys.toSeq) - val resultRow = Row(justItems:_*) + val resultRow = Row(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => StructField(v._1 + "_freqItems", ArrayType(v._2, false)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index d22f5fd2d439c..93383e5a62f11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.Logging -import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.{Row, Column, DataFrame} import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ private[sql] object StatFunctions extends Logging { - + /** Calculate the Pearson Correlation Coefficient for the given columns */ private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { val counts = collectStatisticalData(df, cols) @@ -116,7 +116,10 @@ private[sql] object StatFunctions extends Logging { s"exceed 1e4. Currently $columnSize") val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) => val countsRow = new GenericMutableRow(columnSize + 1) - rows.foreach { row => + rows.foreach { (row: Row) => + // row.get(0) is column 1 + // row.get(1) is column 2 + // row.get(3) is the frequency countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts @@ -126,6 +129,6 @@ private[sql] object StatFunctions extends Logging { val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq val schema = StructType(StructField(tableName, StringType) +: headerNames) - new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)) + new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index d4003b2d9cbf6..e9b60841fc28c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -79,3 +79,20 @@ object Window { } } + +/** + * :: Experimental :: + * Utility functions for defining window in DataFrames. + * + * {{{ + * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * + * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING + * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) + * }}} + * + * @since 1.4.0 + */ +@Experimental +class Window private() // So we can see Window in JavaDoc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9a23cfb89ca12..454af47913bf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,7 +24,6 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -187,7 +186,7 @@ object functions { */ @scala.annotation.varargs def countDistinct(columnName: String, columnNames: String*): Column = - countDistinct(Column(columnName), columnNames.map(Column.apply) :_*) + countDistinct(Column(columnName), columnNames.map(Column.apply) : _*) /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -1299,7 +1298,7 @@ object functions { * @since 1.4.0 */ def toRadians(columnName: String): Column = toRadians(Column(columnName)) - + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index be03a237b6c4e..db68b9c86db1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -46,10 +46,15 @@ private[sql] object JDBCRDD extends Logging { * @param sqlType - A field of java.sql.Types * @return The Catalyst type corresponding to sqlType. */ - private def getCatalystType(sqlType: Int, precision: Int, scale: Int): DataType = { + private def getCatalystType( + sqlType: Int, + precision: Int, + scale: Int, + signed: Boolean): DataType = { val answer = sqlType match { + // scalastyle:off case java.sql.Types.ARRAY => null - case java.sql.Types.BIGINT => LongType + case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType.Unlimited } case java.sql.Types.BINARY => BinaryType case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks case java.sql.Types.BLOB => BinaryType @@ -64,7 +69,7 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.DISTINCT => null case java.sql.Types.DOUBLE => DoubleType case java.sql.Types.FLOAT => FloatType - case java.sql.Types.INTEGER => IntegerType + case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType } case java.sql.Types.JAVA_OBJECT => null case java.sql.Types.LONGNVARCHAR => StringType case java.sql.Types.LONGVARBINARY => BinaryType @@ -88,7 +93,8 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType - case _ => null + case _ => null + // scalastyle:on } if (answer == null) throw new SQLException("Unsupported type " + sqlType) @@ -123,11 +129,12 @@ private[sql] object JDBCRDD extends Logging { val typeName = rsmd.getColumnTypeName(i + 1) val fieldSize = rsmd.getPrecision(i + 1) val fieldScale = rsmd.getScale(i + 1) + val isSigned = rsmd.isSigned(i + 1) val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls val metadata = new MetadataBuilder().putString("name", columnName) val columnType = dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( - getCatalystType(dataType, fieldSize, fieldScale)) + getCatalystType(dataType, fieldSize, fieldScale, isSigned)) fields(i) = StructField(columnName, columnType, nullable, metadata.build()) i = i + 1 } @@ -204,12 +211,14 @@ private[sql] object JDBCRDD extends Logging { requiredColumns: Array[String], filters: Array[Filter], parts: Array[Partition]): RDD[Row] = { + val dialect = JdbcDialects.get(url) + val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) new JDBCRDD( sc, getConnector(driver, url, properties), pruneSchema(schema, requiredColumns), fqTable, - requiredColumns, + quotedColumns, filters, parts, properties) @@ -255,7 +264,7 @@ private[sql] class JDBCRDD( } private def escapeSql(value: String): String = - if (value == null) null else StringUtils.replace(value, "'", "''") + if (value == null) null else StringUtils.replace(value, "'", "''") /** * Turns a single Filter into a String representing a SQL expression. @@ -297,7 +306,7 @@ private[sql] class JDBCRDD( // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that // we don't have to potentially poke around in the Metadata once for every - // row. + // row. // Is there a better way to do this? I'd rather be using a type that // contains only the tags I define. abstract class JDBCConversion @@ -318,19 +327,19 @@ private[sql] class JDBCRDD( */ def getConversions(schema: StructType): Array[JDBCConversion] = { schema.fields.map(sf => sf.dataType match { - case BooleanType => BooleanConversion - case DateType => DateConversion + case BooleanType => BooleanConversion + case DateType => DateConversion case DecimalType.Unlimited => DecimalConversion(None) - case DecimalType.Fixed(d) => DecimalConversion(Some(d)) - case DoubleType => DoubleConversion - case FloatType => FloatConversion - case IntegerType => IntegerConversion - case LongType => + case DecimalType.Fixed(d) => DecimalConversion(Some(d)) + case DoubleType => DoubleConversion + case FloatType => FloatConversion + case IntegerType => IntegerConversion + case LongType => if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion - case StringType => StringConversion - case TimestampType => TimestampConversion - case BinaryType => BinaryConversion - case _ => throw new IllegalArgumentException(s"Unsupported field $sf") + case StringType => StringConversion + case TimestampType => TimestampConversion + case BinaryType => BinaryConversion + case _ => throw new IllegalArgumentException(s"Unsupported field $sf") }).toArray } @@ -371,8 +380,8 @@ private[sql] class JDBCRDD( while (i < conversions.length) { val pos = i + 1 conversions(i) match { - case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) - case DateConversion => + case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) + case DateConversion => // DateUtils.fromJavaDate does not handle null value, so we need to check it. val dateVal = rs.getDate(pos) if (dateVal != null) { @@ -402,14 +411,14 @@ private[sql] class JDBCRDD( } else { mutableRow.update(i, Decimal(decimalVal)) } - case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) - case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) - case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) - case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) + case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) + case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) + case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) + case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 - case StringConversion => mutableRow.setString(i, rs.getString(pos)) - case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) - case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) + case StringConversion => mutableRow.setString(i, rs.getString(pos)) + case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) + case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) case BinaryLongConversion => { val bytes = rs.getBytes(pos) var ans = 0L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 09d6865457df6..30f9190d45bf8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -54,7 +54,7 @@ private[sql] object JDBCRelation { if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0)) // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. - val stride: Long = (partitioning.upperBound / numPartitions + val stride: Long = (partitioning.upperBound / numPartitions - partitioning.lowerBound / numPartitions) var i: Int = 0 var currentValue: Long = partitioning.lowerBound @@ -140,10 +140,10 @@ private[sql] case class JDBCRelation( filters, parts) } - + override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) .jdbc(url, table, properties) - } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 6a169e106b968..8849fc2f1f0ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.jdbc +import java.sql.Types + import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi -import java.sql.Types - /** * :: DeveloperApi :: * A database type definition coupled with the jdbc type needed to send null @@ -80,6 +80,14 @@ abstract class JdbcDialect { * @return The new JdbcType if there is an override for this DataType */ def getJDBCType(dt: DataType): Option[JdbcType] = None + + /** + * Quotes the identifier. This is used to put quotes around the identifier in case the column + * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space). + */ + def quoteIdentifier(colName: String): String = { + s""""$colName"""" + } } /** @@ -141,18 +149,19 @@ object JdbcDialects { @DeveloperApi class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { - require(!dialects.isEmpty) + require(dialects.nonEmpty) - def canHandle(url : String): Boolean = + override def canHandle(url : String): Boolean = dialects.map(_.canHandle(url)).reduce(_ && _) override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = - dialects.map(_.getCatalystType(sqlType, typeName, size, md)).flatten.headOption - - override def getJDBCType(dt: DataType): Option[JdbcType] = - dialects.map(_.getJDBCType(dt)).flatten.headOption + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption + } + override def getJDBCType(dt: DataType): Option[JdbcType] = { + dialects.flatMap(_.getJDBCType(dt)).headOption + } } /** @@ -161,7 +170,7 @@ class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { */ @DeveloperApi case object NoopDialect extends JdbcDialect { - def canHandle(url : String): Boolean = true + override def canHandle(url : String): Boolean = true } /** @@ -170,7 +179,7 @@ case object NoopDialect extends JdbcDialect { */ @DeveloperApi case object PostgresDialect extends JdbcDialect { - def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { @@ -196,7 +205,7 @@ case object PostgresDialect extends JdbcDialect { */ @DeveloperApi case object MySQLDialect extends JdbcDialect { - def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { @@ -208,4 +217,8 @@ case object MySQLDialect extends JdbcDialect { Some(BooleanType) } else None } + + override def quoteIdentifier(colName: String): String = { + s"`$colName`" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index f21dd29aca37f..dd8aaf6474895 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -240,10 +240,10 @@ package object jdbc { } } } - + def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName + case driver => driver.getClass.getCanonicalName } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 9c58b8e4bb16a..565d10247f10e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -124,7 +124,7 @@ private[sql] object InferSchema { case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) case ArrayType(struct: StructType, containsNull) => ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType =>nullTypeToStringType(struct) + case struct: StructType => nullTypeToStringType(struct) case other: DataType => other } @@ -147,7 +147,7 @@ private[sql] object InferSchema { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2).getOrElse { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { case (other: DataType, NullType) => other diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala index 80bf74aa02602..325f54b6808a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala @@ -33,7 +33,7 @@ private[sql] object JacksonGenerator { */ def apply(rowSchema: StructType, gen: JsonGenerator)(row: Row): Unit = { def valWriter: (DataType, Any) => Unit = { - case (_, null) | (NullType, _) => gen.writeNull() + case (_, null) | (NullType, _) => gen.writeNull() case (StringType, v: String) => gen.writeString(v) case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) case (IntegerType, v: Int) => gen.writeNumber(v) @@ -48,16 +48,16 @@ private[sql] object JacksonGenerator { case (DateType, v) => gen.writeString(v.toString) case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v)) - case (ArrayType(ty, _), v: Seq[_] ) => + case (ArrayType(ty, _), v: Seq[_]) => gen.writeStartArray() - v.foreach(valWriter(ty,_)) + v.foreach(valWriter(ty, _)) gen.writeEndArray() - case (MapType(kv,vv, _), v: Map[_,_]) => + case (MapType(kv, vv, _), v: Map[_, _]) => gen.writeStartObject() v.foreach { p => gen.writeFieldName(p._1.toString) - valWriter(vv,p._2) + valWriter(vv, p._2) } gen.writeEndObject() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 037a6d60a2ed6..7e1e21f5fbb99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -141,7 +141,7 @@ private[sql] object JsonRDD extends Logging { case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) case ArrayType(struct: StructType, containsNull) => ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType =>nullTypeToStringType(struct) + case struct: StructType => nullTypeToStringType(struct) case other: DataType => other } StructField(fieldName, newType, nullable) @@ -155,7 +155,7 @@ private[sql] object JsonRDD extends Logging { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2) match { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) match { case Some(commonType) => commonType case None => // t1 or t2 is a StructType, ArrayType, or an unexpected type. @@ -216,7 +216,7 @@ private[sql] object JsonRDD extends Logging { case map: Map[_, _] => StructType(Nil) // We have an array of arrays. If those element arrays do not have the same // element types, we will return ArrayType[StringType]. - case seq: Seq[_] => typeOfArray(seq) + case seq: Seq[_] => typeOfArray(seq) case value => typeOfPrimitiveValue(value) } }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) @@ -406,7 +406,7 @@ private[sql] object JsonRDD extends Logging { } } - private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ + private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any = { if (value == null) { null } else { @@ -434,7 +434,7 @@ private[sql] object JsonRDD extends Logging { } } - private def asRow(json: Map[String,Any], schema: StructType): Row = { + private def asRow(json: Map[String, Any], schema: StructType): Row = { // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 3f97a11ceb97d..4e94fd07a8771 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -44,6 +44,7 @@ package object sql { /** * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. + * @deprecated As of 1.3.0, replaced by `DataFrame`. */ @deprecated("1.3.0", "use DataFrame") type SchemaRDD = DataFrame diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala index f5ce2718bec4a..62c4e92ebec68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -21,9 +21,9 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter -import parquet.Log -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} +import org.apache.parquet.Log +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 36cb5e03bbca7..85c2ce740fe52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -23,9 +23,9 @@ import java.util.{TimeZone, Calendar} import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} import jodd.datetime.JDateTime -import parquet.column.Dictionary -import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} -import parquet.schema.MessageType +import org.apache.parquet.column.Dictionary +import org.apache.parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} +import org.apache.parquet.schema.MessageType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.CatalystConverter.FieldType @@ -243,8 +243,10 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { /** * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in * a long (i.e. precision <= 18) + * + * Returned value is needed by CatalystConverter, which doesn't reuse the Decimal object. */ - protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = { + protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Decimal = { val precision = ctype.precisionInfo.get.precision val scale = ctype.precisionInfo.get.scale val bytes = value.getBytes @@ -480,7 +482,7 @@ private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverte override def hasDictionarySupport: Boolean = true - override def setDictionary(dictionary: Dictionary):Unit = + override def setDictionary(dictionary: Dictionary): Unit = dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes } override def addValueFromDictionary(dictionaryId: Int): Unit = @@ -591,8 +593,8 @@ private[parquet] class CatalystArrayConverter( CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, elementType, false), - fieldIndex=0, - parent=this) + fieldIndex = 0, + parent = this) override def getConverter(fieldIndex: Int): Converter = converter @@ -601,7 +603,7 @@ private[parquet] class CatalystArrayConverter( override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { // fieldIndex is ignored (assumed to be zero but not checked) - if(value == null) { + if (value == null) { throw new IllegalArgumentException("Null values inside Parquet arrays are not supported!") } buffer += value @@ -654,8 +656,8 @@ private[parquet] class CatalystNativeArrayConverter( CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, elementType, false), - fieldIndex=0, - parent=this) + fieldIndex = 0, + parent = this) override def getConverter(fieldIndex: Int): Converter = converter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index f0f4e7d147e75..88ae88e9684c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -21,11 +21,11 @@ import java.nio.ByteBuffer import com.google.common.io.BaseEncoding import org.apache.hadoop.conf.Configuration -import parquet.filter2.compat.FilterCompat -import parquet.filter2.compat.FilterCompat._ -import parquet.filter2.predicate.FilterApi._ -import parquet.filter2.predicate.{FilterApi, FilterPredicate} -import parquet.io.api.Binary +import org.apache.parquet.filter2.compat.FilterCompat +import org.apache.parquet.filter2.compat.FilterCompat._ +import org.apache.parquet.filter2.predicate.FilterApi._ +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} +import org.apache.parquet.io.api.Binary import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index fcb9513ab66f6..704cf56f38265 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -18,20 +18,21 @@ package org.apache.spark.sql.parquet import java.io.IOException -import java.util.logging.Level +import java.util.logging.{Level, Logger => JLogger} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction -import org.apache.spark.sql.types.{StructType, DataType} -import parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} -import parquet.hadoop.metadata.CompressionCodecName -import parquet.schema.MessageType +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat, ParquetRecordReader} +import org.apache.parquet.schema.MessageType +import org.apache.parquet.{Log => ParquetLog} -import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLContext} /** * Relation that consists of data stored in a Parquet columnar format. @@ -94,40 +95,44 @@ private[sql] case class ParquetRelation( private[sql] object ParquetRelation { def enableLogForwarding() { - // Note: the parquet.Log class has a static initializer that - // sets the java.util.logging Logger for "parquet". This + // Note: the org.apache.parquet.Log class has a static initializer that + // sets the java.util.logging Logger for "org.apache.parquet". This // checks first to see if there's any handlers already set // and if not it creates them. If this method executes prior // to that class being loaded then: // 1) there's no handlers installed so there's none to // remove. But when it IS finally loaded the desired affect // of removing them is circumvented. - // 2) The parquet.Log static initializer calls setUseParentHanders(false) + // 2) The parquet.Log static initializer calls setUseParentHandlers(false) // undoing the attempt to override the logging here. // // Therefore we need to force the class to be loaded. // This should really be resolved by Parquet. - Class.forName(classOf[parquet.Log].getName) + Class.forName(classOf[ParquetLog].getName) // Note: Logger.getLogger("parquet") has a default logger // that appends to Console which needs to be cleared. - val parquetLogger = java.util.logging.Logger.getLogger("parquet") + val parquetLogger = JLogger.getLogger(classOf[ParquetLog].getPackage.getName) parquetLogger.getHandlers.foreach(parquetLogger.removeHandler) - // TODO(witgo): Need to set the log level ? - // if(parquetLogger.getLevel != null) parquetLogger.setLevel(null) - if (!parquetLogger.getUseParentHandlers) parquetLogger.setUseParentHandlers(true) + parquetLogger.setUseParentHandlers(true) - // Disables WARN log message in ParquetOutputCommitter. + // Disables a WARN log message in ParquetOutputCommitter. We first ensure that + // ParquetOutputCommitter is loaded and the static LOG field gets initialized. // See https://issues.apache.org/jira/browse/SPARK-5968 for details Class.forName(classOf[ParquetOutputCommitter].getName) - java.util.logging.Logger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) + JLogger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) + + // Similar as above, disables a unnecessary WARN log message in ParquetRecordReader. + // See https://issues.apache.org/jira/browse/PARQUET-220 for details + Class.forName(classOf[ParquetRecordReader[_]].getName) + JLogger.getLogger(classOf[ParquetRecordReader[_]].getName).setLevel(Level.OFF) } // The element type for the RDDs that this relation maps to. type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow // The compression type - type CompressionType = parquet.hadoop.metadata.CompressionCodecName + type CompressionType = org.apache.parquet.hadoop.metadata.CompressionCodecName // The parquet compression short names val shortParquetCompressionCodecNames = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 90950f924a054..1e694f2feabee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -33,13 +33,13 @@ import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat => NewFileOutputFormat} -import parquet.hadoop._ -import parquet.hadoop.api.ReadSupport.ReadContext -import parquet.hadoop.api.{InitContext, ReadSupport} -import parquet.hadoop.metadata.GlobalMetaData -import parquet.hadoop.util.ContextUtil -import parquet.io.ParquetDecodingException -import parquet.schema.MessageType +import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} +import org.apache.parquet.hadoop.metadata.GlobalMetaData +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.io.ParquetDecodingException +import org.apache.parquet.schema.MessageType import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil @@ -78,7 +78,7 @@ private[sql] case class ParquetTableScan( }.toArray protected override def doExecute(): RDD[Row] = { - import parquet.filter2.compat.FilterCompat.FilterPredicateCompat + import org.apache.parquet.filter2.compat.FilterCompat.FilterPredicateCompat val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) @@ -136,7 +136,7 @@ private[sql] case class ParquetTableScan( baseRDD.mapPartitionsWithInputSplit { case (split, iter) => val partValue = "([^=]+)=([^=]+)".r val partValues = - split.asInstanceOf[parquet.hadoop.ParquetInputSplit] + split.asInstanceOf[org.apache.parquet.hadoop.ParquetInputSplit] .getPath .toString .split("/") @@ -378,7 +378,7 @@ private[sql] case class InsertIntoParquetTable( * to imported ones. */ private[parquet] class AppendingParquetOutputFormat(offset: Int) - extends parquet.hadoop.ParquetOutputFormat[Row] { + extends org.apache.parquet.hadoop.ParquetOutputFormat[Row] { // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} var committer: OutputCommitter = null @@ -431,7 +431,7 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) * RecordFilter we want to use. */ private[parquet] class FilteringParquetRowInputFormat - extends parquet.hadoop.ParquetInputFormat[Row] with Logging { + extends org.apache.parquet.hadoop.ParquetInputFormat[Row] with Logging { private var fileStatuses = Map.empty[Path, FileStatus] @@ -439,7 +439,7 @@ private[parquet] class FilteringParquetRowInputFormat inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { - import parquet.filter2.compat.FilterCompat.NoOpFilter + import org.apache.parquet.filter2.compat.FilterCompat.NoOpFilter val readSupport: ReadSupport[Row] = new RowReadSupport() @@ -501,7 +501,7 @@ private[parquet] class FilteringParquetRowInputFormat globalMetaData = new GlobalMetaData(globalMetaData.getSchema, mergedMetadata, globalMetaData.getCreatedBy) - val readContext = getReadSupport(configuration).init( + val readContext = ParquetInputFormat.getReadSupportInstance(configuration).init( new InitContext(configuration, globalMetaData.getKeyValueMetaData, globalMetaData.getSchema)) @@ -531,8 +531,8 @@ private[parquet] class FilteringParquetRowInputFormat minSplitSize: JLong, readContext: ReadContext): JList[ParquetInputSplit] = { - import parquet.filter2.compat.FilterCompat.Filter - import parquet.filter2.compat.RowGroupFilter + import org.apache.parquet.filter2.compat.FilterCompat.Filter + import org.apache.parquet.filter2.compat.RowGroupFilter import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache @@ -541,13 +541,13 @@ private[parquet] class FilteringParquetRowInputFormat val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] val filter: Filter = ParquetInputFormat.getFilter(configuration) var rowGroupsDropped: Long = 0 - var totalRowGroups: Long = 0 + var totalRowGroups: Long = 0 // Ugly hack, stuck with it until PR: // https://github.com/apache/incubator-parquet-mr/pull/17 // is resolved val generateSplits = - Class.forName("parquet.hadoop.ClientSideMetadataSplitStrategy") + Class.forName("org.apache.parquet.hadoop.ClientSideMetadataSplitStrategy") .getDeclaredMethods.find(_.getName == "generateSplits").getOrElse( sys.error(s"Failed to reflectively invoke ClientSideMetadataSplitStrategy.generateSplits")) generateSplits.setAccessible(true) @@ -612,7 +612,7 @@ private[parquet] class FilteringParquetRowInputFormat // https://github.com/apache/incubator-parquet-mr/pull/17 // is resolved val generateSplits = - Class.forName("parquet.hadoop.TaskSideMetadataSplitStrategy") + Class.forName("org.apache.parquet.hadoop.TaskSideMetadataSplitStrategy") .getDeclaredMethods.find(_.getName == "generateTaskSideMDSplits").getOrElse( sys.error( s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits")) @@ -664,7 +664,7 @@ private[parquet] object FileSystemHelper { s"ParquetTableOperations: path $path does not exist or is not a directory") } fs.globStatus(path) - .flatMap { status => if(status.isDir) fs.listStatus(status.getPath) else List(status) } + .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } .map(_.getPath) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 70a220cc43ab9..89db408b1c382 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.parquet import java.util.{HashMap => JHashMap} import org.apache.hadoop.conf.Configuration -import parquet.column.ParquetProperties -import parquet.hadoop.ParquetOutputFormat -import parquet.hadoop.api.ReadSupport.ReadContext -import parquet.hadoop.api.{ReadSupport, WriteSupport} -import parquet.io.api._ -import parquet.schema.MessageType +import org.apache.parquet.column.ParquetProperties +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{ReadSupport, WriteSupport} +import org.apache.parquet.io.api._ +import org.apache.parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 1dc819b5d7b9b..ba2a35b74ef82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -19,26 +19,25 @@ package org.apache.spark.sql.parquet import java.io.IOException -import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConversions._ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job -import parquet.format.converter.ParquetMetadataConverter -import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} -import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} -import parquet.schema.Type.Repetition -import parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} +import org.apache.parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} +import org.apache.parquet.schema.Type.Repetition +import org.apache.parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} +import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ -import org.apache.spark.{Logging, SparkException} -// Implicits -import scala.collection.JavaConversions._ /** A class representing Parquet info fields we care about, for passing back to Parquet */ private[parquet] case class ParquetTypeInfo( @@ -73,13 +72,12 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType case ParquetPrimitiveTypeName.INT96 => // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - sys.error("Potential loss of precision: cannot convert INT96") + throw new AnalysisException("Potential loss of precision: cannot convert INT96") case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) => // TODO: for now, our reader only supports decimals that fit in a Long DecimalType(decimalInfo.getPrecision, decimalInfo.getScale) - case _ => sys.error( - s"Unsupported parquet datatype $parquetType") + case _ => throw new AnalysisException(s"Unsupported parquet datatype $parquetType") } } @@ -371,7 +369,7 @@ private[parquet] object ParquetTypesConverter extends Logging { parquetKeyType, parquetValueType) } - case _ => sys.error(s"Unsupported datatype $ctype") + case _ => throw new AnalysisException(s"Unsupported datatype $ctype") } } } @@ -403,7 +401,7 @@ private[parquet] object ParquetTypesConverter extends Logging { def convertFromString(string: String): Seq[Attribute] = { Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { case s: StructType => s.toAttributes - case other => sys.error(s"Can convert $string to row") + case other => throw new AnalysisException(s"Can convert $string to row") } } @@ -411,8 +409,8 @@ private[parquet] object ParquetTypesConverter extends Logging { // ,;{}()\n\t= and space character are special characters in Parquet schema schema.map(_.name).foreach { name => if (name.matches(".*[ ,;{}()\n\t=].*")) { - sys.error( - s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\n\t=". + throw new AnalysisException( + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". |Please use alias to rename it. """.stripMargin.split("\n").mkString(" ")) } @@ -489,7 +487,7 @@ private[parquet] object ParquetTypesConverter extends Logging { val children = fs .globStatus(path) - .flatMap { status => if(status.isDir) fs.listStatus(status.getPath) else List(status) } + .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } .filterNot { status => val name = status.getPath.getName (name(0) == '.' || name(0) == '_') && name != ParquetFileWriter.PARQUET_METADATA_FILE diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index cb1e60883df1e..7af4eb1ca4716 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import java.net.URI import java.util.{List => JList} import scala.collection.JavaConversions._ @@ -28,16 +29,17 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import parquet.filter2.predicate.FilterApi -import parquet.hadoop._ -import parquet.hadoop.metadata.CompressionCodecName -import parquet.hadoop.util.ContextUtil +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.hadoop.util.ContextUtil import org.apache.spark.{Partition => SparkPartition, SerializableWritable, Logging, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD._ import org.apache.spark.rdd.RDD +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLConf, SQLContext} @@ -82,7 +84,7 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext case partFilePattern(id) => id.toInt case name if name.startsWith("_") => 0 case name if name.startsWith(".") => 0 - case name => sys.error( + case name => throw new AnalysisException( s"Trying to write Parquet files to directory $outputPath, " + s"but found items with illegal name '$name'.") }.reduceOption(_ max _).getOrElse(0) @@ -154,7 +156,7 @@ private[sql] class ParquetRelation2( meta } - override def equals(other: scala.Any): Boolean = other match { + override def equals(other: Any): Boolean = other match { case that: ParquetRelation2 => val schemaEquality = if (shouldMergeSchemas) { this.shouldMergeSchemas == that.shouldMergeSchemas @@ -189,7 +191,7 @@ private[sql] class ParquetRelation2( } } - override def dataSchema: StructType = metadataCache.dataSchema + override def dataSchema: StructType = maybeDataSchema.getOrElse(metadataCache.dataSchema) override private[sql] def refresh(): Unit = { super.refresh() @@ -210,6 +212,13 @@ private[sql] class ParquetRelation2( classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) + if (conf.get("spark.sql.parquet.output.committer.class") == null) { + logInfo("Using default output committer for Parquet: " + + classOf[ParquetOutputCommitter].getCanonicalName) + } else { + logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName) + } + conf.setClass( SQLConf.OUTPUT_COMMITTER_CLASS, committerClass, @@ -282,21 +291,28 @@ private[sql] class ParquetRelation2( val cacheMetadata = useMetadataCache @transient val cachedStatuses = inputFiles.map { f => - // In order to encode the authority of a Path containing special characters such as /, - // we need to use the string returned by the URI of the path to create a new Path. - val pathWithAuthority = new Path(f.getPath.toUri.toString) - + // In order to encode the authority of a Path containing special characters such as '/' + // (which does happen in some S3N credentials), we need to use the string returned by the + // URI of the path to create a new Path. + val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) new FileStatus( f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, - f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithAuthority) + f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) }.toSeq @transient val cachedFooters = footers.map { f => // In order to encode the authority of a Path containing special characters such as /, // we need to use the string returned by the URI of the path to create a new Path. - new Footer(new Path(f.getFile.toUri.toString), f.getParquetMetadata) + new Footer(escapePathUserInfo(f.getFile), f.getParquetMetadata) }.toSeq + private def escapePathUserInfo(path: Path): Path = { + val uri = path.toUri + new Path(new URI( + uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath, + uri.getQuery, uri.getFragment)) + } + // Overridden so we can inject our own cached files statuses. override def getPartitions: Array[SparkPartition] = { val inputFormat = if (cacheMetadata) { @@ -372,12 +388,13 @@ private[sql] class ParquetRelation2( // time-consuming. if (dataSchema == null) { dataSchema = { - val dataSchema0 = - maybeDataSchema - .orElse(readSchema()) - .orElse(maybeMetastoreSchema) - .getOrElse(sys.error("Failed to get the schema.")) - + val dataSchema0 = maybeDataSchema + .orElse(readSchema()) + .orElse(maybeMetastoreSchema) + .getOrElse(throw new AnalysisException( + s"Failed to discover schema of Parquet file(s) in the following location(s):\n" + + paths.mkString("\n\t"))) + // If this Parquet relation is converted from a Hive Metastore table, must reconcile case // case insensitivity issue and possible schema mismatch (probably caused by schema // evolution). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala index 70bcca7526aae..4d5ed211ad0c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.parquet.timestamp import java.nio.{ByteBuffer, ByteOrder} -import parquet.Preconditions -import parquet.io.api.{Binary, RecordConsumer} +import org.apache.parquet.Preconditions +import org.apache.parquet.io.api.{Binary, RecordConsumer} private[parquet] class NanoTime extends Serializable { private var julianDay = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index dacd967cff856..c6a4dabbab05e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -309,7 +309,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { output: Seq[Attribute], rdd: RDD[Row]): SparkPlan = { val converted = if (relation.needConversion) { - execution.RDDConversions.rowToRowRdd(rdd, relation.schema) + execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { rdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala index e0ead23d786f9..7a2b5b949dd4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.sources -import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} +import java.lang.{Double => JDouble, Float => JFloat, Integer => JInteger, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.fs.Path +import org.apache.hadoop.util.Shell import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} @@ -71,10 +72,11 @@ private[sql] object PartitioningUtils { */ private[sql] def parsePartitions( paths: Seq[Path], - defaultPartitionName: String): PartitionSpec = { + defaultPartitionName: String, + typeInference: Boolean): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. val pathsWithPartitionValues = paths.flatMap { path => - parsePartition(path, defaultPartitionName).map(path -> _) + parsePartition(path, defaultPartitionName, typeInference).map(path -> _) } if (pathsWithPartitionValues.isEmpty) { @@ -123,7 +125,8 @@ private[sql] object PartitioningUtils { */ private[sql] def parsePartition( path: Path, - defaultPartitionName: String): Option[PartitionValues] = { + defaultPartitionName: String, + typeInference: Boolean): Option[PartitionValues] = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null @@ -136,7 +139,7 @@ private[sql] object PartitioningUtils { return None } - val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName) + val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName, typeInference) maybeColumn.foreach(columns += _) chopped = chopped.getParent finished = maybeColumn.isEmpty || chopped.getParent == null @@ -152,7 +155,8 @@ private[sql] object PartitioningUtils { private def parsePartitionColumn( columnSpec: String, - defaultPartitionName: String): Option[(String, Literal)] = { + defaultPartitionName: String, + typeInference: Boolean): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { None @@ -163,7 +167,7 @@ private[sql] object PartitioningUtils { val rawColumnValue = columnSpec.drop(equalSignIndex + 1) assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") - val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName) + val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName, typeInference) Some(columnName -> literal) } } @@ -174,7 +178,7 @@ private[sql] object PartitioningUtils { * {{{ * NullType -> * IntegerType -> LongType -> - * FloatType -> DoubleType -> DecimalType.Unlimited -> + * DoubleType -> DecimalType.Unlimited -> * StringType * }}} */ @@ -186,7 +190,7 @@ private[sql] object PartitioningUtils { Seq.empty } else { assert(distinctPartitionsColNames.size == 1, { - val list = distinctPartitionsColNames.mkString("\t", "\n", "") + val list = distinctPartitionsColNames.mkString("\t", "\n\t", "") s"Conflicting partition column names detected:\n$list" }) @@ -204,25 +208,36 @@ private[sql] object PartitioningUtils { } /** - * Converts a string to a `Literal` with automatic type inference. Currently only supports - * [[IntegerType]], [[LongType]], [[FloatType]], [[DoubleType]], [[DecimalType.Unlimited]], and + * Converts a string to a [[Literal]] with automatic type inference. Currently only supports + * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.Unlimited]], and * [[StringType]]. */ private[sql] def inferPartitionColumnValue( raw: String, - defaultPartitionName: String): Literal = { - // First tries integral types - Try(Literal.create(Integer.parseInt(raw), IntegerType)) - .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) - // Then falls back to fractional types - .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType))) - .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) - .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) - // Then falls back to string - .getOrElse { - if (raw == defaultPartitionName) Literal.create(null, NullType) - else Literal.create(raw, StringType) + defaultPartitionName: String, + typeInference: Boolean): Literal = { + if (typeInference) { + // First tries integral types + Try(Literal.create(Integer.parseInt(raw), IntegerType)) + .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) + // Then falls back to fractional types + .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) + .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) + // Then falls back to string + .getOrElse { + if (raw == defaultPartitionName) { + Literal.create(null, NullType) + } else { + Literal.create(unescapePathName(raw), StringType) + } + } + } else { + if (raw == defaultPartitionName) { + Literal.create(null, NullType) + } else { + Literal.create(unescapePathName(raw), StringType) } + } } private val upCastingOrder: Seq[DataType] = @@ -243,4 +258,77 @@ private[sql] object PartitioningUtils { Literal.create(Cast(l, desiredType).eval(), desiredType) } } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). + ////////////////////////////////////////////////////////////////////////////////////////////////// + + val charToEscape = { + val bitSet = new java.util.BitSet(128) + + /** + * ASCII 01-1F are HTTP control characters that need to be escaped. + * \u000A and \u000D are \n and \r, respectively. + */ + val clist = Array( + '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', + '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', + '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', + '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', + '{', '[', ']', '^') + + clist.foreach(bitSet.set(_)) + + if (Shell.WINDOWS) { + Array(' ', '<', '>', '|').foreach(bitSet.set(_)) + } + + bitSet + } + + def needsEscaping(c: Char): Boolean = { + c >= 0 && c < charToEscape.size() && charToEscape.get(c) + } + + def escapePathName(path: String): String = { + val builder = new StringBuilder() + path.foreach { c => + if (needsEscaping(c)) { + builder.append('%') + builder.append(f"${c.asInstanceOf[Int]}%02x") + } else { + builder.append(c) + } + } + + builder.toString() + } + + def unescapePathName(path: String): String = { + val sb = new StringBuilder + var i = 0 + + while (i < path.length) { + val c = path.charAt(i) + if (c == '%' && i + 2 < path.length) { + val code: Int = try { + Integer.valueOf(path.substring(i + 1, i + 3), 16) + } catch { case e: Exception => + -1: Integer + } + if (code >= 0) { + sb.append(code.asInstanceOf[Char]) + i += 3 + } else { + sb.append(c) + i += 1 + } + } else { + sb.append(c) + i += 1 + } + } + + sb.toString() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala index a74a98631da35..ebad0c1564ec0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala @@ -216,7 +216,7 @@ private[sql] class SqlNewHadoopRDD[K, V]( override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => + case Some(c) => try { val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] Some(HadoopRDD.convertSplitLocationInfo(infos)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index fbd98ef0380e1..c94199bfcd233 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -24,18 +24,19 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} -import org.apache.hadoop.util.Shell -import parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.util.ContextUtil import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.{SQLConf, DataFrame, SQLContext, SaveMode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext, SaveMode} private[sql] case class InsertIntoDataSource( logicalRelation: LogicalRelation, @@ -94,10 +95,19 @@ private[sql] case class InsertIntoHadoopFsRelation( // We create a DataFrame by applying the schema of relation to the data to make sure. // We are writing data based on the expected schema, - val df = sqlContext.createDataFrame( - DataFrame(sqlContext, query).queryExecution.toRdd, - relation.schema, - needsConversion = false) + val df = { + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). We + // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can + // safely apply the schema of r.schema to the data. + val project = Project( + relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query) + + sqlContext.createDataFrame( + DataFrame(sqlContext, project).queryExecution.toRdd, + relation.schema, + needsConversion = false) + } val partitionColumns = relation.partitionColumns.fieldNames if (partitionColumns.isEmpty) { @@ -117,8 +127,11 @@ private[sql] case class InsertIntoHadoopFsRelation( val needsConversion = relation.needConversion val dataSchema = relation.dataSchema + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + writerContainer.driverSideSetup() + try { - writerContainer.driverSideSetup() df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) writerContainer.commitJob() relation.refresh() @@ -129,9 +142,10 @@ private[sql] case class InsertIntoHadoopFsRelation( } def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = { - writerContainer.executorSideSetup(taskContext) - + // If anything below fails, we should abort the task. try { + writerContainer.executorSideSetup(taskContext) + if (needsConversion) { val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) while (iterator.hasNext) { @@ -144,6 +158,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.outputWriterForRow(row).write(row) } } + writerContainer.commitTask() } catch { case cause: Throwable => logError("Aborting task.", cause) @@ -181,8 +196,11 @@ private[sql] case class InsertIntoHadoopFsRelation( val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name)) val codegenEnabled = df.sqlContext.conf.codegenEnabled + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + writerContainer.driverSideSetup() + try { - writerContainer.driverSideSetup() df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) writerContainer.commitJob() relation.refresh() @@ -193,30 +211,39 @@ private[sql] case class InsertIntoHadoopFsRelation( } def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = { - writerContainer.executorSideSetup(taskContext) - - val partitionProj = newProjection(codegenEnabled, partitionOutput, output) - val dataProj = newProjection(codegenEnabled, dataOutput, output) - - if (needsConversion) { - val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) - while (iterator.hasNext) { - val row = iterator.next() - val partitionPart = partitionProj(row) - val dataPart = dataProj(row) - val convertedDataPart = converter(dataPart).asInstanceOf[Row] - writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart) - } - } else { - while (iterator.hasNext) { - val row = iterator.next() - val partitionPart = partitionProj(row) - val dataPart = dataProj(row) - writerContainer.outputWriterForRow(partitionPart).write(dataPart) + // If anything below fails, we should abort the task. + try { + writerContainer.executorSideSetup(taskContext) + + val partitionProj = newProjection(codegenEnabled, partitionOutput, output) + val dataProj = newProjection(codegenEnabled, dataOutput, output) + + if (needsConversion) { + val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) + while (iterator.hasNext) { + val row = iterator.next() + val partitionPart = partitionProj(row) + val dataPart = dataProj(row) + val convertedDataPart = converter(dataPart).asInstanceOf[Row] + writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart) + } + } else { + val partitionSchema = StructType.fromAttributes(partitionOutput) + val converter = CatalystTypeConverters.createToScalaConverter(partitionSchema) + while (iterator.hasNext) { + val row = iterator.next() + val partitionPart = converter(partitionProj(row)).asInstanceOf[Row] + val dataPart = dataProj(row) + writerContainer.outputWriterForRow(partitionPart).write(dataPart) + } } - } - writerContainer.commitTask() + writerContainer.commitTask() + } catch { case cause: Throwable => + logError("Aborting task.", cause) + writerContainer.abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } } } @@ -270,8 +297,17 @@ private[sql] abstract class BaseWriterContainer( def driverSideSetup(): Unit = { setupIDs(0, 0, 0) setupConf() - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + + // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor + // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, + // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext. + // + // Also, the `prepareJobForWrite` call must happen before initializing output format and output + // committer, since their initialization involve the job configuration, which can be potentially + // decorated in `prepareJobForWrite`. outputWriterFactory = relation.prepareJobForWrite(job) + taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + outputFormatClass = job.getOutputFormatClass outputCommitter = newOutputCommitter(taskAttemptContext) outputCommitter.setupJob(jobContext) @@ -299,6 +335,8 @@ private[sql] abstract class BaseWriterContainer( SQLConf.OUTPUT_COMMITTER_CLASS, null, classOf[OutputCommitter]) Option(committerClass).map { clazz => + logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") + // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat // has an associated output committer. To override this output committer, // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. @@ -318,7 +356,9 @@ private[sql] abstract class BaseWriterContainer( }.getOrElse { // If output committer class is not set, we will use the one associated with the // file output format. - outputFormatClass.newInstance().getOutputCommitter(context) + val outputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) + logInfo(s"Using output committer class ${outputCommitter.getClass.getCanonicalName}") + outputCommitter } } @@ -347,7 +387,9 @@ private[sql] abstract class BaseWriterContainer( } def abortTask(): Unit = { - outputCommitter.abortTask(taskAttemptContext) + if (outputCommitter != null) { + outputCommitter.abortTask(taskAttemptContext) + } logError(s"Task attempt $taskAttemptId aborted.") } @@ -357,7 +399,9 @@ private[sql] abstract class BaseWriterContainer( } def abortJob(): Unit = { - outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) + if (outputCommitter != null) { + outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) + } logError(s"Job $jobId aborted.") } } @@ -378,6 +422,7 @@ private[sql] class DefaultWriterContainer( override def commitTask(): Unit = { try { + assert(writer != null, "OutputWriter instance should have been initialized") writer.close() super.commitTask() } catch { @@ -389,7 +434,9 @@ private[sql] class DefaultWriterContainer( override def abortTask(): Unit = { try { - writer.close() + if (writer != null) { + writer.close() + } } finally { super.abortTask() } @@ -416,7 +463,7 @@ private[sql] class DynamicPartitionWriterContainer( val valueString = if (string == null || string.isEmpty) { defaultPartitionName } else { - DynamicPartitionWriterContainer.escapePathName(string) + PartitioningUtils.escapePathName(string) } s"/$col=$valueString" }.mkString.stripPrefix(Path.SEPARATOR) @@ -433,6 +480,7 @@ private[sql] class DynamicPartitionWriterContainer( override def commitTask(): Unit = { try { outputWriters.values.foreach(_.close()) + outputWriters.clear() super.commitTask() } catch { case cause: Throwable => super.abortTask() @@ -443,55 +491,9 @@ private[sql] class DynamicPartitionWriterContainer( override def abortTask(): Unit = { try { outputWriters.values.foreach(_.close()) + outputWriters.clear() } finally { super.abortTask() } } } - -private[sql] object DynamicPartitionWriterContainer { - ////////////////////////////////////////////////////////////////////////////////////////////////// - // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). - ////////////////////////////////////////////////////////////////////////////////////////////////// - - val charToEscape = { - val bitSet = new java.util.BitSet(128) - - /** - * ASCII 01-1F are HTTP control characters that need to be escaped. - * \u000A and \u000D are \n and \r, respectively. - */ - val clist = Array( - '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', - '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', - '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', - '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', - '{', '[', ']', '^') - - clist.foreach(bitSet.set(_)) - - if (Shell.WINDOWS) { - Array(' ', '<', '>', '|').foreach(bitSet.set(_)) - } - - bitSet - } - - def needsEscaping(c: Char): Boolean = { - c >= 0 && c < charToEscape.size() && charToEscape.get(c) - } - - def escapePathName(path: String): String = { - val builder = new StringBuilder() - path.foreach { c => - if (DynamicPartitionWriterContainer.needsEscaping(c)) { - builder.append('%') - builder.append(f"${c.asInstanceOf[Int]}%02x") - } else { - builder.append(c) - } - } - - builder.toString() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index ca30b8e74626f..20afd60cb7767 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.catalyst.AbstractSparkSQLParser -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.RunnableCommand @@ -130,7 +130,7 @@ private[sql] class DDLParser( } } - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" /* * describe [extended] table avroTable @@ -138,7 +138,7 @@ private[sql] class DDLParser( */ protected lazy val describeTable: Parser[LogicalPlan] = (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { - case e ~ db ~ tbl => + case e ~ db ~ tbl => val tblIdentifier = db match { case Some(dbName) => Seq(dbName, tbl) @@ -171,7 +171,7 @@ private[sql] class DDLParser( } protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k,v) } + optionName ~ stringLit ^^ { case k ~ v => (k, v) } protected lazy val column: Parser[StructField] = ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => @@ -239,7 +239,7 @@ private[sql] object ResolvedDataSource { Some(partitionColumnsSchema(schema, partitionColumns)) } - val caseInsensitiveOptions= new CaseInsensitiveMap(options) + val caseInsensitiveOptions = new CaseInsensitiveMap(options) val paths = { val patternPath = new Path(caseInsensitiveOptions("path")) SparkHadoopUtil.get.globPath(patternPath).map(_.toString).toArray @@ -322,19 +322,13 @@ private[sql] object ResolvedDataSource { Some(partitionColumnsSchema(data.schema, partitionColumns)), caseInsensitiveOptions) - // For partitioned relation r, r.schema's column ordering is different with the column - // ordering of data.logicalPlan. We need a Project to adjust the ordering. - // So, inside InsertIntoHadoopFsRelation, we can safely apply the schema of r.schema to - // the data. - val project = - Project( - r.schema.map(field => new UnresolvedAttribute(Seq(field.name))), - data.logicalPlan) - + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. sqlContext.executePlan( InsertIntoHadoopFsRelation( r, - project, + data.logicalPlan, mode)).toRdd r case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index aaabbadcd651b..d1547fb1e4abb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -31,7 +31,7 @@ import org.apache.spark.SerializableWritable import org.apache.spark.sql.{Row, _} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType /** * ::DeveloperApi:: @@ -93,7 +93,7 @@ trait SchemaRelationProvider { } /** - * ::DeveloperApi:: + * ::Experimental:: * Implemented by objects that produce relations for a specific kind of data source * with a given schema and partitioned columns. When Spark SQL is given a DDL operation with a * USING clause specified (to specify the implemented [[HadoopFsRelationProvider]]), a user defined @@ -115,6 +115,7 @@ trait SchemaRelationProvider { * * @since 1.4.0 */ +@Experimental trait HadoopFsRelationProvider { /** * Returns a new base relation with the given parameters, a user defined schema, and a list of @@ -378,24 +379,33 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] def refresh(): Unit = { + // We don't filter files/directories whose name start with "_" except "_temporary" here, as + // specific data sources may take advantages over them (e.g. Parquet _metadata and + // _common_metadata files). "_temporary" directories are explicitly ignored since failed + // tasks/jobs may leave partial/corrupted data files there. def listLeafFilesAndDirs(fs: FileSystem, status: FileStatus): Set[FileStatus] = { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - val leafDirs = if (dirs.isEmpty) Set(status) else Set.empty[FileStatus] - files.toSet ++ leafDirs ++ dirs.flatMap(dir => listLeafFilesAndDirs(fs, dir)) + if (status.getPath.getName.toLowerCase == "_temporary") { + Set.empty + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + val leafDirs = if (dirs.isEmpty) Set(status) else Set.empty[FileStatus] + files.toSet ++ leafDirs ++ dirs.flatMap(dir => listLeafFilesAndDirs(fs, dir)) + } } leafFiles.clear() - // We don't filter files/directories like _temporary/_SUCCESS here, as specific data sources - // may take advantages over them (e.g. Parquet _metadata and _common_metadata files). val statuses = paths.flatMap { path => val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) Try(fs.getFileStatus(qualified)).toOption.toArray.flatMap(listLeafFilesAndDirs(fs, _)) + }.filterNot { status => + // SPARK-8037: Ignores files like ".DS_Store" and other hidden files/directories + status.getPath.getName.startsWith(".") } - val (dirs, files) = statuses.partition(_.isDir) + val files = statuses.filterNot(_.isDir) leafFiles ++= files.map(f => f.getPath -> f).toMap leafDirToChildrenFiles ++= files.groupBy(_.getPath.getParent) } @@ -425,8 +435,9 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio // partition values. userDefinedPartitionColumns.map { partitionSchema => val spec = discoverPartitions() + val partitionColumnTypes = spec.partitionColumns.map(_.dataType) val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => - val literals = values.toSeq.zip(spec.partitionColumns.map(_.dataType)).map { + val literals = values.toSeq.zip(partitionColumnTypes).map { case (value, dataType) => Literal.create(value, dataType) } val castedValues = partitionSchema.zip(literals).map { case (field, literal) => @@ -480,9 +491,11 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } private def discoverPartitions(): PartitionSpec = { + val typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled() // We use leaf dirs containing data files to discover the schema. val leafDirs = fileStatusCache.leafDirToChildrenFiles.keys.toSeq - PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME) + PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference) } /** @@ -493,7 +506,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio */ override lazy val schema: StructType = { val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet - StructType(dataSchema ++ partitionSpec.partitionColumns.filterNot { column => + StructType(dataSchema ++ partitionColumns.filterNot { column => dataSchemaColumnNames.contains(column.name.toLowerCase) }) } diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index 28e90b9520b2c..12fb128149d32 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -36,11 +36,11 @@ log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n log4j.appender.FA.Threshold = INFO # Some packages are noisy for no good reason. -log4j.additivity.parquet.hadoop.ParquetRecordReader=false -log4j.logger.parquet.hadoop.ParquetRecordReader=OFF +log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false +log4j.logger.org.apache.parquet.hadoop.ParquetRecordReader=OFF -log4j.additivity.parquet.hadoop.ParquetOutputCommitter=false -log4j.logger.parquet.hadoop.ParquetOutputCommitter=OFF +log4j.additivity.org.apache.parquet.hadoop.ParquetOutputCommitter=false +log4j.logger.org.apache.parquet.hadoop.ParquetOutputCommitter=OFF log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF @@ -52,5 +52,5 @@ log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF # Parquet related logging -log4j.logger.parquet.hadoop=WARN +log4j.logger.org.apache.parquet.hadoop=WARN log4j.logger.org.apache.spark.sql.parquet=INFO diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 0772e5e187425..72e60d9aa75cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,8 +25,6 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.storage.{RDDBlockId, StorageLevel} case class BigData(s: String) @@ -34,8 +32,12 @@ case class BigData(s: String) class CachedTableSuite extends QueryTest { TestData // Load test tables. + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.sql + def rddIdOf(tableName: String): Int = { - val executedPlan = table(tableName).queryExecution.executedPlan + val executedPlan = ctx.table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -45,47 +47,47 @@ class CachedTableSuite extends QueryTest { } def isMaterialized(rddId: Int): Boolean = { - sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - cacheTable("tempTable") + ctx.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - uncacheTable("tempTable") + ctx.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != cacheManager.lookupCachedData(testData)) + assert(None != ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - uncacheTable("tempTable") + ctx.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - cacheTable("tempTable1") + ctx.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - uncacheTable("tempTable2") + ctx.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -93,103 +95,103 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(table("bigData").count() === 200000L) - table("bigData").unpersist(blocking = true) + ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(ctx.table("bigData").count() === 200000L) + ctx.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - table("testData").cache() - assertCached(table("testData")) - table("testData").unpersist(blocking = true) + ctx.table("testData").cache() + assertCached(ctx.table("testData")) + ctx.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - table("testData").cache() - table("testData").count() - table("testData").unpersist(blocking = true) - assertCached(table("testData"), 0) + ctx.table("testData").cache() + ctx.table("testData").count() + ctx.table("testData").unpersist(blocking = true) + assertCached(ctx.table("testData"), 0) } test("isCached") { - cacheTable("testData") + ctx.cacheTable("testData") - assertCached(table("testData")) - assert(table("testData").queryExecution.withCachedData match { + assertCached(ctx.table("testData")) + assert(ctx.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - uncacheTable("testData") - assert(!isCached("testData")) - assert(table("testData").queryExecution.withCachedData match { + ctx.uncacheTable("testData") + assert(!ctx.isCached("testData")) + assert(ctx.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - cacheTable("testData") - assertCached(table("testData")) + ctx.cacheTable("testData") + assertCached(ctx.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - cacheTable("testData") + ctx.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r }.size } - uncacheTable("testData") + ctx.uncacheTable("testData") } test("read from cached table and uncache") { - cacheTable("testData") - checkAnswer(table("testData"), testData.collect().toSeq) - assertCached(table("testData")) + ctx.cacheTable("testData") + checkAnswer(ctx.table("testData"), testData.collect().toSeq) + assertCached(ctx.table("testData")) - uncacheTable("testData") - checkAnswer(table("testData"), testData.collect().toSeq) - assertCached(table("testData"), 0) + ctx.uncacheTable("testData") + checkAnswer(ctx.table("testData"), testData.collect().toSeq) + assertCached(ctx.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - uncacheTable("testData") + ctx.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - cacheTable("selectStar") + ctx.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - uncacheTable("selectStar") + ctx.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - cacheTable("testData") + ctx.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - uncacheTable("testData") + ctx.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(table("testData")) + assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") assert( @@ -197,7 +199,7 @@ class CachedTableSuite extends QueryTest { "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!isCached("testData"), "Table 'testData' should not be cached") + assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -206,14 +208,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(table("testCacheTable")) + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") + ctx.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -221,14 +223,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(table("testCacheTable")) + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") + ctx.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -236,7 +238,7 @@ class CachedTableSuite extends QueryTest { test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(table("testData")) + assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") assert( @@ -248,7 +250,7 @@ class CachedTableSuite extends QueryTest { isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - uncacheTable("testData") + ctx.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -256,7 +258,7 @@ class CachedTableSuite extends QueryTest { test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -265,38 +267,38 @@ class CachedTableSuite extends QueryTest { test("Drops temporary table") { testData.select('key).registerTempTable("t1") - table("t1") - dropTempTable("t1") - assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) + ctx.table("t1") + ctx.dropTempTable("t1") + assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - cacheTable("t1") + ctx.cacheTable("t1") - assert(isCached("t1")) - assert(isCached("t2")) + assert(ctx.isCached("t1")) + assert(ctx.isCached("t2")) - dropTempTable("t1") - assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) - assert(!isCached("t2")) + ctx.dropTempTable("t1") + assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) + assert(!ctx.isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - cacheTable("t1") - cacheTable("t2") - clearCache() - assert(cacheManager.isEmpty) + ctx.cacheTable("t1") + ctx.cacheTable("t2") + ctx.clearCache() + assert(ctx.cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - cacheTable("t1") - cacheTable("t2") + ctx.cacheTable("t1") + ctx.cacheTable("t2") sql("Clear CACHE") - assert(cacheManager.isEmpty) + assert(ctx.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { @@ -305,8 +307,8 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - cacheTable("t1") - cacheTable("t2") + ctx.cacheTable("t1") + ctx.cacheTable("t2") assert((accsSize + 2) == Accumulators.originals.size) } @@ -317,8 +319,8 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - uncacheTable("t1") - uncacheTable("t2") + ctx.uncacheTable("t1") + ctx.uncacheTable("t2") assert((accsSize - 2) == Accumulators.originals.size) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 9bdf201b3be7c..4f5484f1368d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -19,23 +19,31 @@ package org.apache.spark.sql import org.scalatest.Matchers._ +import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + test("alias") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + assert(df.select(df("a").as("b")).columns.head === "b") + assert(df.select(df("a").alias("b")).columns.head === "b") + } + test("single explode") { - val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( df.select(explode('intList)), Row(1) :: Row(2) :: Row(3) :: Nil) } test("explode and other columns") { - val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( df.select($"a", explode('intList)), @@ -45,13 +53,13 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer( df.select($"*", explode('intList)), - Row(1, Seq(1,2,3), 1) :: - Row(1, Seq(1,2,3), 2) :: - Row(1, Seq(1,2,3), 3) :: Nil) + Row(1, Seq(1, 2, 3), 1) :: + Row(1, Seq(1, 2, 3), 2) :: + Row(1, Seq(1, 2, 3), 3) :: Nil) } test("aliased explode") { - val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( df.select(explode('intList).as('int)).select('int), @@ -79,7 +87,7 @@ class ColumnExpressionSuite extends QueryTest { } test("self join explode") { - val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") val exploded = df.select(explode('intList).as('i)) checkAnswer( @@ -206,7 +214,7 @@ class ColumnExpressionSuite extends QueryTest { } test("!==") { - val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( + val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -267,7 +275,7 @@ class ColumnExpressionSuite extends QueryTest { } test("between") { - val testData = TestSQLContext.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (0, 1, 2) :: (1, 2, 3) :: (2, 1, 0) :: @@ -280,7 +288,7 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer) } - val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( + val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -406,7 +414,7 @@ class ColumnExpressionSuite extends QueryTest { test("monotonicallyIncreasingId") { // Make sure we have 2 partitions, each with 2 records. - val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + val df = ctx.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -416,7 +424,7 @@ class ColumnExpressionSuite extends QueryTest { } test("sparkPartitionId") { - val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") + val df = ctx.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") checkAnswer( df.select(sparkPartitionId()), Row(0) @@ -446,13 +454,51 @@ class ColumnExpressionSuite extends QueryTest { } test("rand") { - val randCol = testData.select('key, rand(5L).as("rand")) + val randCol = testData.select($"key", rand(5L).as("rand")) randCol.columns.length should be (2) val rows = randCol.collect() rows.foreach { row => assert(row.getDouble(1) <= 1.0) assert(row.getDouble(1) >= 0.0) } + + def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { + val projects = df.queryExecution.executedPlan.collect { + case project: Project => project + } + assert(projects.size === expectedNumProjects) + } + + // We first create a plan with two Projects. + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // Because Rand function is not deterministic, the column rand is not deterministic. + // So, in the optimizer, we will not collapse Project [rand + 1 AS rand1, rand - 1 AS rand2] + // and Project [key, Rand 5 AS rand]. The final plan still has two Projects. + val dfWithTwoProjects = + testData + .select($"key", (rand(5L) + 1).as("rand")) + .select(($"rand" + 1).as("rand1"), ($"rand" - 1).as("rand2")) + checkNumProjects(dfWithTwoProjects, 2) + + // Now, we add one more project rand1 - rand2 on top of the query plan. + // Since rand1 and rand2 are deterministic (they basically apply +/- to the generated + // rand value), we can collapse rand1 - rand2 to the Project generating rand1 and rand2. + // So, the plan will be optimized from ... + // Project [(rand1 - rand2) AS (rand1 - rand2)] + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // to ... + // Project [((rand + 1 AS rand1) - (rand - 1 AS rand2)) AS (rand1 - rand2)] + // Project [key, Rand 5 AS rand] + // LogicalRDD [key, value] + val dfWithThreeProjects = dfWithTwoProjects.select($"rand1" - $"rand2") + checkNumProjects(dfWithThreeProjects, 2) + dfWithThreeProjects.collect().foreach { row => + assert(row.getDouble(0) === 2.0 +- 0.0001) + } } test("randn") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 35a574f354741..790b405c72697 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types.DecimalType class DataFrameAggregateSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), @@ -67,12 +68,12 @@ class DataFrameAggregateSuite extends QueryTest { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "false") + ctx.conf.setConf("spark.sql.retainGroupColumns", "false") checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "true") + ctx.conf.setConf("spark.sql.retainGroupColumns", "true") } test("agg without groups") { @@ -148,12 +149,12 @@ class DataFrameAggregateSuite extends QueryTest { test("null count") { checkAnswer( testData3.groupBy('a).agg(count('b)), - Seq(Row(1,0), Row(2, 1)) + Seq(Row(1, 0), Row(2, 1)) ) checkAnswer( testData3.groupBy('a).agg(count('a + 'b)), - Seq(Row(1,0), Row(2, 1)) + Seq(Row(1, 0), Row(2, 1)) ) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b1e0faa310b68..53c2befb73702 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ /** @@ -27,6 +26,9 @@ import org.apache.spark.sql.types._ */ class DataFrameFunctionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") val row = df.select(array("a", "b")).first() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 2d2367d6e7292..fbb30706a4943 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.TestSQLContext.{sparkContext => sc} -import org.apache.spark.sql.test.TestSQLContext.implicits._ - - class DataFrameImplicitsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("RDD of tuples") { checkAnswer( - sc.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), (1 to 10).map(i => Row(i, i.toString))) } @@ -37,19 +36,19 @@ class DataFrameImplicitsSuite extends QueryTest { test("RDD[Int]") { checkAnswer( - sc.parallelize(1 to 10).toDF("intCol"), + ctx.sparkContext.parallelize(1 to 10).toDF("intCol"), (1 to 10).map(i => Row(i))) } test("RDD[Long]") { checkAnswer( - sc.parallelize(1L to 10L).toDF("longCol"), + ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"), (1L to 10L).map(i => Row(i))) } test("RDD[String]") { checkAnswer( - sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 787f3f175fea2..6165764632c29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ - class DataFrameJoinSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str") @@ -34,6 +34,15 @@ class DataFrameJoinSuite extends QueryTest { Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil) } + test("join - join using multiple columns") { + val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") + val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "int2")), + Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) + } + test("join - join using self join") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") @@ -49,7 +58,8 @@ class DataFrameJoinSuite extends QueryTest { checkAnswer( df1.join(df2, $"df1.key" === $"df2.key"), - sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) + ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + .collect().toSeq) } test("join - using aliases after self join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 41b4f02e6a294..495701d4f616c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ - class DataFrameNaFunctionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + def createDF(): DataFrame = { Seq[(String, java.lang.Integer, java.lang.Double)]( ("Bob", 16, 176.5), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 46b1845a9180c..0d3ff899dad72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.sql -import org.scalatest.FunSuite import org.scalatest.Matchers._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.SparkFunSuite -class DataFrameStatSuite extends FunSuite { - - val sqlCtx = TestSQLContext - def toLetter(i: Int): String = (i + 97).toChar.toString +class DataFrameStatSuite extends SparkFunSuite { + + private val sqlCtx = org.apache.spark.sql.test.TestSQLContext + import sqlCtx.implicits._ + + private def toLetter(i: Int): String = (i + 97).toChar.toString test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") @@ -74,10 +74,10 @@ class DataFrameStatSuite extends FunSuite { val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0)) assert(rows(0).get(0).toString === "0") assert(rows(0).getLong(1) === 2L) - assert(rows(0).get(2) === null) + assert(rows(0).get(2) === 0L) assert(rows(1).get(0).toString === "1") assert(rows(1).getLong(1) === 1L) - assert(rows(1).get(2) === null) + assert(rows(1).get(2) === 0L) assert(rows(2).get(0).toString === "2") assert(rows(2).getLong(1) === 2L) assert(rows(2).getLong(2) === 1L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0dcba80ef2a20..bb8621abe64ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -21,17 +21,19 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint} class DataFrameSuite extends QueryTest { import org.apache.spark.sql.TestData._ + lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("analysis error should be eagerly reported") { - val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis + val oldSetting = ctx.conf.dataFrameEagerAnalysis // Eager analysis. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") intercept[Exception] { testData.select('nonExistentName) } intercept[Exception] { @@ -45,11 +47,11 @@ class DataFrameSuite extends QueryTest { } // No more eager analysis once the flag is turned off - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") testData.select('nonExistentName) // Set the flag back to original value before this test. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } test("dataframe toString") { @@ -59,7 +61,7 @@ class DataFrameSuite extends QueryTest { } test("rename nested groupby") { - val df = Seq((1,(1,1))).toDF() + val df = Seq((1, (1, 1))).toDF() checkAnswer( df.groupBy("_1").agg(sum("_2._1")).toDF("key", "total"), @@ -67,12 +69,12 @@ class DataFrameSuite extends QueryTest { } test("invalid plan toString, debug mode") { - val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + val oldSetting = ctx.conf.dataFrameEagerAnalysis + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ - TestSQLContext.debug() + ctx.debug() val badPlan = testData.select('badColumn) @@ -81,7 +83,7 @@ class DataFrameSuite extends QueryTest { badPlan.toString) // Set the flag back to original value before this test. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } test("access complex data") { @@ -97,8 +99,8 @@ class DataFrameSuite extends QueryTest { } test("empty data frame") { - assert(TestSQLContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(TestSQLContext.emptyDataFrame.count() === 0) + assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(ctx.emptyDataFrame.count() === 0) } test("head and take") { @@ -211,23 +213,23 @@ class DataFrameSuite extends QueryTest { test("global sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) checkAnswer( testData2.orderBy(asc("a"), desc("b")), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( testData2.orderBy('a.asc, 'b.desc), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( testData2.orderBy('a.desc, 'b.desc), - Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) + Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( testData2.orderBy('a.desc, 'b.asc), - Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( arrayData.toDF().orderBy('data.getItem(0).asc), @@ -311,7 +313,7 @@ class DataFrameSuite extends QueryTest { } test("replace column using withColumn") { - val df2 = TestSQLContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -331,7 +333,52 @@ class DataFrameSuite extends QueryTest { checkAnswer( df, testData.collect().toSeq) - assert(df.schema.map(_.name) === Seq("key","value")) + assert(df.schema.map(_.name) === Seq("key", "value")) + } + + test("drop column using drop with column reference") { + val col = testData("key") + val df = testData.drop(col) + checkAnswer( + df, + testData.collect().map(x => Row(x.getString(1))).toSeq) + assert(df.schema.map(_.name) === Seq("value")) + } + + test("drop unknown column (no-op) with column reference") { + val col = Column("random") + val df = testData.drop(col) + checkAnswer( + df, + testData.collect().toSeq) + assert(df.schema.map(_.name) === Seq("key", "value")) + } + + test("drop unknown column with same name (no-op) with column reference") { + val col = Column("key") + val df = testData.drop(col) + checkAnswer( + df, + testData.collect().toSeq) + assert(df.schema.map(_.name) === Seq("key", "value")) + } + + test("drop column after join with duplicate columns using column reference") { + val newSalary = salary.withColumnRenamed("personId", "id") + val col = newSalary("id") + // this join will result in duplicate "id" columns + val joinedDf = person.join(newSalary, + person("id") === newSalary("id"), "inner") + // remove only the "id" column that was associated with newSalary + val df = joinedDf.drop(col) + checkAnswer( + df, + joinedDf.collect().map { + case Row(id: Int, name: String, age: Int, idToDrop: Int, salary: Double) => + Row(id, name, age, salary) + }.toSeq) + assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary")) + assert(df("id") == person("id")) } test("withColumnRenamed") { @@ -347,7 +394,7 @@ class DataFrameSuite extends QueryTest { test("randomSplit") { val n = 600 - val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -364,24 +411,24 @@ class DataFrameSuite extends QueryTest { test("describe") { val describeTestData = Seq( - ("Bob", 16, 176), + ("Bob", 16, 176), ("Alice", 32, 164), ("David", 60, 192), - ("Amy", 24, 180)).toDF("name", "age", "height") + ("Amy", 24, 180)).toDF("name", "age", "height") val describeResult = Seq( - Row("count", "4", "4"), - Row("mean", "33.0", "178.0"), - Row("stddev", "16.583123951777", "10.0"), - Row("min", "16", "164"), - Row("max", "60", "192")) + Row("count", "4", "4"), + Row("mean", "33.0", "178.0"), + Row("stddev", "16.583123951777", "10.0"), + Row("min", "16", "164"), + Row("max", "60", "192")) val emptyDescribeResult = Seq( - Row("count", "0", "0"), - Row("mean", null, null), - Row("stddev", null, null), - Row("min", null, null), - Row("max", null, null)) + Row("count", "0", "0"), + Row("mean", null, null), + Row("stddev", null, null), + Row("min", null, null), + Row("max", null, null)) def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) @@ -442,19 +489,22 @@ class DataFrameSuite extends QueryTest { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = TestSQLContext.createDataFrame(rowRDD, schema) + val df = ctx.createDataFrame(rowRDD, schema) df.rdd.collect() } test("SPARK-6899") { - val originalValue = TestSQLContext.conf.codegenEnabled - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true") - checkAnswer( - decimalData.agg(avg('a)), - Row(new java.math.BigDecimal(2.0))) - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + val originalValue = ctx.conf.codegenEnabled + ctx.setConf(SQLConf.CODEGEN_ENABLED, "true") + try{ + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + } finally { + ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } } test("SPARK-7133: Implement struct, array, and map field accessor") { @@ -465,14 +515,14 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( + val df = ctx.read.json(ctx.sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( + val df2 = ctx.read.json(ctx.sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -492,7 +542,7 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7324 dropDuplicates") { - val testData = TestSQLContext.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -540,41 +590,49 @@ class DataFrameSuite extends QueryTest { test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = TestSQLContext.range(0, 10, 1, 15).select("id") + val res1 = ctx.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = TestSQLContext.range(3, 15, 3, 2).select("id") + val res2 = ctx.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = TestSQLContext.range(1, -2).select("id") + val res3 = ctx.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = TestSQLContext.range(1, -2, -2, 6).select("id") + val res4 = ctx.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = TestSQLContext.range(-3, -8, -2, 1).select("id") + val res5 = ctx.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = TestSQLContext.range(-8, -4, 2, 1).select("id") + val res6 = ctx.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = TestSQLContext.range(-10, -9, -20, 1).select("id") + val res7 = ctx.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = TestSQLContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) + + // only end provided as argument + val res10 = ctx.range(10).select("id") + assert(res10.count == 10) + assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + + val res11 = ctx.range(-1).select("id") + assert(res11.count == 0) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 037d392c1f929..ffd26c4f5a7c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,27 +20,28 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData + lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.logicalPlanToSparkQuery + test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = planner.HashJoin(join) + val planned = ctx.planner.HashJoin(join) assert(planned.size === 1) } def assertJoin(sqlString: String, c: Class[_]): Any = { - val df = sql(sqlString) + val df = ctx.sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j @@ -61,9 +62,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("join operator selection") { - cacheManager.clearCache() + ctx.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -94,22 +95,22 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") Seq( ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) } } test("broadcasted hash join operator selection") { - cacheManager.clearCache() - sql("CACHE TABLE testData") + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") - val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), @@ -117,7 +118,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", @@ -126,17 +127,17 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) } - sql("UNCACHE TABLE testData") + ctx.sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = planner.HashJoin(join) + val planned = ctx.planner.HashJoin(join) assert(planned.size === 1) } @@ -167,10 +168,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val y = testData2.where($"a" === 1).as("y") checkAnswer( x.join(y).where($"x.a" === $"y.a"), - Row(1,1,1,1) :: - Row(1,1,1,2) :: - Row(1,2,1,1) :: - Row(1,2,1,2) :: Nil + Row(1, 1, 1, 1) :: + Row(1, 1, 1, 2) :: + Row(1, 2, 1, 1) :: + Row(1, 2, 1, 2) :: Nil ) } @@ -241,7 +242,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.N, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -255,7 +256,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, 1) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT r.a, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -301,7 +302,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.a, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -310,7 +311,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 6)) checkAnswer( - sql( + ctx.sql( """ |SELECT r.N, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -362,7 +363,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.a, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -371,7 +372,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 10)) checkAnswer( - sql( + ctx.sql( """ |SELECT r.N, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -386,7 +387,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT l.N, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -401,7 +402,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT r.a, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -411,11 +412,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted left semi join operator selection") { - cacheManager.clearCache() - sql("CACHE TABLE testData") - val tmp = conf.autoBroadcastJoinThreshold + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") + val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastLeftSemiJoinHash]) @@ -423,7 +424,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) @@ -431,12 +432,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) - sql("UNCACHE TABLE testData") + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) + ctx.sql("UNCACHE TABLE testData") } test("left semi join") { - val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, Row(1, 1) :: Row(1, 2) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index f9f41eb358bd5..2089660c52bf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,49 +19,47 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} class ListTablesSuite extends QueryTest with BeforeAndAfter { - import org.apache.spark.sql.test.TestSQLContext.implicits._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ - val df = - sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") before { df.registerTempTable("ListTablesSuiteTable") } after { - catalog.unregisterTable(Seq("ListTablesSuiteTable")) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) } test("get all tables") { checkAnswer( - tables().filter("tableName = 'ListTablesSuiteTable'"), + ctx.tables().filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("getting all Tables with a database name has no impact on returned table names") { checkAnswer( - tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -69,19 +67,20 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(tables(), sql("SHOW TABLes")).foreach { + Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) tableDF.registerTempTable("tables") checkAnswer( - sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), + ctx.sql( + "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), Row(true, "ListTablesSuiteTable") ) checkAnswer( - tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - dropTempTable("tables") + ctx.dropTempTable("tables") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index c4281c4b55c02..0a38af2b4c889 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -17,36 +17,29 @@ package org.apache.spark.sql -import java.lang.{Double => JavaDouble} - import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ - -private[this] object MathExpressionsTestData { - - case class DoubleData(a: JavaDouble, b: JavaDouble) - val doubleData = TestSQLContext.sparkContext.parallelize( - (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF() - - val nnDoubleData = TestSQLContext.sparkContext.parallelize( - (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF() - - case class NullDoubles(a: JavaDouble) - val nullDoubles = - TestSQLContext.sparkContext.parallelize( - NullDoubles(1.0) :: - NullDoubles(2.0) :: - NullDoubles(3.0) :: - NullDoubles(null) :: Nil - ).toDF() + + +private object MathExpressionsTestData { + case class DoubleData(a: java.lang.Double, b: java.lang.Double) + case class NullDoubles(a: java.lang.Double) } class MathExpressionsSuite extends QueryTest { import MathExpressionsTestData._ - def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() + + private lazy val nnDoubleData = (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1)).toDF() + + private lazy val nullDoubles = + Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF() + + private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( c: Column => Column, f: T => T): Unit = { checkAnswer( @@ -65,7 +58,8 @@ class MathExpressionsSuite extends QueryTest { ) } - def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { + private def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = + { checkAnswer( nnDoubleData.select(c('a)), (1 to 10).map(n => Row(f(n * 0.1))) @@ -89,7 +83,7 @@ class MathExpressionsSuite extends QueryTest { ) } - def testTwoToOneMathFunction( + private def testTwoToOneMathFunction( c: (Column, Column) => Column, d: (Column, Double) => Column, f: (Double, Double) => Double): Unit = { @@ -206,7 +200,7 @@ class MathExpressionsSuite extends QueryTest { } test("log") { - testOneToOneNonNegativeMathFunction(log, math.log) + testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) } test("log10") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index bbf9ab113ca43..98ba3c99283a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -67,6 +67,10 @@ class QueryTest extends PlanTest { checkAnswer(df, Seq(expectedAnswer)) } + protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { test(sqlString) { checkAnswer(sqlContext.sql(sqlString), expectedAnswer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index fb3ba4bc1b908..d84b57af9c882 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer -import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ -class RowSuite extends FunSuite { +class RowSuite extends SparkFunSuite { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ test("create row") { val expected = new GenericMutableRow(4) @@ -56,7 +57,7 @@ class RowSuite extends FunSuite { test("serialize w/ kryo") { val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() - val serializer = new SparkSqlSerializer(TestSQLContext.sparkContext.getConf) + val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf) val instance = serializer.newInstance() val ser = instance.serialize(row) val de = instance.deserialize(ser).asInstanceOf[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index bf73d0c7074a5..76d0dd1744a41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,68 +17,64 @@ package org.apache.spark.sql -import org.scalatest.FunSuiteLike -import org.apache.spark.sql.test._ +class SQLConfSuite extends QueryTest { -/* Implicits */ -import TestSQLContext._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext -class SQLConfSuite extends QueryTest with FunSuiteLike { - - val testKey = "test.key.0" - val testVal = "test.val.0" + private val testKey = "test.key.0" + private val testVal = "test.val.0" test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(TestSQLContext.sparkContext) - assert(newContext.getConf("spark.sql.testkey", "false") == "true") + val newContext = new SQLContext(ctx.sparkContext) + assert(newContext.getConf("spark.sql.testkey", "false") === "true") } test("programmatic ways of basic setting and getting") { - conf.clear() - assert(getAllConfs.size === 0) + ctx.conf.clear() + assert(ctx.getAllConfs.size === 0) - setConf(testKey, testVal) - assert(getConf(testKey) == testVal) - assert(getConf(testKey, testVal + "_") == testVal) - assert(getAllConfs.contains(testKey)) + ctx.setConf(testKey, testVal) + assert(ctx.getConf(testKey) === testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(TestSQLContext.getConf(testKey) == testVal) - assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getAllConfs.contains(testKey)) + assert(ctx.getConf(testKey) == testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getAllConfs.contains(testKey)) - conf.clear() + ctx.conf.clear() } test("parse SQL set commands") { - conf.clear() - sql(s"set $testKey=$testVal") - assert(getConf(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) + ctx.conf.clear() + ctx.sql(s"set $testKey=$testVal") + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) - sql("set some.property=20") - assert(getConf("some.property", "0") == "20") - sql("set some.property = 40") - assert(getConf("some.property", "0") == "40") + ctx.sql("set some.property=20") + assert(ctx.getConf("some.property", "0") === "20") + ctx.sql("set some.property = 40") + assert(ctx.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" - sql(s"set $key=$vs") - assert(getConf(key, "0") == vs) + ctx.sql(s"set $key=$vs") + assert(ctx.getConf(key, "0") === vs) - sql(s"set $key=") - assert(getConf(key, "0") == "") + ctx.sql(s"set $key=") + assert(ctx.getConf(key, "0") === "") - conf.clear() + ctx.conf.clear() } test("deprecated property") { - conf.clear() - sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10") + ctx.conf.clear() + ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(ctx.getConf(SQLConf.SHUFFLE_PARTITIONS) === "10") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index f186bc1c18123..c8d8796568a41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -17,33 +17,32 @@ package org.apache.spark.sql -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.SparkFunSuite -class SQLContextSuite extends FunSuite with BeforeAndAfterAll { +class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { - private val testSqlContext = TestSQLContext - private val testSparkContext = TestSQLContext.sparkContext + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext override def afterAll(): Unit = { - SQLContext.setLastInstantiatedContext(testSqlContext) + SQLContext.setLastInstantiatedContext(ctx) } test("getOrCreate instantiates SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = SQLContext.getOrCreate(testSparkContext) + val sqlContext = SQLContext.getOrCreate(ctx.sparkContext) assert(sqlContext != null, "SQLContext.getOrCreate returned null") - assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") } test("getOrCreate gets last explicitly instantiated SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = new SQLContext(testSparkContext) - assert(SQLContext.getOrCreate(testSparkContext) != null, + val sqlContext = new SQLContext(ctx.sparkContext) + assert(SQLContext.getOrCreate(ctx.sparkContext) != null, "SQLContext.getOrCreate after explicitly created SQLContext returned null") - assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 7c47fe454b6dc..5babc4332cc77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,20 +24,19 @@ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} - +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { +class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Make sure the tables are loaded. TestData - import org.apache.spark.sql.test.TestSQLContext.implicits._ - val sqlCtx = TestSQLContext + val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ + import sqlContext.sql test("SPARK-6743: no columns from cache") { Seq( @@ -46,14 +45,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { (43, 81, 24) ).toDF("a", "b", "c").registerTempTable("cachedData") - cacheTable("cachedData") + sqlContext.cacheTable("cachedData") checkAnswer( sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), Row(0) :: Row(81) :: Nil) } test("self join with aliases") { - Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") + Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") checkAnswer( sql( @@ -76,7 +75,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("self join with alias in agg") { - Seq(1,2,3) + Seq(1, 2, 3) .map(i => (i, i.toString)) .toDF("int", "str") .groupBy("str") @@ -94,14 +93,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SQL Dialect Switching to a new SQL parser") { - val newContext = new SQLContext(TestSQLContext.sparkContext) + val newContext = new SQLContext(sqlContext.sparkContext) newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) } test("SQL Dialect Switch to an invalid parser with alias") { - val newContext = new SQLContext(TestSQLContext.sparkContext) + val newContext = new SQLContext(sqlContext.sparkContext) newContext.sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { newContext.sql("SELECT 1") @@ -113,12 +112,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( sql("SELECT a FROM testData2 SORT BY a"), - Seq(1, 1, 2 ,2 ,3 ,3).map(Row(_)) + Seq(1, 1, 2, 2, 3, 3).map(Row(_)) ) } test("grouping on nested fields") { - read.json(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) + sqlContext.read.json(sqlContext.sparkContext.parallelize( + """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") checkAnswer( @@ -135,8 +135,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-6201 IN type conversion") { - read.json( - sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) + sqlContext.read.json( + sqlContext.sparkContext.parallelize( + Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") checkAnswer( @@ -157,12 +158,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("aggregation with codegen") { - val originalValue = conf.codegenEnabled - setConf(SQLConf.CODEGEN_ENABLED, "true") + val originalValue = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") // Prepare a table that we can group some rows. - table("testData") - .unionAll(table("testData")) - .unionAll(table("testData")) + sqlContext.table("testData") + .unionAll(sqlContext.table("testData")) + .unionAll(sqlContext.table("testData")) .registerTempTable("testData3x") def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { @@ -184,77 +185,79 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, expectedResults) } - // Just to group rows. - testCodeGen( - "SELECT key FROM testData3x GROUP BY key", - (1 to 100).map(Row(_))) - // COUNT - testCodeGen( - "SELECT key, count(value) FROM testData3x GROUP BY key", - (1 to 100).map(i => Row(i, 3))) - testCodeGen( - "SELECT count(key) FROM testData3x", - Row(300) :: Nil) - // COUNT DISTINCT ON int - testCodeGen( - "SELECT value, count(distinct key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, 1))) - testCodeGen( - "SELECT count(distinct key) FROM testData3x", - Row(100) :: Nil) - // SUM - testCodeGen( - "SELECT value, sum(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, 3 * i))) - testCodeGen( - "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", - Row(5050 * 3, 5050 * 3.0) :: Nil) - // AVERAGE - testCodeGen( - "SELECT value, avg(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT avg(key) FROM testData3x", - Row(50.5) :: Nil) - // MAX - testCodeGen( - "SELECT value, max(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT max(key) FROM testData3x", - Row(100) :: Nil) - // MIN - testCodeGen( - "SELECT value, min(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT min(key) FROM testData3x", - Row(1) :: Nil) - // Some combinations. - testCodeGen( - """ - |SELECT - | value, - | sum(key), - | max(key), - | min(key), - | avg(key), - | count(key), - | count(distinct key) - |FROM testData3x - |GROUP BY value - """.stripMargin, - (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) - testCodeGen( - "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", - Row(100, 1, 50.5, 300, 100) :: Nil) - // Aggregate with Code generation handling all null values - testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(0, null, 0) :: Nil) - - dropTempTable("testData3x") - setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + try { + // Just to group rows. + testCodeGen( + "SELECT key FROM testData3x GROUP BY key", + (1 to 100).map(Row(_))) + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testCodeGen( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) + // Aggregate with Code generation handling all null values + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(0, null, 0) :: Nil) + } finally { + sqlContext.dropTempTable("testData3x") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } } test("Add Parser of SQL COALESCE()") { @@ -354,7 +357,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("left semi greater than predicate") { checkAnswer( sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), - Seq(Row(3,1), Row(3,2)) + Seq(Row(3, 1), Row(3, 2)) ) } @@ -371,16 +374,16 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("agg") { checkAnswer( sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), - Seq(Row(1,3), Row(2,3), Row(3,3))) + Seq(Row(1, 3), Row(2, 3), Row(3, 3))) } test("literal in agg grouping expressions") { checkAnswer( sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1,2), Row(2,2), Row(3,2))) + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) checkAnswer( sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1,2), Row(2,2), Row(3,2))) + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) } test("aggregates with nulls") { @@ -405,19 +408,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { def sortTest(): Unit = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), - Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) + Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), - Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( sql("SELECT b FROM binaryData ORDER BY a ASC"), @@ -445,37 +448,43 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("sorting") { - val before = conf.externalSortEnabled - setConf(SQLConf.EXTERNAL_SORT, "false") + val before = sqlContext.conf.externalSortEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") sortTest() - setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) } test("external sorting") { - val before = conf.externalSortEnabled - setConf(SQLConf.EXTERNAL_SORT, "true") + val before = sqlContext.conf.externalSortEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") sortTest() - setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) } test("SPARK-6927 sorting with codegen on") { - val externalbefore = conf.externalSortEnabled - val codegenbefore = conf.codegenEnabled - setConf(SQLConf.EXTERNAL_SORT, "false") - setConf(SQLConf.CODEGEN_ENABLED, "true") - sortTest() - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + val externalbefore = sqlContext.conf.externalSortEnabled + val codegenbefore = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + try{ + sortTest() + } finally { + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } } test("SPARK-6927 external sorting with codegen on") { - val externalbefore = conf.externalSortEnabled - val codegenbefore = conf.codegenEnabled - setConf(SQLConf.CODEGEN_ENABLED, "true") - setConf(SQLConf.EXTERNAL_SORT, "true") - sortTest() - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + val externalbefore = sqlContext.conf.externalSortEnabled + val codegenbefore = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") + try { + sortTest() + } finally { + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } } test("limit") { @@ -508,7 +517,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("Allow only a single WITH clause per query") { intercept[RuntimeException] { - sql("with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") + sql( + "with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") } } @@ -552,7 +562,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("average overflow") { checkAnswer( sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), - Seq(Row(2147483645.0,1), Row(2.0,2))) + Seq(Row(2147483645.0, 1), Row(2.0, 2))) } test("count") { @@ -619,10 +629,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (SELECT * FROM testData2 WHERE a = 1) x JOIN | (SELECT * FROM testData2 WHERE a = 1) y |WHERE x.a = y.a""".stripMargin), - Row(1,1,1,1) :: - Row(1,1,1,2) :: - Row(1,2,1,1) :: - Row(1,2,1,2) :: Nil) + Row(1, 1, 1, 1) :: + Row(1, 1, 1, 2) :: + Row(1, 2, 1, 1) :: + Row(1, 2, 1, 2) :: Nil) } test("inner join, no matches") { @@ -855,7 +865,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SET commands semantics using sql()") { - conf.clear() + sqlContext.conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" @@ -887,17 +897,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql(s"SET $nonexistentKey"), Row(s"$nonexistentKey=") ) - conf.clear() + sqlContext.conf.clear() } test("SET commands with illegal or inappropriate argument") { - conf.clear() + sqlContext.conf.clear() // Set negative mapred.reduce.tasks for automatically determing // the number of reducers is not supported intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) - conf.clear() + sqlContext.conf.clear() } test("apply schema") { @@ -915,7 +925,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = sqlCtx.createDataFrame(rowRDD1, schema1) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -945,7 +955,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = sqlCtx.createDataFrame(rowRDD2, schema2) + val df2 = sqlContext.createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -970,7 +980,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = sqlCtx.createDataFrame(rowRDD3, schema2) + val df3 = sqlContext.createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -1015,7 +1025,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = sqlContext.createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -1030,7 +1040,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3371 Renaming a function expression with group by gives error") { - TestSQLContext.udf.register("len", (s: String) => s.length) + sqlContext.udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) @@ -1211,9 +1221,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3483 Special chars in column names") { - val data = sparkContext.parallelize( + val data = sqlContext.sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) - read.json(data).registerTempTable("records") + sqlContext.read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1254,35 +1264,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-4322 Grouping field with struct field as sub expression") { - read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") + sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) - dropTempTable("data") + sqlContext.dropTempTable("data") - read.json(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + sqlContext.read.json( + sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) - dropTempTable("data") + sqlContext.dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { checkAnswer( sql("SELECT a + b FROM testData2 ORDER BY a"), - Seq(2, 3, 3 ,4 ,4 ,5).map(Row(_)) + Seq(2, 3, 3, 4, 4, 5).map(Row(_)) ) } test("oder by asc by default when not specify ascending and descending") { checkAnswer( sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), - Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2,2), Row(1, 1), Row(1, 2)) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)) ) } test("Supporting relational operator '<=>' in Spark SQL") { - val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil - val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) + val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil + val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) rdd1.toDF().registerTempTable("nulldata1") - val nullCheckData2 = TestData(1,"1") :: TestData(2,null) :: Nil - val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) + val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil + val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), @@ -1290,23 +1302,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("Multi-column COUNT(DISTINCT ...)") { - val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil + val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } test("SPARK-4699 case sensitivity SQL query") { - setConf(SQLConf.CASE_SENSITIVE, "false") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, "false") val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) - setConf(SQLConf.CASE_SENSITIVE, "true") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, "true") } test("SPARK-6145: ORDER BY test for nested fields") { - read.json(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) @@ -1318,17 +1331,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-6145: special cases") { - read.json(sparkContext.makeRDD( + sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - read.json(sparkContext.makeRDD( + sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } + + test("SPARK-7952: fix the equality check between boolean and numeric types") { + withTempTable("t") { + // numeric field i, boolean field j, result of i = j, result of i <=> j + Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)]( + (1, true, true, true), + (0, false, true, true), + (2, true, false, false), + (2, false, false, false), + (null, true, null, false), + (null, false, null, false), + (0, null, null, false), + (1, null, null, false), + (null, null, null, true) + ).toDF("i", "b", "r1", "r2").registerTempTable("t") + + checkAnswer(sql("select i = b from t"), sql("select r1 from t")) + checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 3fa00fd9d0ccb..ece3d6fdf2af5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -19,10 +19,8 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.test.TestSQLContext._ case class ReflectData( stringField: String, @@ -74,45 +72,44 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends FunSuite { +class ScalaReflectionRelationSuite extends SparkFunSuite { - import org.apache.spark.sql.test.TestSQLContext.implicits._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectData") + new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1, 2, 3)) + Seq(data).toDF().registerTempTable("reflectData") - assert(sql("SELECT * FROM reflectData").collect().head === + assert(ctx.sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), - new Timestamp(12345), Seq(1,2,3))) + new Timestamp(12345), Seq(1, 2, 3))) } test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectNullData") + Seq(data).toDF().registerTempTable("reflectNullData") - assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) + assert(ctx.sql("SELECT * FROM reflectNullData").collect().head === + Row.fromSeq(Seq.fill(7)(null))) } test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectOptionalData") + Seq(data).toDF().registerTempTable("reflectOptionalData") - assert(sql("SELECT * FROM reflectOptionalData").collect().head === + assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. test("query binary data") { - val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) - rdd.toDF().registerTempTable("reflectBinary") + Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") - val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] + val result = ctx.sql("SELECT data FROM reflectBinary") + .collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -128,10 +125,9 @@ class ScalaReflectionRelationSuite extends FunSuite { Map(10 -> 100L, 20 -> 200L), Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), Nested(None, "abc"))) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectComplexData") - assert(sql("SELECT * FROM reflectComplexData").collect().head === + Seq(data).toDF().registerTempTable("reflectComplexData") + assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === new GenericRow(Array[Any]( Seq(1, 2, 3), Seq(1, 2, null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index 6f6d3c9c243d4..e55c9e460b791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.test.TestSQLContext -class SerializationSuite extends FunSuite { +class SerializationSuite extends SparkFunSuite { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext test("[SPARK-5235] SQLContext should be serializable") { - val sqlContext = new SQLContext(TestSQLContext.sparkContext) + val sqlContext = new SQLContext(ctx.sparkContext) new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 8fbc2d23d47e6..725a18bfae3a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -109,8 +109,8 @@ object TestData { case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) val arrayData = TestSQLContext.sparkContext.parallelize( - ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) :: - ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) + ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: + ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) arrayData.toDF().registerTempTable("arrayData") case class MapData(data: scala.collection.Map[Int, String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d615542ab50a7..703a34c47ec20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,43 +17,83 @@ package org.apache.spark.sql -import org.apache.spark.sql.test._ - -/* Implicits */ -import TestSQLContext._ -import TestSQLContext.implicits._ case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + test("built-in fixed arity expressions") { + val df = ctx.emptyDataFrame + df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") + } + + test("built-in vararg expressions") { + val df = Seq((1, 2)).toDF("a", "b") + df.selectExpr("array(a, b)") + df.selectExpr("struct(a, b)") + } + + test("built-in expressions with multiple constructors") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("substr(a, 2)", "substr(a, 2, 3)").collect() + } + + test("count") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("count(a)") + } + + test("count distinct") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("count(distinct a)") + } + + test("error reporting for incorrect number of arguments") { + val df = ctx.emptyDataFrame + val e = intercept[AnalysisException] { + df.selectExpr("substr('abcd', 2, 3, 4)") + } + assert(e.getMessage.contains("arguments")) + } + + test("error reporting for undefined functions") { + val df = ctx.emptyDataFrame + val e = intercept[AnalysisException] { + df.selectExpr("a_function_that_does_not_exist()") + } + assert(e.getMessage.contains("undefined function")) + } + test("Simple UDF") { - udf.register("strLenScala", (_: String).length) - assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) + ctx.udf.register("strLenScala", (_: String).length) + assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - udf.register("random0", () => { Math.random()}) - assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) + ctx.udf.register("random0", () => { Math.random()}) + assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - udf.register("strLenScala", (_: String).length + (_:Int)) - assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + ctx.udf.register("strLenScala", (_: String).length + (_: Int)) + assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("struct UDF") { - udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = - sql("SELECT returnStruct('test', 'test2') as ret") + ctx.sql("SELECT returnStruct('test', 'test2') as ret") .select($"ret.f1").head().getString(0) assert(result === "test") } test("udf that is transformed") { - udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. - assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index dc2d43a197f40..45c9f06941c10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql -import java.io.File - -import org.apache.spark.util.Utils - import scala.beans.{BeanInfo, BeanProperty} import com.clearspring.analytics.stream.cardinality.HyperLogLog @@ -28,12 +24,11 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { @@ -72,11 +67,13 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } class UserDefinedTypeSuite extends QueryTest { - val points = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD = sparkContext.parallelize(points).toDF() + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val pointsRDD = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } @@ -94,10 +91,10 @@ class UserDefinedTypeSuite extends QueryTest { } test("UDTs and UDFs") { - TestSQLContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( - sql("SELECT testType(features) from points"), + ctx.sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 7cefcf44061ce..339e719f39f16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.columnar -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.types._ -class ColumnStatsSuite extends FunSuite { +class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0)) testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0)) testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 1e105e259dce7..a1e76eaa982cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -23,15 +23,14 @@ import java.sql.Timestamp import com.esotericsoftware.kryo.{Serializer, Kryo} import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.serializer.KryoRegistrator -import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ -class ColumnTypeSuite extends FunSuite with Logging { +class ColumnTypeSuite extends SparkFunSuite with Logging { val DEFAULT_BUFFER_SIZE = 512 test("defaultSize") { @@ -73,7 +72,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(TIMESTAMP, new Timestamp(0L), 12) val binary = Array.fill[Byte](4)(0: Byte) - checkActualSize(BINARY, binary, 4 + 4) + checkActualSize(BINARY, binary, 4 + 4) val generic = Map(1 -> "a") checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) @@ -167,7 +166,7 @@ class ColumnTypeSuite extends FunSuite with Logging { val serializer = new SparkSqlSerializer(conf).newInstance() val buffer = ByteBuffer.allocate(512) - val obj = CustomClass(Int.MaxValue,Long.MaxValue) + val obj = CustomClass(Int.MaxValue, Long.MaxValue) val serializedObj = serializer.serialize(obj).array() GENERIC.append(serializer.serialize(obj).array(), buffer) @@ -278,7 +277,7 @@ private[columnar] object CustomerSerializer extends Serializer[CustomClass] { override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = { val a = input.readInt() val b = input.readLong() - CustomClass(a,b) + CustomClass(a, b) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 56591d9dba29e..fa3b8144c086e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -21,8 +21,6 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY @@ -31,8 +29,12 @@ class InMemoryColumnarQuerySuite extends QueryTest { // Make sure the tables are loaded. TestData + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.{logicalPlanToSparkQuery, sql} + test("simple columnar query") { - val plan = executePlan(testData.logicalPlan).executedPlan + val plan = ctx.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -40,16 +42,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) .toDF().registerTempTable("sizeTst") - cacheTable("sizeTst") + ctx.cacheTable("sizeTst") assert( - table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - conf.autoBroadcastJoinThreshold) + ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + ctx.conf.autoBroadcastJoinThreshold) } test("projection") { - val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -58,7 +60,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = executePlan(testData.logicalPlan).executedPlan + val plan = ctx.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -70,7 +72,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - cacheTable("repeatedData") + ctx.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -82,7 +84,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - cacheTable("nullableRepeatedData") + ctx.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -94,7 +96,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq.map(Row.fromTuple)) - cacheTable("timestamps") + ctx.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -106,7 +108,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - cacheTable("withEmptyParts") + ctx.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -155,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { // Create a RDD for the schema val rdd = - sparkContext.parallelize((1 to 100), 10).map { i => + ctx.sparkContext.parallelize((1 to 100), 10).map { i => Row( s"str${i}: test cache.", s"binary${i}: test cache.".getBytes("UTF-8"), @@ -173,20 +175,20 @@ class InMemoryColumnarQuerySuite extends QueryTest { new Timestamp(i), (1 to i).toSeq, (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, - Row((i - 0.25).toFloat, (1 to i).toSeq)) + Row((i - 0.25).toFloat, Seq(true, false, null))) } - createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan + val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - isCached("InMemoryCache_different_data_types"), + ctx.isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - table("InMemoryCache_different_data_types").collect()) - dropTempTable("InMemoryCache_different_data_types") + ctx.table("InMemoryCache_different_data_types").collect()) + ctx.dropTempTable("InMemoryCache_different_data_types") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index a0702144f942c..2a6e0c376551a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types.DataType @@ -39,7 +38,7 @@ object TestNullableColumnAccessor { } } -class NullableColumnAccessorSuite extends FunSuite { +class NullableColumnAccessorSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 3a5605d2335d7..cb4e9f1eb7f46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.columnar -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ @@ -35,7 +34,7 @@ object TestNullableColumnBuilder { } } -class NullableColumnBuilderSuite extends FunSuite { +class NullableColumnBuilderSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 2a0b701cad7fa..6545c6b314a4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -17,43 +17,46 @@ package org.apache.spark.sql.columnar -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ -class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { - val originalColumnBatchSize = conf.columnBatchSize - val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning +class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize + private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - setConf(SQLConf.COLUMN_BATCH_SIZE, "10") + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, "10") - val pruningData = sparkContext.makeRDD((1 to 100).map { key => + val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning - setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") // Enable in-memory table scan accumulators - setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { - setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) - setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) } before { - cacheTable("pruningData") + ctx.cacheTable("pruningData") } after { - uncacheTable("pruningData") + ctx.uncacheTable("pruningData") } // Comparisons @@ -107,7 +110,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val df = sql(query) + val df = ctx.sql(query) val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index 8b518f094174c..20d65a74e3b7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ -class BooleanBitSetSuite extends FunSuite { +class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ def skeleton(count: Int) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index 64b70552eb047..acfab6586c0d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -19,16 +19,15 @@ package org.apache.spark.sql.columnar.compression import java.nio.ByteBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType -class DictionaryEncodingSuite extends FunSuite { - testDictionaryEncoding(new IntColumnStats, INT) - testDictionaryEncoding(new LongColumnStats, LONG) +class DictionaryEncodingSuite extends SparkFunSuite { + testDictionaryEncoding(new IntColumnStats, INT) + testDictionaryEncoding(new LongColumnStats, LONG) testDictionaryEncoding(new StringColumnStats, STRING) def testDictionaryEncoding[T <: AtomicType]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index bfd99f143bedc..2111e9fbe62cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType -class IntegralDeltaSuite extends FunSuite { - testIntegralDelta(new IntColumnStats, INT, IntDelta) +class IntegralDeltaSuite extends SparkFunSuite { + testIntegralDelta(new IntColumnStats, INT, IntDelta) testIntegralDelta(new LongColumnStats, LONG, LongDelta) def testIntegralDelta[I <: IntegralType]( @@ -116,7 +115,7 @@ class IntegralDeltaSuite extends FunSuite { test(s"$scheme: simple case") { val input = columnType match { - case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) + case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index fde7a4595be0e..67ec08f594a43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType -class RunLengthEncodingSuite extends FunSuite { +class RunLengthEncodingSuite extends SparkFunSuite { testRunLengthEncoding(new NoopColumnStats, BOOLEAN) - testRunLengthEncoding(new ByteColumnStats, BYTE) - testRunLengthEncoding(new ShortColumnStats, SHORT) - testRunLengthEncoding(new IntColumnStats, INT) - testRunLengthEncoding(new LongColumnStats, LONG) - testRunLengthEncoding(new StringColumnStats, STRING) + testRunLengthEncoding(new ByteColumnStats, BYTE) + testRunLengthEncoding(new ShortColumnStats, SHORT) + testRunLengthEncoding(new IntColumnStats, INT) + testRunLengthEncoding(new LongColumnStats, LONG) + testRunLengthEncoding(new StringColumnStats, STRING) def testRunLengthEncoding[T <: AtomicType]( columnStats: ColumnStats, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 523be56df65ba..45a7e8fe68f72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SQLConf, execution} import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ @@ -31,7 +30,7 @@ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ -class PlannerSuite extends FunSuite { +class PlannerSuite extends SparkFunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 15337c4045436..8631e247c6c05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.execution import java.sql.{Timestamp, Date} -import org.scalatest.{FunSuite, BeforeAndAfterAll} +import org.apache.spark.sql.test.TestSQLContext +import org.scalatest.BeforeAndAfterAll import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.serializer.Serializer -import org.apache.spark.ShuffleDependency +import org.apache.spark.{ShuffleDependency, SparkFunSuite} import org.apache.spark.sql.types._ import org.apache.spark.sql.Row -import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} -class SparkSqlSerializer2DataTypeSuite extends FunSuite { +class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { // Make sure that we will not use serializer2 for unsupported data types. def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { val testName = @@ -74,11 +74,13 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll var numShufflePartitions: Int = _ var useSerializer2: Boolean = _ + protected lazy val ctx = TestSQLContext + override def beforeAll(): Unit = { - numShufflePartitions = conf.numShufflePartitions - useSerializer2 = conf.useSqlSerializer2 + numShufflePartitions = ctx.conf.numShufflePartitions + useSerializer2 = ctx.conf.useSqlSerializer2 - sql("set spark.sql.useSerializer2=true") + ctx.sql("set spark.sql.useSerializer2=true") val supportedTypes = Seq(StringType, BinaryType, NullType, BooleanType, @@ -94,7 +96,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll // Create a RDD with all data types supported by SparkSqlSerializer2. val rdd = - sparkContext.parallelize((1 to 1000), 10).map { i => + ctx.sparkContext.parallelize((1 to 1000), 10).map { i => Row( s"str${i}: test serializer2.", s"binary${i}: test serializer2.".getBytes("UTF-8"), @@ -112,15 +114,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll new Timestamp(i)) } - createDataFrame(rdd, schema).registerTempTable("shuffle") + ctx.createDataFrame(rdd, schema).registerTempTable("shuffle") super.beforeAll() } override def afterAll(): Unit = { - dropTempTable("shuffle") - sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") - sql(s"set spark.sql.useSerializer2=$useSerializer2") + ctx.dropTempTable("shuffle") + ctx.sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") + ctx.sql(s"set spark.sql.useSerializer2=$useSerializer2") super.afterAll() } @@ -141,16 +143,16 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll } test("key schema and value schema are not nulls") { - val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") + val df = ctx.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") checkSerializer(df.queryExecution.executedPlan, serializerClass) checkAnswer( df, - table("shuffle").collect()) + ctx.table("shuffle").collect()) } test("key schema is null") { val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") - val df = sql(s"SELECT $aggregations FROM shuffle") + val df = ctx.sql(s"SELECT $aggregations FROM shuffle") checkSerializer(df.queryExecution.executedPlan, serializerClass) checkAnswer( df, @@ -158,15 +160,14 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll } test("value schema is null") { - val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0") + val df = ctx.sql(s"SELECT col0 FROM shuffle ORDER BY col0") checkSerializer(df.queryExecution.executedPlan, serializerClass) - assert( - df.map(r => r.getString(0)).collect().toSeq === - table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) + assert(df.map(r => r.getString(0)).collect().toSeq === + ctx.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) } test("no map output field") { - val df = sql(s"SELECT 1 + 1 FROM shuffle") + val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle") checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) } } @@ -177,8 +178,8 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { super.beforeAll() // Sort merge will not be triggered. val bypassMergeThreshold = - sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") + ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") } } @@ -189,7 +190,7 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite super.beforeAll() // To trigger the sort merge. val bypassMergeThreshold = - sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") + ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 358d8cf06e463..8ec3985e00360 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.debug -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext._ -class DebuggingSuite extends FunSuite { +class DebuggingSuite extends SparkFunSuite { test("DataFrame.debug()") { testData.debug() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 2aad01ded1acf..5290c28cfca02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.joins -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Projection, Row} import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends FunSuite { +class HashedRelationSuite extends SparkFunSuite { // Key is simply the record itself private val keyProjection = new Projection { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 347f28351fd72..49d348c3ed21b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -21,14 +21,13 @@ import java.math.BigDecimal import java.sql.DriverManager import java.util.{Calendar, GregorianCalendar, Properties} -import org.apache.spark.sql.test._ -import org.apache.spark.sql.types._ import org.h2.jdbc.JdbcSQLException -import org.scalatest.{FunSuite, BeforeAndAfter} -import TestSQLContext._ -import TestSQLContext.implicits._ +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ -class JDBCSuite extends FunSuite with BeforeAndAfter { +class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null @@ -36,12 +35,16 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) val testH2Dialect = new JdbcDialect { - def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = Some(StringType) } + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.sql + before { Class.forName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test @@ -67,7 +70,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - + sql( s""" |CREATE TEMPORARY TABLE fetchtwo @@ -75,7 +78,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', | fetchSize '2') """.stripMargin.replaceAll("\n", " ")) - + sql( s""" |CREATE TEMPORARY TABLE parts @@ -208,7 +211,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(ids(1) === 2) assert(ids(2) === 3) } - + test("SELECT second field when fetchSize is two") { val ids = sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) @@ -252,26 +255,26 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("Basic API") { - assert(TestSQLContext.read.jdbc( + assert(ctx.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) } test("Basic API with FetchSize") { val properties = new Properties properties.setProperty("fetchSize", "2") - assert(TestSQLContext.read.jdbc( + assert(ctx.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert( - TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) .collect().length === 3) } @@ -327,9 +330,9 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("test DATE types") { - val rows = TestSQLContext.read.jdbc( + val rows = ctx.read.jdbc( urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - val cachedRows = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) @@ -337,9 +340,8 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("test DATE types in cache") { - val rows = - TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) @@ -347,7 +349,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("test types for null value") { - val rows = TestSQLContext.read.jdbc( + val rows = ctx.read.jdbc( urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -394,10 +396,8 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) - val df = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) - assert(df.schema.filter( - _.dataType != org.apache.spark.sql.types.StringType - ).isEmpty) + val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) assert(rows(0).get(1).isInstanceOf[String]) @@ -410,6 +410,17 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(JdbcDialects.get("test.invalid") == NoopDialect) } + test("quote column names by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + + val columns = Seq("abc", "key") + val MySQLColumns = columns.map(MySQL.quoteIdentifier(_)) + val PostgresColumns = columns.map(Postgres.quoteIdentifier(_)) + assert(MySQLColumns === Seq("`abc`", "`key`")) + assert(PostgresColumns === Seq(""""abc"""", """"key"""")) + } + test("Dialect unregister") { JdbcDialects.registerDialect(testH2Dialect) JdbcDialects.unregisterDialect(testH2Dialect) @@ -418,7 +429,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { test("Aggregated dialects") { val agg = new AggregatedDialect(List(new JdbcDialect { - def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = if (sqlType % 2 == 0) { @@ -429,8 +440,8 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) assert(!agg.canHandle("jdbc:h2")) - assert(agg.getCatalystType(0,"",1,null) == Some(LongType)) - assert(agg.getCatalystType(1,"",1,null) == Some(StringType)) + assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) + assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 2e4c12f9da80c..d949ef42267ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager import java.util.Properties -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SaveMode, Row} -import org.apache.spark.sql.test._ import org.apache.spark.sql.types._ -class JDBCWriteSuite extends FunSuite with BeforeAndAfter { +class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null val url1 = "jdbc:h2:mem:testdb3" @@ -35,12 +35,16 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { properties.setProperty("user", "testUser") properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") - + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.sql + before { Class.forName("org.h2.Driver") conn = DriverManager.getConnection(url) conn.prepareStatement("create schema test").executeUpdate() - + conn1 = DriverManager.getConnection(url1, properties) conn1.prepareStatement("create schema test").executeUpdate() conn1.prepareStatement("drop table if exists test.people").executeUpdate() @@ -52,20 +56,20 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { conn1.prepareStatement( "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn1.commit() - - TestSQLContext.sql( + + ctx.sql( s""" |CREATE TEMPORARY TABLE PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - - TestSQLContext.sql( + + ctx.sql( s""" |CREATE TEMPORARY TABLE PEOPLE1 |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass') - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll("\n", " ")) } after { @@ -73,66 +77,64 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { conn1.close() } - val sc = TestSQLContext.sparkContext + private lazy val sc = ctx.sparkContext - val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) - val arr1x2 = Array[Row](Row.apply("fred", 3)) - val schema2 = StructType( + private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) + private lazy val arr1x2 = Array[Row](Row.apply("fred", 3)) + private lazy val schema2 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: Nil) - val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) - val schema3 = StructType( + private lazy val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) + private lazy val schema3 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) - assert(2 == TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) - assert(2 == - TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) df.write.jdbc(url, "TEST.APPENDTEST", new Properties) df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) - assert(3 == TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) - assert(2 == - TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) + assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { @@ -141,15 +143,15 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { } test("INSERT to JDBC Datasource") { - TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { - TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) - } + ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 7e6eeba17752a..d889c7be17ce7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -23,21 +23,19 @@ import java.sql.{Date, Timestamp} import com.fasterxml.jackson.core.JsonFactory import org.scalactic.Tolerance._ +import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.json.InferSchema.compatibleType import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.util.Utils -class JsonSuite extends QueryTest { - import org.apache.spark.sql.json.TestJsonData._ +class JsonSuite extends QueryTest with TestJsonData { - TestJsonData + protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.sql + import ctx.implicits._ test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { @@ -214,7 +212,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring with null in sampling") { - val jsonDF = read.json(jsonNullStruct) + val jsonDF = ctx.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -233,7 +231,7 @@ class JsonSuite extends QueryTest { } test("Primitive field and type inferring") { - val jsonDF = read.json(primitiveFieldAndType) + val jsonDF = ctx.read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -261,7 +259,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring") { - val jsonDF = read.json(complexFieldAndType1) + val jsonDF = ctx.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -360,7 +358,7 @@ class JsonSuite extends QueryTest { } test("GetField operation on complex data type") { - val jsonDF = read.json(complexFieldAndType1) + val jsonDF = ctx.read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -376,7 +374,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in primitive field values") { - val jsonDF = read.json(primitiveFieldValueTypeConflict) + val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -450,7 +448,7 @@ class JsonSuite extends QueryTest { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = read.json(primitiveFieldValueTypeConflict) + val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -503,7 +501,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in complex field values") { - val jsonDF = read.json(complexFieldValueTypeConflict) + val jsonDF = ctx.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -522,12 +520,12 @@ class JsonSuite extends QueryTest { Row(Seq(), "11", "[1,2,3]", Row(null), "[]") :: Row(null, """{"field":false}""", null, null, "{}") :: Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") :: - Row(Seq(7), "{}","""["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil + Row(Seq(7), "{}", """["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil ) } test("Type conflict in array elements") { - val jsonDF = read.json(arrayElementTypeConflict) + val jsonDF = ctx.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -555,7 +553,7 @@ class JsonSuite extends QueryTest { } test("Handling missing fields") { - val jsonDF = read.json(missingFields) + val jsonDF = ctx.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -574,8 +572,9 @@ class JsonSuite extends QueryTest { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - sparkContext.parallelize(1 to 100).map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = read.option("samplingRatio", "0.49").json(path) + ctx.sparkContext.parallelize(1 to 100) + .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) + val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -590,7 +589,7 @@ class JsonSuite extends QueryTest { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] + ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.path === Some(path)) assert(relationWithSchema.schema === schema) @@ -602,7 +601,7 @@ class JsonSuite extends QueryTest { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = read.json(path) + val jsonDF = ctx.read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -671,7 +670,7 @@ class JsonSuite extends QueryTest { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = read.schema(schema).json(path) + val jsonDF1 = ctx.read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -688,7 +687,7 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val jsonDF2 = read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -709,7 +708,7 @@ class JsonSuite extends QueryTest { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -737,7 +736,7 @@ class JsonSuite extends QueryTest { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -763,7 +762,7 @@ class JsonSuite extends QueryTest { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = read.json(complexFieldAndType2) + val jsonDF = ctx.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -781,7 +780,7 @@ class JsonSuite extends QueryTest { } test("SPARK-3390 Complex arrays") { - val jsonDF = read.json(complexFieldAndType2) + val jsonDF = ctx.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -804,7 +803,7 @@ class JsonSuite extends QueryTest { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = read.json(jsonArray) + val jsonDF = ctx.read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -822,10 +821,10 @@ class JsonSuite extends QueryTest { test("Corrupt records") { // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord - TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - val jsonDF = read.json(corruptRecords) + val jsonDF = ctx.read.json(corruptRecords) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -875,11 +874,11 @@ class JsonSuite extends QueryTest { Row("]") :: Nil ) - TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } test("SPARK-4068: nulls in arrays") { - val jsonDF = read.json(nullsInArrays) + val jsonDF = ctx.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -925,7 +924,7 @@ class JsonSuite extends QueryTest { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = createDataFrame(rowRDD1, schema1) + val df1 = ctx.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -948,7 +947,7 @@ class JsonSuite extends QueryTest { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = createDataFrame(rowRDD2, schema2) + val df3 = ctx.createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -956,8 +955,8 @@ class JsonSuite extends QueryTest { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = read.json(primitiveFieldAndType) - val primTable = read.json(jsonDF.toJSON) + val jsonDF = ctx.read.json(primitiveFieldAndType) + val primTable = ctx.read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -969,8 +968,8 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val complexJsonDF = read.json(complexFieldAndType1) - val compTable = read.json(complexJsonDF.toJSON) + val complexJsonDF = ctx.read.json(complexFieldAndType1) + val compTable = ctx.read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1074,29 +1073,29 @@ class JsonSuite extends QueryTest { } test("SPARK-7565 MapType in JsonRDD") { - val useStreaming = getConf(SQLConf.USE_JACKSON_STREAMING_API, "true") - val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord - TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + val useStreaming = ctx.getConf(SQLConf.USE_JACKSON_STREAMING_API, "true") + val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) try{ for (useStreaming <- List("true", "false")) { - setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) val temp = Utils.createTempDir().getPath - val df = read.schema(schemaWithSimpleMap).json(mapType1) + val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) df.write.mode("overwrite").parquet(temp) // order of MapType is not defined - assert(read.parquet(temp).count() == 5) + assert(ctx.read.parquet(temp).count() == 5) - val df2 = read.json(corruptRecords) + val df2 = ctx.read.json(corruptRecords) df2.write.mode("overwrite").parquet(temp) - checkAnswer(read.parquet(temp), df2.collect()) + checkAnswer(ctx.read.parquet(temp), df2.collect()) } } finally { - setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) - setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index 47a97a49daabb..b6a6a8dc6a63c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.json -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext -object TestJsonData { +trait TestJsonData { - val primitiveFieldAndType = - TestSQLContext.sparkContext.parallelize( + protected def ctx: SQLContext + + def primitiveFieldAndType: RDD[String] = + ctx.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -32,8 +35,8 @@ object TestJsonData { "null":null }""" :: Nil) - val primitiveFieldValueTypeConflict = - TestSQLContext.sparkContext.parallelize( + def primitiveFieldValueTypeConflict: RDD[String] = + ctx.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -43,15 +46,15 @@ object TestJsonData { """{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470, "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) - val jsonNullStruct = - TestSQLContext.sparkContext.parallelize( + def jsonNullStruct: RDD[String] = + ctx.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) - val complexFieldValueTypeConflict = - TestSQLContext.sparkContext.parallelize( + def complexFieldValueTypeConflict: RDD[String] = + ctx.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -61,23 +64,23 @@ object TestJsonData { """{"num_struct":{}, "str_array":["str1", "str2", 33], "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) - val arrayElementTypeConflict = - TestSQLContext.sparkContext.parallelize( + def arrayElementTypeConflict: RDD[String] = + ctx.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) - val missingFields = - TestSQLContext.sparkContext.parallelize( + def missingFields: RDD[String] = + ctx.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: """{"d":{"field":true}}""" :: """{"e":"str"}""" :: Nil) - val complexFieldAndType1 = - TestSQLContext.sparkContext.parallelize( + def complexFieldAndType1: RDD[String] = + ctx.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -92,8 +95,8 @@ object TestJsonData { "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] }""" :: Nil) - val complexFieldAndType2 = - TestSQLContext.sparkContext.parallelize( + def complexFieldAndType2: RDD[String] = + ctx.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -146,16 +149,16 @@ object TestJsonData { ]] }""" :: Nil) - val mapType1 = - TestSQLContext.sparkContext.parallelize( + def mapType1: RDD[String] = + ctx.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: """{"map": {"c": 1, "d": 4}}""" :: """{"map": {"e": null}}""" :: Nil) - val mapType2 = - TestSQLContext.sparkContext.parallelize( + def mapType2: RDD[String] = + ctx.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -163,22 +166,22 @@ object TestJsonData { """{"map": {"e": null}}""" :: """{"map": {"f": {"field1": null}}}""" :: Nil) - val nullsInArrays = - TestSQLContext.sparkContext.parallelize( + def nullsInArrays: RDD[String] = + ctx.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) - val jsonArray = - TestSQLContext.sparkContext.parallelize( + def jsonArray: RDD[String] = + ctx.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) - val corruptRecords = - TestSQLContext.sparkContext.parallelize( + def corruptRecords: RDD[String] = + ctx.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -186,6 +189,5 @@ object TestJsonData { """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) - val empty = - TestSQLContext.sparkContext.parallelize(Seq[String]()) + def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index bdc2ebabc5e9a..17f5f9a491e6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -18,14 +18,13 @@ package org.apache.spark.sql.parquet import org.scalatest.BeforeAndAfterAll -import parquet.filter2.predicate.Operators._ -import parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.parquet.filter2.predicate.Operators._ +import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} @@ -42,7 +41,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} * data type is nullable. */ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext private def checkFilterPredicate( df: DataFrame, @@ -312,7 +311,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { } class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") @@ -341,7 +340,7 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA } class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index dd48bb350f26d..46b25859d9a68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -23,22 +23,21 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.parquet.example.data.simple.SimpleGroup +import org.apache.parquet.example.data.{Group, GroupWriter} +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter} +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.scalatest.BeforeAndAfterAll -import parquet.example.data.simple.SimpleGroup -import parquet.example.data.{Group, GroupWriter} -import parquet.hadoop.api.WriteSupport -import parquet.hadoop.api.WriteSupport.WriteContext -import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName} -import parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter} -import parquet.io.api.RecordConsumer -import parquet.schema.{MessageType, MessageTypeParser} +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode} @@ -66,9 +65,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS * A test suite that tests basic Parquet I/O. */ class ParquetIOSuiteBase extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext - - import sqlContext.implicits.localSeqToDataFrameHolder + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ /** * Writes `data` to a Parquet file, reads it back and check file contents. @@ -104,7 +102,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { test("fixed-length decimals") { def makeDecimalRDD(decimal: DecimalType): DataFrame = - sparkContext + sqlContext.sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) .toDF() @@ -115,7 +113,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) - checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } @@ -123,7 +121,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { intercept[Throwable] { withTempPath { dir => makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).collect() + sqlContext.read.parquet(dir.getCanonicalPath).collect() } } @@ -131,14 +129,14 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { intercept[Throwable] { withTempPath { dir => makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).collect() + sqlContext.read.parquet(dir.getCanonicalPath).collect() } } } test("date type") { def makeDateRDD(): DataFrame = - sparkContext + sqlContext.sparkContext .parallelize(0 to 1000) .map(i => Tuple1(DateUtils.toJavaDate(i))) .toDF() @@ -147,7 +145,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempPath { dir => val data = makeDateRDD() data.write.parquet(dir.getCanonicalPath) - checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } @@ -200,7 +198,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetDataFrame(allNulls :: Nil) { df => val rows = df.collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows.head === Row(Seq.fill(5)(null): _*)) } } @@ -213,7 +211,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetDataFrame(allNones :: Nil) { df => val rows = df.collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows.head === Row(Seq.fill(3)(null): _*)) } } @@ -236,7 +234,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { def checkCompressionCodec(codec: CompressionCodecName): Unit = { withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { withParquetFile(data) { path => - assertResult(conf.parquetCompressionCodec.toUpperCase) { + assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { compressionCodecFor(path) } } @@ -244,7 +242,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) + checkCompressionCodec(CompressionCodecName.fromConf(sqlContext.conf.parquetCompressionCodec)) checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) checkCompressionCodec(CompressionCodecName.GZIP) @@ -283,7 +281,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempDir { dir => val path = new Path(dir.toURI.toString, "part-r-0.parquet") makeRawParquetFile(path) - checkAnswer(read.parquet(path.toString), (0 until 10).map { i => + checkAnswer(sqlContext.read.parquet(path.toString), (0 until 10).map { i => Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) }) } @@ -312,7 +310,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile((1 to 10).map(i => (i, i.toString))) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file) - checkAnswer(read.parquet(file), newData.map(Row.fromTuple)) + checkAnswer(sqlContext.read.parquet(file), newData.map(Row.fromTuple)) } } @@ -321,7 +319,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file) - checkAnswer(read.parquet(file), data.map(Row.fromTuple)) + checkAnswer(sqlContext.read.parquet(file), data.map(Row.fromTuple)) } } @@ -341,7 +339,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file) - checkAnswer(read.parquet(file), (data ++ newData).map(Row.fromTuple)) + checkAnswer(sqlContext.read.parquet(file), (data ++ newData).map(Row.fromTuple)) } } @@ -369,11 +367,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val path = new Path(location.getCanonicalPath) ParquetFileWriter.writeMetadataFile( - sparkContext.hadoopConfiguration, + sqlContext.sparkContext.hadoopConfiguration, path, new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) - assertResult(read.parquet(path.toString).schema) { + assertResult(sqlContext.read.parquet(path.toString).schema) { StructType( StructField("a", BooleanType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: @@ -383,6 +381,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } test("SPARK-6352 DirectParquetOutputCommitter") { + val clonedConf = new Configuration(configuration) + // Write to a parquet file and let it fail. // _temporary should be missing if direct output committer works. try { @@ -397,16 +397,48 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val fs = path.getFileSystem(configuration) assert(!fs.exists(path)) } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } - finally { - configuration.set("spark.sql.parquet.output.committer.class", - "parquet.hadoop.ParquetOutputCommitter") + } + + test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overriden") { + withTempPath { dir => + val clonedConf = new Configuration(configuration) + + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS, classOf[ParquetOutputCommitter].getCanonicalName) + + configuration.set( + "spark.sql.parquet.output.committer.class", + classOf[BogusParquetOutputCommitter].getCanonicalName) + + try { + val message = intercept[SparkException] { + sqlContext.range(0, 1).write.parquet(dir.getCanonicalPath) + }.getCause.getMessage + assert(message === "Intentional exception for testing purposes") + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } } } } +class BogusParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + + override def commitJob(jobContext: JobContext): Unit = { + sys.error("Intentional exception for testing purposes") + } +} + class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") @@ -430,7 +462,7 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA } class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 90d4528efca48..3240079483545 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -14,20 +14,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.parquet import java.io.File +import java.math.BigInteger +import java.sql.Timestamp import scala.collection.mutable.ArrayBuffer +import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.sources.PartitioningUtils._ import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec} -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLContext} +import org.apache.spark.sql.{Column, QueryTest, Row, SQLContext} // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) @@ -36,33 +39,33 @@ case class ParquetData(intField: Int, stringField: String) case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { - override val sqlContext: SQLContext = TestSQLContext - import sqlContext._ + override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext import sqlContext.implicits._ + import sqlContext.sql val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" test("column type inference") { def check(raw: String, literal: Literal): Unit = { - assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) + assert(inferPartitionColumnValue(raw, defaultPartitionName, true) === literal) } check("10", Literal.create(10, IntegerType)) check("1000000000000000", Literal.create(1000000000000000L, LongType)) - check("1.5", Literal.create(1.5, FloatType)) + check("1.5", Literal.create(1.5, DoubleType)) check("hello", Literal.create("hello", StringType)) check(defaultPartitionName, Literal.create(null, NullType)) } test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - assert(expected === parsePartition(new Path(path), defaultPartitionName)) + assert(expected === parsePartition(new Path(path), defaultPartitionName, true)) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), defaultPartitionName).get + parsePartition(new Path(path), defaultPartitionName, true).get }.getMessage assert(message.contains(expected)) @@ -80,13 +83,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { ArrayBuffer( Literal.create(10, IntegerType), Literal.create("hello", StringType), - Literal.create(1.5, FloatType))) + Literal.create(1.5, DoubleType))) }) check("file://path/a=10/b_hello/c=1.5", Some { PartitionValues( ArrayBuffer("c"), - ArrayBuffer(Literal.create(1.5, FloatType))) + ArrayBuffer(Literal.create(1.5, DoubleType))) }) check("file:///", None) @@ -102,7 +105,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { test("parse partitions") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { - assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName) === spec) + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) === spec) } check(Seq( @@ -118,7 +121,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { "hdfs://host:9000/path/a=10.5/b=hello"), PartitionSpec( StructType(Seq( - StructField("a", FloatType), + StructField("a", DoubleType), StructField("b", StringType))), Seq( Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), @@ -137,7 +140,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { "hdfs://host:9000/path/a=10.5/b=world/_temporary/path"), PartitionSpec( StructType(Seq( - StructField("a", FloatType), + StructField("a", DoubleType), StructField("b", StringType))), Seq( Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), @@ -159,7 +162,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), PartitionSpec( StructType(Seq( - StructField("a", FloatType), + StructField("a", DoubleType), StructField("b", StringType))), Seq( Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), @@ -171,6 +174,77 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { PartitionSpec.emptySpec) } + test("parse partitions with type inference disabled") { + def check(paths: Seq[String], spec: PartitionSpec): Unit = { + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, false) === spec) + } + + check(Seq( + "hdfs://host:9000/path/a=10/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq(Partition(Row("10", "hello"), "hdfs://host:9000/path/a=10/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(Row("10", "20"), "hdfs://host:9000/path/a=10/b=20"), + Partition(Row("10.5", "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/_temporary", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello", + "hdfs://host:9000/path/a=10.5/_temporary", + "hdfs://host:9000/path/a=10.5/_TeMpOrArY", + "hdfs://host:9000/path/a=10.5/b=hello/_temporary", + "hdfs://host:9000/path/a=10.5/b=hello/_TEMPORARY", + "hdfs://host:9000/path/_temporary/path", + "hdfs://host:9000/path/a=11/_temporary/path", + "hdfs://host:9000/path/a=10.5/b=world/_temporary/path"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(Row("10", "20"), "hdfs://host:9000/path/a=10/b=20"), + Partition(Row("10.5", "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=20", + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(Row("10", "20"), s"hdfs://host:9000/path/a=10/b=20"), + Partition(Row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(Row("10", null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(Row("10.5", null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + + check(Seq( + s"hdfs://host:9000/path1", + s"hdfs://host:9000/path2"), + PartitionSpec.emptySpec) + } + test("read partitioned table - normal case") { withTempDir { base => for { @@ -187,8 +261,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { // Introduce _temporary dir to the base dir the robustness of the schema discovery process. new File(base.getCanonicalPath, "_temporary").mkdir() - println("load the partitioned table") - read.parquet(base.getCanonicalPath).registerTempTable("t") + sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -235,7 +308,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.parquet(base.getCanonicalPath).registerTempTable("t") + sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -283,7 +356,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath) + val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -323,7 +396,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath) + val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -355,7 +428,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), makePartitionDir(base, defaultPartitionName, "pi" -> 2)) - read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath).registerTempTable("t") + sqlContext.read.format("parquet").load(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -368,7 +441,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { test("SPARK-7749 Non-partitioned table should have empty partition spec") { withTempPath { dir => (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) - val queryExecution = read.parquet(dir.getCanonicalPath).queryExecution + val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { case LogicalRelation(relation: ParquetRelation2) => assert(relation.partitionSpec === PartitionSpec.emptySpec) @@ -377,4 +450,73 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } } } + + test("SPARK-7847: Dynamic partition directory path escaping and unescaping") { + withTempPath { dir => + val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s") + df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), df.collect()) + } + } + + test("Various partition value types") { + val row = + Row( + 100.toByte, + 40000.toShort, + Int.MaxValue, + Long.MaxValue, + 1.5.toFloat, + 4.5, + new java.math.BigDecimal(new BigInteger("212500"), 5), + new java.math.BigDecimal(2.125), + java.sql.Date.valueOf("2015-05-23"), + new Timestamp(0), + "This is a string, /[]?=:", + "This is not a partition column") + + // BooleanType is not supported yet + val partitionColumnTypes = + Seq( + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType(10, 5), + DecimalType.Unlimited, + DateType, + TimestampType, + StringType) + + val partitionColumns = partitionColumnTypes.zipWithIndex.map { + case (t, index) => StructField(s"p_$index", t) + } + + val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema) + + withTempPath { dir => + df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name).cast(f.dataType)) + checkAnswer(sqlContext.read.load(dir.toString).select(fields: _*), row) + } + } + + test("SPARK-8037: Ignores files whose name starts with dot") { + withTempPath { dir => + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(dir.getCanonicalPath) + + Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store")) + Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) + + checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index b98ba09ccfc2d..de0107a361815 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -19,16 +19,17 @@ package org.apache.spark.sql.parquet import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.types._ import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ /** * A test suite that tests various Parquet queries. */ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ + import sqlContext.sql test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { @@ -39,22 +40,22 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(Seq("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), data.map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -111,10 +112,22 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { List(Row("same", "run_5", 100))) } } + + test("SPARK-6917 DecimalType should work with non-native types") { + val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new java.sql.Timestamp(i))) + val schema = StructType(List(StructField("d", DecimalType(18, 0), false), + StructField("time", TimestampType, false)).toArray) + withTempPath { file => + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema) + df.write.parquet(file.getCanonicalPath) + val df2 = sqlContext.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } } class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") @@ -126,7 +139,7 @@ class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAnd } class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index c964b6d984557..171a656f0e01e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -20,15 +20,14 @@ package org.apache.spark.sql.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.scalatest.FunSuite -import parquet.schema.MessageTypeParser +import org.apache.parquet.schema.MessageTypeParser +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -class ParquetSchemaSuite extends FunSuite with ParquetTest { - val sqlContext = TestSQLContext +class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 516ba373f41d2..eb15a1609f1d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -33,8 +33,6 @@ import org.apache.spark.sql.{DataFrame, SaveMode} * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ private[sql] trait ParquetTest extends SQLTestUtils { - import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder} - import sqlContext.sparkContext /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` @@ -44,7 +42,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).toDF().write.parquet(file.getCanonicalPath) + sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -75,7 +73,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index d2d1011b8e917..a71088430bfd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -26,18 +26,20 @@ import org.apache.spark.util.Utils class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql + + private lazy val sparkContext = caseInsensitiveContext.sparkContext var path: File = null override def beforeAll(): Unit = { path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - read.json(rdd).registerTempTable("jt") + caseInsensitiveContext.read.json(rdd).registerTempTable("jt") } override def afterAll(): Unit = { - dropTempTable("jt") + caseInsensitiveContext.dropTempTable("jt") } after { @@ -59,7 +61,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a, b FROM jsonTable"), sql("SELECT a, b FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") } test("CREATE TEMPORARY TABLE AS SELECT based on the file without write permission") { @@ -129,7 +131,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT * FROM jsonTable"), sql("SELECT a * 4 FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") // Explicitly delete the data. if (path.exists()) Utils.deleteRecursively(path) @@ -147,7 +149,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT * FROM jsonTable"), sql("SELECT b FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") } test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index f5106f67a08df..51d22b6a1378a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -43,7 +43,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo StructField("bigintType", LongType, nullable = false), StructField("tinyintType", ByteType, nullable = false), StructField("decimalType", DecimalType.Unlimited, nullable = false), - StructField("fixedDecimalType", DecimalType(5,1), nullable = false), + StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), StructField("binaryType", BinaryType, nullable = false), StructField("booleanType", BooleanType, nullable = false), StructField("smallIntType", ShortType, nullable = false), @@ -51,8 +51,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo StructField("mapType", MapType(StringType, StringType)), StructField("arrayType", ArrayType(StringType)), StructField("structType", - StructType(StructField("f1",StringType) :: - (StructField("f2",IntegerType)) :: Nil + StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil ) ) )) @@ -64,19 +63,18 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo } class DDLTestSuite extends DataSourceTest { - import caseInsensitiveContext._ before { - sql( - """ - |CREATE TEMPORARY TABLE ddlPeople - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - |) - """.stripMargin) + caseInsensitiveContext.sql( + """ + |CREATE TEMPORARY TABLE ddlPeople + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) } sqlTest( @@ -101,7 +99,8 @@ class DDLTestSuite extends DataSourceTest { )) test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { - val attributes = sql("describe ddlPeople").queryExecution.executedPlan.output + val attributes = caseInsensitiveContext.sql("describe ddlPeople") + .queryExecution.executedPlan.output assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) assert(attributes.map(_.dataType).toSet === Set(StringType)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 24ed665c67d2e..3f77960d09246 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -17,14 +17,18 @@ package org.apache.spark.sql.sources +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.test.TestSQLContext -import org.scalatest.BeforeAndAfter + abstract class DataSourceTest extends QueryTest with BeforeAndAfter { // We want to test some edge cases. - implicit val caseInsensitiveContext = new SQLContext(TestSQLContext.sparkContext) + protected implicit lazy val caseInsensitiveContext = { + val ctx = new SQLContext(TestSQLContext.sparkContext) + ctx.setConf(SQLConf.CASE_SENSITIVE, "false") + ctx + } - caseInsensitiveContext.setConf(SQLConf.CASE_SENSITIVE, "false") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index cce747e7dbf64..81b3a0f0c5b3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -97,7 +97,7 @@ object FiltersPushed { class FilteredScanSuite extends DataSourceTest { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql before { sql( @@ -154,7 +154,7 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT a, b FROM oneToTenFiltered WHERE a IN (1,3,5)", - Seq(1,3,5).map(i => Row(i, i * 2))) + Seq(1, 3, 5).map(i => Row(i, i * 2))) sqlTest( "SELECT a, b FROM oneToTenFiltered WHERE A = 1", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 6f375ef36237d..0b7c46c482c88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -26,14 +26,16 @@ import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with BeforeAndAfterAll { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql + + private lazy val sparkContext = caseInsensitiveContext.sparkContext var path: File = null override def beforeAll: Unit = { path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - read.json(rdd).registerTempTable("jt") + caseInsensitiveContext.read.json(rdd).registerTempTable("jt") sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) @@ -45,8 +47,8 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } override def afterAll: Unit = { - dropTempTable("jsonTable") - dropTempTable("jt") + caseInsensitiveContext.dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jt") Utils.deleteRecursively(path) } @@ -109,7 +111,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { // Writing the table to less part files. val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5) - read.json(rdd1).registerTempTable("jt1") + caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1 @@ -121,7 +123,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { // Writing the table to more part files. val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10) - read.json(rdd2).registerTempTable("jt2") + caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2 @@ -140,8 +142,8 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { (1 to 10).map(i => Row(i * 10, s"str$i")) ) - dropTempTable("jt1") - dropTempTable("jt2") + caseInsensitiveContext.dropTempTable("jt1") + caseInsensitiveContext.dropTempTable("jt2") } test("INSERT INTO not supported for JSONRelation for now") { @@ -154,13 +156,14 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } test("save directly to the path of a JSON table") { - table("jt").selectExpr("a * 5 as a", "b").write.mode(SaveMode.Overwrite).json(path.toString) + caseInsensitiveContext.table("jt").selectExpr("a * 5 as a", "b") + .write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 5, s"str$i")) ) - table("jt").write.mode(SaveMode.Overwrite).json(path.toString) + caseInsensitiveContext.table("jt").write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) @@ -181,7 +184,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { test("Caching") { // Cached Query Execution - cacheTable("jsonTable") + caseInsensitiveContext.cacheTable("jsonTable") assertCached(sql("SELECT * FROM jsonTable")) checkAnswer( sql("SELECT * FROM jsonTable"), @@ -220,7 +223,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a * 2, b FROM jt").collect()) // Verify uncaching - uncacheTable("jsonTable") + caseInsensitiveContext.uncacheTable("jsonTable") assertCached(sql("SELECT * FROM jsonTable"), 0) } @@ -251,6 +254,6 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { "It is not allowed to insert into a table that is not an InsertableRelation." ) - dropTempTable("oneToTen") + caseInsensitiveContext.dropTempTable("oneToTen") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index c2bc52e2120c1..257526feab945 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -52,10 +52,9 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo } class PrunedScanSuite extends DataSourceTest { - import caseInsensitiveContext._ before { - sql( + caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE oneToTenPruned |USING org.apache.spark.sql.sources.PrunedScanSource @@ -115,7 +114,7 @@ class PrunedScanSuite extends DataSourceTest { def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { - val queryExecution = sql(sqlString).queryExecution + val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { case p: execution.PhysicalRDD => p } match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 8331a14c9295c..296b0d6f74a0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.sources -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ResolvedDataSourceSuite extends FunSuite { +class ResolvedDataSourceSuite extends SparkFunSuite { test("builtin sources") { assert(ResolvedDataSource.lookupDataSource("jdbc") === diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 274c652dd14d6..b032515a9d28c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -27,7 +27,9 @@ import org.apache.spark.util.Utils class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql + + private lazy val sparkContext = caseInsensitiveContext.sparkContext var originalDefaultSource: String = null @@ -36,60 +38,63 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { var df: DataFrame = null override def beforeAll(): Unit = { - originalDefaultSource = conf.defaultDataSourceName + originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName path = Utils.createTempDir() path.delete() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - df = read.json(rdd) + df = caseInsensitiveContext.read.json(rdd) df.registerTempTable("jsonTable") } override def afterAll(): Unit = { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) } after { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) Utils.deleteRecursively(path) } def checkLoad(): Unit = { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - checkAnswer(read.load(path.toString), df.collect()) + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + checkAnswer(caseInsensitiveContext.read.load(path.toString), df.collect()) // Test if we can pick up the data source name passed in load. - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - checkAnswer(read.format("json").load(path.toString), df.collect()) - checkAnswer(read.format("json").load(path.toString), df.collect()) + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect()) + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect()) val schema = StructType(StructField("b", StringType, true) :: Nil) checkAnswer( - read.format("json").schema(schema).load(path.toString), + caseInsensitiveContext.read.format("json").schema(schema).load(path.toString), sql("SELECT b FROM jsonTable").collect()) } test("save with path and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") df.write.save(path.toString) checkLoad() } test("save with string mode and path, and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") path.createNewFile() df.write.mode("overwrite").save(path.toString) checkLoad() } test("save with path and datasource, and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") df.write.json(path.toString) checkLoad() } test("save with data source and options, and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") df.write.mode(SaveMode.ErrorIfExists).json(path.toString) checkLoad() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 77af04a491742..5d4ecd810862c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -88,9 +88,9 @@ case class AllDataTypesScan( } class TableScanSuite extends DataSourceTest { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql - var tableWithSchemaExpected = (1 to 10).map { i => + private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( s"str_$i", s"str_$i", @@ -215,7 +215,7 @@ class TableScanSuite extends DataSourceTest { Nil ) - assert(expectedSchema == table("tableWithSchema").schema) + assert(expectedSchema == caseInsensitiveContext.table("tableWithSchema").schema) checkAnswer( sql( @@ -270,7 +270,7 @@ class TableScanSuite extends DataSourceTest { test("Caching") { // Cached Query Execution - cacheTable("oneToTen") + caseInsensitiveContext.cacheTable("oneToTen") assertCached(sql("SELECT * FROM oneToTen")) checkAnswer( sql("SELECT * FROM oneToTen"), @@ -297,7 +297,7 @@ class TableScanSuite extends DataSourceTest { (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching - uncacheTable("oneToTen") + caseInsensitiveContext.uncacheTable("oneToTen") assertCached(sql("SELECT * FROM oneToTen"), 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index ca66cdc48272d..ac4a00a6f3dac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -25,11 +25,9 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils trait SQLTestUtils { - val sqlContext: SQLContext + def sqlContext: SQLContext - import sqlContext.{conf, sparkContext} - - protected def configuration = sparkContext.hadoopConfiguration + protected def configuration = sqlContext.sparkContext.hadoopConfiguration /** * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL @@ -39,12 +37,12 @@ trait SQLTestUtils { */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) - (keys, values).zipped.foreach(conf.setConf) + val currentValues = keys.map(key => Try(sqlContext.conf.getConf(key)).toOption) + (keys, values).zipped.foreach(sqlContext.conf.setConf) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => conf.setConf(key, value) - case (key, None) => conf.unsetConf(key) + case (key, Some(value)) => sqlContext.conf.setConf(key, value) + case (key, None) => sqlContext.conf.unsetConf(key) } } } @@ -75,14 +73,18 @@ trait SQLTestUtils { /** * Drops temporary table `tableName` after calling `f`. */ - protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally sqlContext.dropTempTable(tableName) + protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(sqlContext.dropTempTable) } /** * Drops table `tableName` after calling `f`. */ - protected def withTable(tableName: String)(f: => Unit): Unit = { - try f finally sqlContext.sql(s"DROP TABLE IF EXISTS $tableName") + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + sqlContext.sql(s"DROP TABLE IF EXISTS $name") + } + } } } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 437f697d25bf3..73e6ccdb1eaf8 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ spark-hive_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + com.google.guava guava diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 3458b04bfba0f..c9da25253e13f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,23 +17,24 @@ package org.apache.spark.sql.hive.thriftserver +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} -import org.apache.spark.sql.SQLConf -import org.apache.spark.{SparkContext, SparkConf, Logging} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListenerApplicationEnd, SparkListener} import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkContext} -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a @@ -51,6 +52,7 @@ object HiveThriftServer2 extends Logging { @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { val server = new HiveThriftServer2(sqlContext) + sqlContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) server.init(sqlContext.hiveconf) server.start() listener = new HiveThriftServer2Listener(server, sqlContext.conf) diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala similarity index 52% rename from sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala rename to sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index b9d4f1c58c982..e071103df925c 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -17,78 +17,55 @@ package org.apache.spark.sql.hive.thriftserver +import java.security.PrivilegedExceptionAction import java.sql.{Date, Timestamp} -import java.util.concurrent.Executors -import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, UUID} - -import org.apache.commons.logging.Log -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.service.cli.thrift.TProtocolVersion -import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager +import java.util.concurrent.RejectedExecutionException +import java.util.{Map => JMap, UUID} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} +import scala.util.control.NonFatal +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.cli._ +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.ql.metadata.HiveException +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.cli.operation.ExecuteStatementOperation -import org.apache.hive.service.cli.session.{SessionManager, HiveSession} +import org.apache.hive.service.cli.session.HiveSession -import org.apache.spark.{SparkContext, Logging} -import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} +import org.apache.spark.Logging import org.apache.spark.sql.execution.SetCommand -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} -/** - * A compatibility layer for interacting with Hive version 0.13.1. - */ -private[thriftserver] object HiveThriftServerShim { - val version = "0.13.1" - - def setServerUserName( - sparkServiceUGI: UserGroupInformation, - sparkCliService:SparkSQLCLIService) = { - setSuperField(sparkCliService, "serviceUGI", sparkServiceUGI) - } -} - -private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) - extends AbstractSparkSQLDriver(_context) { - override def getResults(res: JList[_]): Boolean = { - if (hiveResponse == null) { - false - } else { - res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) - hiveResponse = null - true - } - } -} private[hive] class SparkExecuteStatementOperation( parentSession: HiveSession, statement: String, confOverlay: JMap[String, String], - runInBackground: Boolean = true)( - hiveContext: HiveContext, - sessionToActivePool: SMap[SessionHandle, String]) - // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution - extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging { + runInBackground: Boolean = true) + (hiveContext: HiveContext, sessionToActivePool: SMap[SessionHandle, String]) + extends ExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground) + with Logging { private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ + private var statementId: String = _ def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. - logDebug("CLOSING") + hiveContext.sparkContext.clearJobGroup() + logDebug(s"CLOSING $statementId") + cleanup(OperationState.CLOSED) } - def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { + def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { dataTypes(ordinal) match { case StringType => to += from.getString(ordinal) @@ -149,10 +126,10 @@ private[hive] class SparkExecuteStatementOperation( } def getResultSetSchema: TableSchema = { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - if (result.queryExecution.analyzed.output.size == 0) { + if (result == null || result.queryExecution.analyzed.output.size == 0) { new TableSchema(new FieldSchema("Result", "string", "") :: Nil) } else { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") val schema = result.queryExecution.analyzed.output.map { attr => new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") } @@ -160,9 +137,73 @@ private[hive] class SparkExecuteStatementOperation( } } - def run(): Unit = { - val statementId = UUID.randomUUID().toString - logInfo(s"Running query '$statement'") + override def run(): Unit = { + setState(OperationState.PENDING) + setHasResultSet(true) // avoid no resultset for async run + + if (!runInBackground) { + runInternal() + } else { + val parentSessionState = SessionState.get() + val hiveConf = getConfigForOperation() + val sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) + val sessionHive = getCurrentHive() + val currentSqlSession = hiveContext.currentSession + + // Runnable impl to call runInternal asynchronously, + // from a different thread + val backgroundOperation = new Runnable() { + + override def run(): Unit = { + val doAsAction = new PrivilegedExceptionAction[Object]() { + override def run(): Object = { + + // User information is part of the metastore client member in Hive + hiveContext.setSession(currentSqlSession) + Hive.set(sessionHive) + SessionState.setCurrentSessionState(parentSessionState) + try { + runInternal() + } catch { + case e: HiveSQLException => + setOperationException(e) + log.error("Error running hive query: ", e) + } + return null + } + } + + try { + ShimLoader.getHadoopShims().doAs(sparkServiceUGI, doAsAction) + } catch { + case e: Exception => + setOperationException(new HiveSQLException(e)) + logError("Error running hive query as user : " + + sparkServiceUGI.getShortUserName(), e) + } + } + } + try { + // This submit blocks if no background threads are available to run this operation + val backgroundHandle = + getParentSession().getSessionManager().submitBackgroundOperation(backgroundOperation) + setBackgroundHandle(backgroundHandle) + } catch { + case rejected: RejectedExecutionException => + setState(OperationState.ERROR) + throw new HiveSQLException("The background threadpool cannot accept" + + " new task for execution, please retry the operation", rejected) + case NonFatal(e) => + logError(s"Error executing query in background", e) + setState(OperationState.ERROR) + throw e + } + } + } + + private def runInternal(): Unit = { + statementId = UUID.randomUUID().toString + logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) HiveThriftServer2.listener.onStatementStart( statementId, @@ -194,63 +235,82 @@ private[hive] class SparkExecuteStatementOperation( } } dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray - setHasResultSet(true) } catch { + case e: HiveSQLException => + if (getStatus().getState() == OperationState.CANCELED) { + return + } else { + setState(OperationState.ERROR); + throw e + } // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. case e: Throwable => + val currentState = getStatus().getState() + logError(s"Error executing query, currentState $currentState, ", e) setState(OperationState.ERROR) HiveThriftServer2.listener.onStatementError( statementId, e.getMessage, e.getStackTraceString) - logError("Error executing query:", e) throw new HiveSQLException(e.toString) } setState(OperationState.FINISHED) HiveThriftServer2.listener.onStatementFinish(statementId) } -} - -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager - with ReflectedCompositeService { - - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) - - override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) + override def cancel(): Unit = { + logInfo(s"Cancel '$statement' with $statementId") + if (statementId != null) { + hiveContext.sparkContext.cancelJobGroup(statementId) + } + cleanup(OperationState.CANCELED) } - override def openSession( - protocol: TProtocolVersion, - username: String, - passwd: String, - sessionConf: java.util.Map[String, String], - withImpersonation: Boolean, - delegationToken: String): SessionHandle = { - hiveContext.openSession() - val sessionHandle = super.openSession( - protocol, username, passwd, sessionConf, withImpersonation, delegationToken) - val session = super.getSession(sessionHandle) - HiveThriftServer2.listener.onSessionCreated( - session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) - sessionHandle + private def cleanup(state: OperationState) { + setState(state) + if (runInBackground) { + val backgroundHandle = getBackgroundHandle() + if (backgroundHandle != null) { + backgroundHandle.cancel(true) + } + } } - override def closeSession(sessionHandle: SessionHandle) { - HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) - super.closeSession(sessionHandle) - sparkSqlOperationManager.sessionToActivePool -= sessionHandle + /** + * If there are query specific settings to overlay, then create a copy of config + * There are two cases we need to clone the session config that's being passed to hive driver + * 1. Async query - + * If the client changes a config setting, that shouldn't reflect in the execution + * already underway + * 2. confOverlay - + * The query specific settings should only be applied to the query config and not session + * @return new configuration + * @throws HiveSQLException + */ + private def getConfigForOperation(): HiveConf = { + var sqlOperationConf = getParentSession().getHiveConf() + if (!getConfOverlay().isEmpty() || runInBackground) { + // clone the partent session config for this query + sqlOperationConf = new HiveConf(sqlOperationConf) + + // apply overlay query specific settings, if any + getConfOverlay().foreach { case (k, v) => + try { + sqlOperationConf.verifyAndSet(k, v) + } catch { + case e: IllegalArgumentException => + throw new HiveSQLException("Error applying statement specific settings", e) + } + } + } + return sqlOperationConf + } - hiveContext.detachSession() + private def getCurrentHive(): Hive = { + try { + return Hive.get() + } catch { + case e: HiveException => + throw new HiveSQLException("Failed to get current Hive object", e); + } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index deb1008c468bf..039cfa40d26b3 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -32,18 +32,18 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.thrift.transport.TSocket import org.apache.spark.Logging -import org.apache.spark.sql.hive.{HiveContext, HiveShim} +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" private var continuedPrompt = "".padTo(prompt.length, ' ') - private var transport:TSocket = _ + private var transport: TSocket = _ installSignalHandler() @@ -267,7 +267,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } else { var ret = 0 val hconf = conf.asInstanceOf[HiveConf] - val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf) + val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf) if (proc != null) { if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || @@ -276,13 +276,13 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { driver.init() val out = sessionState.out - val start:Long = System.currentTimeMillis() + val start: Long = System.currentTimeMillis() if (sessionState.getIsVerbose) { out.println(cmd) } val rc = driver.run(cmd) val end = System.currentTimeMillis() - val timeTaken:Double = (end - start) / 1000.0 + val timeTaken: Double = (end - start) / 1000.0 ret = rc.getResponseCode if (ret != 0) { @@ -310,7 +310,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { res.clear() } } catch { - case e:IOException => + case e: IOException => console.printError( s"""Failed with exception ${e.getClass.getName}: ${e.getMessage} |${org.apache.hadoop.util.StringUtils.stringifyException(e)} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 499e077d7294a..41f647d5f8c5a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -21,8 +21,6 @@ import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException -import scala.collection.JavaConversions._ - import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.ShimLoader @@ -34,7 +32,8 @@ import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.util.Utils + +import scala.collection.JavaConversions._ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) extends CLIService @@ -52,7 +51,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) try { HiveAuthFactory.loginFromKeytab(hiveConf) sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) - HiveThriftServerShim.setServerUserName(sparkServiceUGI, this) + setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala similarity index 86% rename from sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala rename to sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 48ac9062af96a..77272aecf2835 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ +import java.util.{ArrayList => JArrayList, List => JList} import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} @@ -27,8 +27,12 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -private[hive] abstract class AbstractSparkSQLDriver( - val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver with Logging { +import scala.collection.JavaConversions._ + +private[hive] class SparkSQLDriver( + val context: HiveContext = SparkSQLEnv.hiveContext) + extends Driver + with Logging { private[hive] var tableSchema: Schema = _ private[hive] var hiveResponse: Seq[String] = _ @@ -71,6 +75,16 @@ private[hive] abstract class AbstractSparkSQLDriver( 0 } + override def getResults(res: JList[_]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) + hiveResponse = null + true + } + } + override def getSchema: Schema = tableSchema override def destroy() { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 7c0c505e2d61e..79eda1f5123bf 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -22,7 +22,7 @@ import java.io.PrintStream import scala.collection.JavaConversions._ import org.apache.spark.scheduler.StatsReportListener -import org.apache.spark.sql.hive.{HiveShim, HiveContext} +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.util.Utils @@ -56,7 +56,7 @@ private[hive] object SparkSQLEnv extends Logging { hiveContext.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) hiveContext.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) - hiveContext.setConf("spark.sql.hive.version", HiveShim.version) + hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) if (log.isDebugEnabled) { hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala new file mode 100644 index 0000000000000..2d5ee68002286 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.concurrent.Executors + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.SessionHandle +import org.apache.hive.service.cli.session.SessionManager +import org.apache.hive.service.cli.thrift.TProtocolVersion + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + + +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { + + private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) + } + + override def openSession( + protocol: TProtocolVersion, + username: String, + passwd: String, + sessionConf: java.util.Map[String, String], + withImpersonation: Boolean, + delegationToken: String): SessionHandle = { + hiveContext.openSession() + val sessionHandle = super.openSession( + protocol, username, passwd, sessionConf, withImpersonation, delegationToken) + val session = super.getSession(sessionHandle) + HiveThriftServer2.listener.onSessionCreated( + session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) + sessionHandle + } + + override def closeSession(sessionHandle: SessionHandle) { + HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) + super.closeSession(sessionHandle) + sparkSqlOperationManager.sessionToActivePool -= sessionHandle + + hiveContext.detachSession() + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 9c0bf02391e0e..c8031ed0f3437 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -44,9 +44,12 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) confOverlay: JMap[String, String], async: Boolean): ExecuteStatementOperation = synchronized { - val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)( - hiveContext, sessionToActivePool) + val runInBackground = async && hiveContext.hiveThriftServerAsync + val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, + runInBackground)(hiveContext, sessionToActivePool) handleToOperation.put(operation.getHandle, operation) + logDebug(s"Created Operation for $statement with session=$parentSession, " + + s"runInBackground=$runInBackground") operation } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 6a2be4a58e5cb..10c83d8b27a2a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -47,7 +47,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" ++ generateSessionStatsTable() ++ generateSQLStatsTable() - UIUtils.headerSparkPage("ThriftServer", content, parent, Some(5000)) + UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -77,7 +77,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" [{id}] } - val detail = if(info.state == ExecutionState.FAILED) info.detail else info.executePlan + val detail = if (info.state == ExecutionState.FAILED) info.detail else info.executePlan {info.userName} @@ -85,7 +85,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {info.groupId} {formatDate(info.startTimestamp)} - {if(info.finishTimestamp > 0) formatDate(info.finishTimestamp)} + {if (info.finishTimestamp > 0) formatDate(info.finishTimestamp)} {formatDurationOption(Some(info.totalTime))} {info.statement} {info.state} @@ -143,14 +143,14 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/ThriftServer/session?id=%s" + val sessionLink = "%s/sql/session?id=%s" .format(UIUtils.prependBaseUri(parent.basePath), session.sessionId) {session.userName} {session.ip} {session.sessionId} {formatDate(session.startTimestamp)} - {if(session.finishTimestamp > 0) formatDate(session.finishTimestamp)} + {if (session.finishTimestamp > 0) formatDate(session.finishTimestamp)} {formatDurationOption(Some(session.totalTime))} {session.totalExecution.toString} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 33ba038ecce73..3b01afa603cea 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -55,7 +55,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) Total run {sessionStat._2.totalExecution} SQL ++ generateSQLStatsTable(sessionStat._2.sessionId) - UIUtils.headerSparkPage("ThriftServer", content, parent, Some(5000)) + UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) } /** Generate basic stats of the streaming program */ @@ -87,7 +87,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) [{id}] } - val detail = if(info.state == ExecutionState.FAILED) info.detail else info.executePlan + val detail = if (info.state == ExecutionState.FAILED) info.detail else info.executePlan {info.userName} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 343031f10c75c..94fd8a6bb60b9 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -27,7 +27,9 @@ import org.apache.spark.{SparkContext, Logging, SparkException} * This assumes the given SparkContext has enabled its SparkUI. */ private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) - extends SparkUITab(getSparkUI(sparkContext), "ThriftServer") with Logging { + extends SparkUITab(getSparkUI(sparkContext), "sql") with Logging { + + override val name = "SQL" val parent = getSparkUI(sparkContext) val listener = HiveThriftServer2.listener diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index b070fa8eaa469..13b0c5951dddc 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -25,12 +25,16 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.Utils -class CliSuite extends FunSuite with BeforeAndAfter with Logging { +/** + * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary + * Hive metastore and warehouse. + */ +class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() @@ -58,13 +62,13 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging { | --master local | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --driver-class-path ${sys.props("java.class.path")} """.stripMargin.split("\\s+").toSeq ++ extraArgs } var next = 0 val foundAllExpectedAnswers = Promise.apply[Unit]() - val queryStream = new ByteArrayInputStream(queries.mkString("\n").getBytes) + // Explicitly adds ENTER for each statement to make sure they are actually entered into the CLI. + val queryStream = new ByteArrayInputStream(queries.map(_ + "\n").mkString.getBytes) val buffer = new ArrayBuffer[String]() val lock = new Object @@ -124,12 +128,12 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging { "SELECT COUNT(*) FROM hive_test;" -> "5", "DROP TABLE hive_test;" - -> "Time taken: " + -> "OK" ) } test("Single command with -e") { - runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") + runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") } test("Single command with --database") { @@ -151,4 +155,33 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging { -> "hive_test" ) } + + test("Commands using SerDe provided in --jars") { + val jarFile = + "../hive/src/test/resources/hive-hcatalog-core-0.13.1.jar" + .split("/") + .mkString(File.separator) + + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + + runCliWithin(3.minute, Seq("--jars", s"$jarFile"))( + """CREATE TABLE t1(key string, val string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; + """.stripMargin + -> "OK", + "CREATE TABLE sourceTable (key INT, val STRING);" + -> "OK", + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;" + -> "OK", + "INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;" + -> "Time taken:", + "SELECT count(key) FROM t1;" + -> "5", + "DROP TABLE t1;" + -> "OK", + "DROP TABLE sourceTable;" + -> "OK" + ) + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 1fadea97fd07f..178bd1f5cb164 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -19,14 +19,18 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.URL -import java.sql.{Date, DriverManager, Statement} +import java.nio.charset.StandardCharsets +import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ -import scala.concurrent.{Await, Promise} +import scala.concurrent.{Await, Promise, future} +import scala.concurrent.ExecutionContext.Implicits.global import scala.sys.process.{Process, ProcessLogger} import scala.util.{Random, Try} +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver import org.apache.hive.service.auth.PlainSaslHelper @@ -35,10 +39,10 @@ import org.apache.hive.service.cli.thrift.TCLIService.Client import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.Logging -import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils object TestData { @@ -54,7 +58,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { override def mode: ServerMode.Value = ServerMode.binary private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = { - // Transport creation logics below mimics HiveConnection.createBinaryTransport + // Transport creation logic below mimics HiveConnection.createBinaryTransport val rawTransport = new TSocket("localhost", serverPort) val user = System.getProperty("user.name") val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) @@ -109,7 +113,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + assert(resultSet.getString(1) === + s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") } } @@ -335,6 +340,42 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } ) } + + test("test jdbc cancel") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_map", + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") + + queries.foreach(statement.execute) + + val largeJoin = "SELECT COUNT(*) FROM test_map " + + List.fill(10)("join test_map").mkString(" ") + val f = future { Thread.sleep(100); statement.cancel(); } + val e = intercept[SQLException] { + statement.executeQuery(largeJoin) + } + assert(e.getMessage contains "cancelled") + Await.result(f, 3.minute) + + // cancel is a noop + statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") + val sf = future { Thread.sleep(100); statement.cancel(); } + val smallJoin = "SELECT COUNT(*) FROM test_map " + + List.fill(4)("join test_map").mkString(" ") + val rs1 = statement.executeQuery(smallJoin) + Await.result(sf, 3.minute) + rs1.next() + assert(rs1.getInt(1) === math.pow(5, 5)) + rs1.close() + + val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map") + rs2.next() + assert(rs2.getInt(1) === 5) + rs2.close() + } + } } class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { @@ -363,7 +404,8 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + assert(resultSet.getString(1) === + s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") } } } @@ -391,10 +433,10 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { val statements = connections.map(_.createStatement()) try { - statements.zip(fs).map { case (s, f) => f(s) } + statements.zip(fs).foreach { case (s, f) => f(s) } } finally { - statements.map(_.close()) - connections.map(_.close()) + statements.foreach(_.close()) + connections.foreach(_.close()) } } @@ -403,7 +445,7 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { } } -abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging { +abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAll with Logging { def mode: ServerMode.Value private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$") @@ -433,15 +475,33 @@ abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll wit ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT } + val driverClassPath = { + // Writes a temporary log4j.properties and prepend it to driver classpath, so that it + // overrides all other potential log4j configurations contained in other dependency jar files. + val tempLog4jConf = Utils.createTempDir().getCanonicalPath + + Files.write( + """log4j.rootCategory=INFO, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + """.stripMargin, + new File(s"$tempLog4jConf/log4j.properties"), + UTF_8) + + tempLog4jConf + File.pathSeparator + sys.props("java.class.path") + } + s"""$startScript | --master local - | --hiveconf hive.root.logger=INFO,console | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode | --hiveconf $portConf=$port - | --driver-class-path ${sys.props("java.class.path")} + | --driver-class-path $driverClassPath + | --driver-java-options -Dlog4j.debug | --conf spark.ui.enabled=false """.stripMargin.split("\\s+").toSeq } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index a286dc5825f77..4c9fab7ef6136 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -73,7 +73,7 @@ class UISeleniumSuite } ignore("thrift server ui test") { - withJdbcStatement(statement =>{ + withJdbcStatement { statement => val baseURL = s"http://localhost:$uiPort" val queries = Seq( @@ -84,11 +84,11 @@ class UISeleniumSuite eventually(timeout(10 seconds), interval(50 milliseconds)) { go to baseURL - find(cssSelector("""ul li a[href*="ThriftServer"]""")) should not be None + find(cssSelector("""ul li a[href*="sql"]""")) should not be None } eventually(timeout(10 seconds), interval(50 milliseconds)) { - go to (baseURL + "/ThriftServer") + go to (baseURL + "/sql") find(id("sessionstat")) should not be None find(id("sqlstat")) should not be None @@ -97,6 +97,6 @@ class UISeleniumSuite findAll(cssSelector("""ul table tbody tr td""")).map(_.text).toList should contain (line) } } - }) + } } } diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala deleted file mode 100644 index b3a79ba1c7d6b..0000000000000 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ /dev/null @@ -1,278 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import java.sql.{Date, Timestamp} -import java.util.concurrent.Executors -import java.util.{ArrayList => JArrayList, Map => JMap, UUID} - -import org.apache.commons.logging.Log -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.service.cli.thrift.TProtocolVersion -import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, Map => SMap} - -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.shims.ShimLoader -import org.apache.hadoop.security.UserGroupInformation -import org.apache.hive.service.cli._ -import org.apache.hive.service.cli.operation.ExecuteStatementOperation -import org.apache.hive.service.cli.session.{SessionManager, HiveSession} - -import org.apache.spark.Logging -import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} -import org.apache.spark.sql.execution.SetCommand -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.types._ - -/** - * A compatibility layer for interacting with Hive version 0.12.0. - */ -private[thriftserver] object HiveThriftServerShim { - val version = "0.12.0" - - def setServerUserName(sparkServiceUGI: UserGroupInformation, sparkCliService:SparkSQLCLIService) = { - val serverUserName = ShimLoader.getHadoopShims.getShortUserName(sparkServiceUGI) - setSuperField(sparkCliService, "serverUserName", serverUserName) - } -} - -private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) - extends AbstractSparkSQLDriver(_context) { - override def getResults(res: JArrayList[String]): Boolean = { - if (hiveResponse == null) { - false - } else { - res.addAll(hiveResponse) - hiveResponse = null - true - } - } -} - -private[hive] class SparkExecuteStatementOperation( - parentSession: HiveSession, - statement: String, - confOverlay: JMap[String, String])( - hiveContext: HiveContext, - sessionToActivePool: SMap[SessionHandle, String]) - extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging { - - private var result: DataFrame = _ - private var iter: Iterator[SparkRow] = _ - private var dataTypes: Array[DataType] = _ - - def close(): Unit = { - // RDDs will be cleaned automatically upon garbage collection. - logDebug("CLOSING") - } - - def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { - if (!iter.hasNext) { - new RowSet() - } else { - // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int - val maxRows = maxRowsL.toInt - var curRow = 0 - var rowSet = new ArrayBuffer[Row](maxRows.min(1024)) - - while (curRow < maxRows && iter.hasNext) { - val sparkRow = iter.next() - val row = new Row() - var curCol = 0 - - while (curCol < sparkRow.length) { - if (sparkRow.isNullAt(curCol)) { - addNullColumnValue(sparkRow, row, curCol) - } else { - addNonNullColumnValue(sparkRow, row, curCol) - } - curCol += 1 - } - rowSet += row - curRow += 1 - } - new RowSet(rowSet, 0) - } - } - - def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(from(ordinal).asInstanceOf[String]) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal))) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal))) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) - case DecimalType() => - val hiveDecimal = from.getDecimal(ordinal) - to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) - case LongType => - to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal))) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal))) - case DateType => - to.addColumnValue(ColumnValue.dateValue(from(ordinal).asInstanceOf[Date])) - case TimestampType => - to.addColumnValue( - ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = HiveContext.toHiveString((from.get(ordinal), dataTypes(ordinal))) - to.addColumnValue(ColumnValue.stringValue(hiveString)) - } - } - - def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(null) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(null)) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(null)) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(null)) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(null)) - case DecimalType() => - to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) - case LongType => - to.addColumnValue(ColumnValue.longValue(null)) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(null)) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(null)) - case DateType => - to.addColumnValue(ColumnValue.dateValue(null)) - case TimestampType => - to.addColumnValue(ColumnValue.timestampValue(null)) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - to.addColumnValue(ColumnValue.stringValue(null: String)) - } - } - - def getResultSetSchema: TableSchema = { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - if (result.queryExecution.analyzed.output.size == 0) { - new TableSchema(new FieldSchema("Result", "string", "") :: Nil) - } else { - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - new TableSchema(schema) - } - } - - def run(): Unit = { - val statementId = UUID.randomUUID().toString - logInfo(s"Running query '$statement'") - setState(OperationState.RUNNING) - HiveThriftServer2.listener.onStatementStart( - statementId, parentSession.getSessionHandle.getSessionId.toString, statement, statementId) - hiveContext.sparkContext.setJobGroup(statementId, statement) - sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => - hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) - } - try { - result = hiveContext.sql(statement) - logDebug(result.queryExecution.toString()) - result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) => - sessionToActivePool(parentSession.getSessionHandle) = value - logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") - case _ => - } - HiveThriftServer2.listener.onStatementParsed(statementId, result.queryExecution.toString()) - iter = { - val useIncrementalCollect = - hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean - if (useIncrementalCollect) { - result.rdd.toLocalIterator - } else { - result.collect().iterator - } - } - dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray - setHasResultSet(true) - } catch { - // Actually do need to catch Throwable as some failures don't inherit from Exception and - // HiveServer will silently swallow them. - case e: Throwable => - setState(OperationState.ERROR) - HiveThriftServer2.listener.onStatementError( - statementId, e.getMessage, e.getStackTraceString) - logError("Error executing query:",e) - throw new HiveSQLException(e.toString) - } - setState(OperationState.FINISHED) - HiveThriftServer2.listener.onStatementFinish(statementId) - } -} - -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager - with ReflectedCompositeService { - - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) - - override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) - } - - override def openSession( - username: String, - passwd: String, - sessionConf: java.util.Map[String, String], - withImpersonation: Boolean, - delegationToken: String): SessionHandle = { - hiveContext.openSession() - val sessionHandle = super.openSession( - username, passwd, sessionConf, withImpersonation, delegationToken) - HiveThriftServer2.listener.onSessionCreated("UNKNOWN", sessionHandle.getSessionId.toString) - sessionHandle - } - - override def closeSession(sessionHandle: SessionHandle) { - HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) - super.closeSession(sessionHandle) - sparkSqlOperationManager.sessionToActivePool -= sessionHandle - - hiveContext.detachSession() - } -} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 0b1917a392901..048f78b4daa8d 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -23,7 +23,6 @@ import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.HiveShim import org.apache.spark.sql.hive.test.TestHive /** @@ -254,7 +253,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // the answer is sensitive for jdk version "udf_java_method" - ) ++ HiveShim.compatibilityBlackList + ) /** * The set of tests that are believed to be working in catalyst. Tests not on whiteList or diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index e322340094e6f..a17546d706248 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-sql_${scala.binary.version} @@ -136,16 +143,6 @@ - - hive-0.12.0 - - - com.twitter - parquet-hive-bundle - 1.5.0 - - - diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index 3f20c6142e59a..7f8449cdc282d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -29,10 +29,10 @@ import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand} private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ADD = Keyword("ADD") - protected val DFS = Keyword("DFS") + protected val ADD = Keyword("ADD") + protected val DFS = Keyword("DFS") protected val FILE = Keyword("FILE") - protected val JAR = Keyword("JAR") + protected val JAR = Keyword("JAR") protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index b64768ababef9..3b8cafb4a6c37 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.hive -import java.io.{BufferedReader, File, InputStreamReader, PrintStream} +import java.io.File +import java.net.{URL, URLClassLoader} import java.sql.Timestamp -import java.util.{ArrayList => JArrayList} -import org.apache.hadoop.hive.ql.parse.VariableSubstitution +import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.spark.sql.catalyst.ParserDialect import scala.collection.JavaConversions._ @@ -30,24 +31,20 @@ import scala.language.implicitConversions import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution -import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, QueryExecutionException, SetCommand} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} -import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy} -import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.sources.DataSourceStrategy import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -146,6 +143,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { getConf("spark.sql.hive.metastore.barrierPrefixes", "") .split(",").filterNot(_ == "") + /* + * hive thrift server use background spark sql thread pool to execute sql queries + */ + protected[hive] def hiveThriftServerAsync: Boolean = + getConf("spark.sql.hive.thriftServer.async", "true").toBoolean + @transient protected[sql] lazy val substitutor = new VariableSubstitution() @@ -188,13 +191,22 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + s"or change $HIVE_METASTORE_VERSION to $hiveExecutionVersion.") } - val jars = getClass.getClassLoader match { - case urlClassLoader: java.net.URLClassLoader => urlClassLoader.getURLs - case other => - throw new IllegalArgumentException( - "Unable to locate hive jars to connect to metastore " + - s"using classloader ${other.getClass.getName}. " + - "Please set spark.sql.hive.metastore.jars") + + // We recursively find all jars in the class loader chain, + // starting from the given classLoader. + def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { + case null => Array.empty[URL] + case urlClassLoader: URLClassLoader => + urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) + case other => allJars(other.getParent) + } + + val classLoader = Utils.getContextOrSparkClassLoader + val jars = allJars(classLoader) + if (jars.length == 0) { + throw new IllegalArgumentException( + "Unable to locate hive jars to connect to metastore. " + + "Please set spark.sql.hive.metastore.jars.") } logInfo( @@ -321,7 +333,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val tableParameters = relation.hiveQlTable.getParameters val oldTotalSize = - Option(tableParameters.get(HiveShim.getStatsSetupConstTotalSize)) + Option(tableParameters.get(StatsSetupConst.TOTAL_SIZE)) .map(_.toLong) .getOrElse(0L) val newTotalSize = getFileSizeForTable(hiveconf, relation.hiveQlTable) @@ -332,7 +344,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.client.alterTable( relation.table.copy( properties = relation.table.properties + - (HiveShim.getStatsSetupConstTotalSize -> newTotalSize.toString))) + (StatsSetupConst.TOTAL_SIZE -> newTotalSize.toString))) } case otherRelation => throw new UnsupportedOperationException( @@ -344,9 +356,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override def setConf(key: String, value: String): Unit = { super.setConf(key, value) - hiveconf.set(key, value) executionHive.runSqlHive(s"SET $key=$value") metadataHive.runSqlHive(s"SET $key=$value") + // If users put any Spark SQL setting in the spark conf (e.g. spark-defaults.conf), + // this setConf will be called in the constructor of the SQLContext. + // Also, calling hiveconf will create a default session containing a HiveConf, which + // will interfer with the creation of executionHive (which is a lazy val). So, + // we put hiveconf.set at the end of this method. + hiveconf.set(key, value) } /* A catalyst metadata catalog that points to the Hive Metastore. */ @@ -356,10 +373,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Note that HiveUDFs will be overridden by functions registered in this context. @transient - override protected[sql] lazy val functionRegistry = - new HiveFunctionRegistry with OverrideFunctionRegistry { - override def conf: CatalystConf = currentSession().conf - } + override protected[sql] lazy val functionRegistry: FunctionRegistry = + new HiveFunctionRegistry with OverrideFunctionRegistry /* An analyzer that uses the Hive metastore. */ @transient @@ -515,7 +530,7 @@ private[hive] object HiveContext { val propMap: HashMap[String, String] = HashMap() // We have to mask all properties in hive-site.xml that relates to metastore data source // as we used a local metastore here. - HiveConf.ConfVars.values().foreach { confvar => + HiveConf.ConfVars.values().foreach { confvar => if (confvar.varname.contains("datanucleus") || confvar.varname.contains("jdo")) { propMap.put(confvar.varname, confvar.defaultVal) } @@ -538,7 +553,7 @@ private[hive] object HiveContext { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) @@ -549,7 +564,7 @@ private[hive] object HiveContext { case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (decimal: java.math.BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString - HiveShim.createDecimal(decimal).toString + HiveDecimal.create(decimal).toString case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 0a694c70e4e5c..c466203cd0220 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} +import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} @@ -335,7 +336,7 @@ private[hive] trait HiveInspectors { val allRefs = si.getAllStructFieldRefs new GenericRow( allRefs.map(r => - unwrap(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) + unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector)).toArray) } @@ -350,7 +351,7 @@ private[hive] trait HiveInspectors { new HiveVarchar(s, s.size) case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) + (o: Any) => HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) case _: JavaDateObjectInspector => (o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int]) @@ -439,31 +440,31 @@ private[hive] trait HiveInspectors { case _ if a == null => null case x: PrimitiveObjectInspector => x match { // TODO we don't support the HiveVarcharObjectInspector yet. - case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a) + case _: StringObjectInspector if x.preferWritable() => getStringWritable(a) case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() - case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a) + case _: IntObjectInspector if x.preferWritable() => getIntWritable(a) case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] - case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a) + case _: BooleanObjectInspector if x.preferWritable() => getBooleanWritable(a) case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean] - case _: FloatObjectInspector if x.preferWritable() => HiveShim.getFloatWritable(a) + case _: FloatObjectInspector if x.preferWritable() => getFloatWritable(a) case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float] - case _: DoubleObjectInspector if x.preferWritable() => HiveShim.getDoubleWritable(a) + case _: DoubleObjectInspector if x.preferWritable() => getDoubleWritable(a) case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double] - case _: LongObjectInspector if x.preferWritable() => HiveShim.getLongWritable(a) + case _: LongObjectInspector if x.preferWritable() => getLongWritable(a) case _: LongObjectInspector => a.asInstanceOf[java.lang.Long] - case _: ShortObjectInspector if x.preferWritable() => HiveShim.getShortWritable(a) + case _: ShortObjectInspector if x.preferWritable() => getShortWritable(a) case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short] - case _: ByteObjectInspector if x.preferWritable() => HiveShim.getByteWritable(a) + case _: ByteObjectInspector if x.preferWritable() => getByteWritable(a) case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte] case _: HiveDecimalObjectInspector if x.preferWritable() => - HiveShim.getDecimalWritable(a.asInstanceOf[Decimal]) + getDecimalWritable(a.asInstanceOf[Decimal]) case _: HiveDecimalObjectInspector => - HiveShim.createDecimal(a.asInstanceOf[Decimal].toJavaBigDecimal) - case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a) + HiveDecimal.create(a.asInstanceOf[Decimal].toJavaBigDecimal) + case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a) case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] - case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a) + case _: DateObjectInspector if x.preferWritable() => getDateWritable(a) case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int]) - case _: TimestampObjectInspector if x.preferWritable() => HiveShim.getTimestampWritable(a) + case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a) case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp] } case x: SettableStructObjectInspector => @@ -561,8 +562,8 @@ private[hive] trait HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name) :_*), - java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) :_*)) + java.util.Arrays.asList(fields.map(f => f.name) : _*), + java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) : _*)) } /** @@ -574,31 +575,31 @@ private[hive] trait HiveInspectors { */ def toInspector(expr: Expression): ObjectInspector = expr match { case Literal(value, StringType) => - HiveShim.getStringWritableConstantObjectInspector(value) + getStringWritableConstantObjectInspector(value) case Literal(value, IntegerType) => - HiveShim.getIntWritableConstantObjectInspector(value) + getIntWritableConstantObjectInspector(value) case Literal(value, DoubleType) => - HiveShim.getDoubleWritableConstantObjectInspector(value) + getDoubleWritableConstantObjectInspector(value) case Literal(value, BooleanType) => - HiveShim.getBooleanWritableConstantObjectInspector(value) + getBooleanWritableConstantObjectInspector(value) case Literal(value, LongType) => - HiveShim.getLongWritableConstantObjectInspector(value) + getLongWritableConstantObjectInspector(value) case Literal(value, FloatType) => - HiveShim.getFloatWritableConstantObjectInspector(value) + getFloatWritableConstantObjectInspector(value) case Literal(value, ShortType) => - HiveShim.getShortWritableConstantObjectInspector(value) + getShortWritableConstantObjectInspector(value) case Literal(value, ByteType) => - HiveShim.getByteWritableConstantObjectInspector(value) + getByteWritableConstantObjectInspector(value) case Literal(value, BinaryType) => - HiveShim.getBinaryWritableConstantObjectInspector(value) + getBinaryWritableConstantObjectInspector(value) case Literal(value, DateType) => - HiveShim.getDateWritableConstantObjectInspector(value) + getDateWritableConstantObjectInspector(value) case Literal(value, TimestampType) => - HiveShim.getTimestampWritableConstantObjectInspector(value) + getTimestampWritableConstantObjectInspector(value) case Literal(value, DecimalType()) => - HiveShim.getDecimalWritableConstantObjectInspector(value) + getDecimalWritableConstantObjectInspector(value) case Literal(_, NullType) => - HiveShim.getPrimitiveNullWritableConstantObjectInspector + getPrimitiveNullWritableConstantObjectInspector case Literal(value, ArrayType(dt, _)) => val listObjectInspector = toInspector(dt) if (value == null) { @@ -658,8 +659,8 @@ private[hive] trait HiveInspectors { case _: JavaFloatObjectInspector => FloatType case _: WritableBinaryObjectInspector => BinaryType case _: JavaBinaryObjectInspector => BinaryType - case w: WritableHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(w) - case j: JavaHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(j) + case w: WritableHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(w) + case j: JavaHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(j) case _: WritableDateObjectInspector => DateType case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType @@ -668,17 +669,143 @@ private[hive] trait HiveInspectors { case _: JavaVoidObjectInspector => NullType } + private def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] + DecimalType(info.precision(), info.scale()) + } + + private def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.stringTypeInfo, getStringWritable(value)) + + private def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.intTypeInfo, getIntWritable(value)) + + private def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) + + private def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) + + private def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.longTypeInfo, getLongWritable(value)) + + private def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) + + private def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.shortTypeInfo, getShortWritable(value)) + + private def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.byteTypeInfo, getByteWritable(value)) + + private def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) + + private def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.dateTypeInfo, getDateWritable(value)) + + private def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) + + private def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) + + private def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.voidTypeInfo, null) + + private def getStringWritable(value: Any): hadoopIo.Text = + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) + + private def getIntWritable(value: Any): hadoopIo.IntWritable = + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) + + private def getDoubleWritable(value: Any): hiveIo.DoubleWritable = + if (value == null) { + null + } else { + new hiveIo.DoubleWritable(value.asInstanceOf[Double]) + } + + private def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = + if (value == null) { + null + } else { + new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) + } + + private def getLongWritable(value: Any): hadoopIo.LongWritable = + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) + + private def getFloatWritable(value: Any): hadoopIo.FloatWritable = + if (value == null) { + null + } else { + new hadoopIo.FloatWritable(value.asInstanceOf[Float]) + } + + private def getShortWritable(value: Any): hiveIo.ShortWritable = + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) + + private def getByteWritable(value: Any): hiveIo.ByteWritable = + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) + + private def getBinaryWritable(value: Any): hadoopIo.BytesWritable = + if (value == null) { + null + } else { + new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) + } + + private def getDateWritable(value: Any): hiveIo.DateWritable = + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) + + private def getTimestampWritable(value: Any): hiveIo.TimestampWritable = + if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + } + + private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = + if (value == null) { + null + } else { + // TODO precise, scale? + new hiveIo.HiveDecimalWritable( + HiveDecimal.create(value.asInstanceOf[Decimal].toJavaBigDecimal)) + } + implicit class typeInfoConversions(dt: DataType) { import org.apache.hadoop.hive.serde2.typeinfo._ import TypeInfoFactory._ + private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { + case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) + case _ => new DecimalTypeInfo( + HiveShim.UNLIMITED_DECIMAL_PRECISION, + HiveShim.UNLIMITED_DECIMAL_SCALE) + } + def toTypeInfo: TypeInfo = dt match { case ArrayType(elemType, _) => getListTypeInfo(elemType.toTypeInfo) case StructType(fields) => getStructTypeInfo( - java.util.Arrays.asList(fields.map(_.name) :_*), - java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) :_*)) + java.util.Arrays.asList(fields.map(_.name) : _*), + java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*)) case MapType(keyType, valueType, _) => getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo @@ -690,7 +817,7 @@ private[hive] trait HiveInspectors { case LongType => longTypeInfo case ShortType => shortTypeInfo case StringType => stringTypeInfo - case d: DecimalType => HiveShim.decimalTypeInfo(d) + case d: DecimalType => decimalTypeInfo(d) case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 425a4005aa2c3..5a4651a887b7c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.hive import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} + import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.metastore.Warehouse import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata._ -import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} @@ -37,7 +39,6 @@ import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode, sources} -import org.apache.spark.util.Utils /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -546,13 +547,17 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = ??? + override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + throw new UnsupportedOperationException + } /** * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def unregisterTable(tableIdentifier: Seq[String]): Unit = ??? + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { + throw new UnsupportedOperationException + } override def unregisterAllTables(): Unit = {} } @@ -592,7 +597,7 @@ private[hive] case class MetastoreRelation self: Product => - override def equals(other: scala.Any): Boolean = other match { + override def equals(other: Any): Boolean = other match { case relation: MetastoreRelation => databaseName == relation.databaseName && tableName == relation.tableName && @@ -666,8 +671,8 @@ private[hive] case class MetastoreRelation @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { - val totalSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstTotalSize) - val rawDataSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstRawDataSize) + val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) + val rawDataSize = hiveQlTable.getParameters.get(StatsSetupConst.RAW_DATA_SIZE) // TODO: check if this estimate is valid for tables after partition pruning. // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be // relatively cheap if parameters for the table are populated into the metastore. An @@ -693,11 +698,7 @@ private[hive] case class MetastoreRelation } } - val tableDesc = HiveShim.getTableDesc( - Class.forName( - hiveQlTable.getSerializationLib, - true, - Utils.getContextOrSparkClassLoader).asInstanceOf[Class[Deserializer]], + val tableDesc = new TableDesc( hiveQlTable.getInputFormatClass, // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to @@ -707,25 +708,25 @@ private[hive] case class MetastoreRelation hiveQlTable.getMetadata ) - implicit class SchemaAttribute(f: FieldSchema) { + implicit class SchemaAttribute(f: HiveColumn) { def toAttribute: AttributeReference = AttributeReference( - f.getName, - HiveMetastoreTypes.toDataType(f.getType), + f.name, + HiveMetastoreTypes.toDataType(f.hiveType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true )(qualifiers = Seq(alias.getOrElse(tableName))) } - // Must be a stable value since new attributes are born here. - val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) + /** PartitionKey attributes */ + val partitionKeys = table.partitionColumns.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = hiveQlTable.getCols.map(_.toAttribute) + val attributes = table.schema.map(_.toAttribute) val output = attributes ++ partitionKeys /** An attribute map that can be used to lookup original attributes based on expression id. */ - val attributeMap = AttributeMap(output.map(o => (o,o))) + val attributeMap = AttributeMap(output.map(o => (o, o))) /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) @@ -739,6 +740,11 @@ private[hive] case class MetastoreRelation private[hive] object HiveMetastoreTypes { def toDataType(metastoreType: String): DataType = DataTypeParser.parse(metastoreType) + def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { + case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" + case _ => s"decimal($HiveShim.UNLIMITED_DECIMAL_PRECISION,$HiveShim.UNLIMITED_DECIMAL_SCALE)" + } + def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" case StructType(fields) => @@ -755,7 +761,7 @@ private[hive] object HiveMetastoreTypes { case BinaryType => "binary" case BooleanType => "boolean" case DateType => "date" - case d: DecimalType => HiveShim.decimalMetastoreString(d) + case d: DecimalType => decimalMetastoreString(d) case TimestampType => "timestamp" case NullType => "void" case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 2cbb5ca4d2e0c..9544d12c9053c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive import java.sql.Date -import scala.collection.mutable.ArrayBuffer - import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.ql.{ErrorMsg, Context} @@ -39,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.sources.DescribeCommand +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ @@ -46,6 +45,7 @@ import org.apache.spark.util.random.RandomSampler /* Implicit conversions */ import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer /** * Used when we need to start parsing the AST before deciding that we are going to pass the command @@ -57,7 +57,7 @@ private[hive] case object NativePlaceholder extends LogicalPlan { override def output: Seq[Attribute] = Seq.empty } -case class CreateTableAsSelect( +private[hive] case class CreateTableAsSelect( tableDesc: HiveTable, child: LogicalPlan, allowExisting: Boolean) extends UnaryNode with Command { @@ -665,7 +665,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C HiveColumn(field.getName, field.getType, field.getComment) }) } - case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil)=> + case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => val serdeParams = new java.util.HashMap[String, String]() child match { case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => @@ -775,7 +775,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" case Token("TOK_TRUNCATETABLE", - Token("TOK_TABLE_PARTITION",table)::Nil) => NativePlaceholder + Token("TOK_TABLE_PARTITION", table) :: Nil) => NativePlaceholder case Token("TOK_QUERY", queryArgs) if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => @@ -1151,7 +1151,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Seq(false, false) => Inner }.toBuffer - val joinedTables = tables.reduceLeft(Join(_,_, Inner, None)) + val joinedTables = tables.reduceLeft(Join(_, _, Inner, None)) // Must be transform down. val joinedResult = joinedTables transform { @@ -1171,7 +1171,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // worth the number of hacks that will be required to implement it. Namely, we need to add // some sort of mapped star expansion that would expand all child output row to be similarly // named output expressions where some aggregate expression has been applied (i.e. First). - ??? // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) + // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) + throw new UnsupportedOperationException case Token(allJoinTokens(joinToken), relation1 :: @@ -1560,6 +1561,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C """.stripMargin) } + /* Case insensitive matches for Window Specification */ + val PRECEDING = "(?i)preceding".r + val FOLLOWING = "(?i)following".r + val CURRENT = "(?i)current".r def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { case Token(windowName, Nil) :: Nil => // Refer to a window spec defined in the window clause. @@ -1613,11 +1618,19 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } else { val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) def nodeToBoundary(node: Node): FrameBoundary = node match { - case Token("preceding", Token(count, Nil) :: Nil) => - if (count == "unbounded") UnboundedPreceding else ValuePreceding(count.toInt) - case Token("following", Token(count, Nil) :: Nil) => - if (count == "unbounded") UnboundedFollowing else ValueFollowing(count.toInt) - case Token("current", Nil) => CurrentRow + case Token(PRECEDING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedPreceding + } else { + ValuePreceding(count.toInt) + } + case Token(FOLLOWING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedFollowing + } else { + ValueFollowing(count.toInt) + } + case Token(CURRENT(), Nil) => CurrentRow case _ => throw new NotImplementedError( s"""No parse rules for the Window Frame Boundary based on Node ${node.getName} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala new file mode 100644 index 0000000000000..d08c594151654 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.{InputStream, OutputStream} +import java.rmi.server.UID + +/* Implicit conversions */ +import scala.collection.JavaConversions._ +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils +import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector +import org.apache.hadoop.io.Writable + +import org.apache.spark.Logging +import org.apache.spark.sql.types.Decimal +import org.apache.spark.util.Utils + +private[hive] object HiveShim { + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + val UNLIMITED_DECIMAL_PRECISION = 38 + val UNLIMITED_DECIMAL_SCALE = 18 + + /* + * This function in hive-0.13 become private, but we have to do this to walkaround hive bug + */ + private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { + val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") + val result: StringBuilder = new StringBuilder(old) + var first: Boolean = old.isEmpty + + for (col <- cols) { + if (first) { + first = false + } else { + result.append(',') + } + result.append(col) + } + conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) + } + + /* + * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty + */ + def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { + if (ids != null && ids.nonEmpty) { + ColumnProjectionUtils.appendReadColumns(conf, ids) + } + if (names != null && names.nonEmpty) { + appendReadColumnNames(conf, names) + } + } + + /* + * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that + * is needed to initialize before serialization. + */ + def prepareWritable(w: Writable): Writable = { + w match { + case w: AvroGenericRecordWritable => + w.setRecordReaderID(new UID()) + case _ => + } + w + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, + hdoi.precision(), hdoi.scale()) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } + } + + /** + * This class provides the UDF creation and also the UDF instance serialization and + * de-serialization cross process boundary. + * + * Detail discussion can be found at https://github.com/apache/spark/pull/3640 + * + * @param functionClassName UDF class name + */ + private[hive] case class HiveFunctionWrapper(var functionClassName: String) + extends java.io.Externalizable { + + // for Serialization + def this() = this(null) + + @transient + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp, clazz).asInstanceOf[T] + inp.close() + t + } + + @transient + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() + } + + def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) + .asInstanceOf[UDFType] + } + + def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) + } + + private var instance: AnyRef = null + + def writeExternal(out: java.io.ObjectOutput) { + // output the function name + out.writeUTF(functionClassName) + + // Write a flag if instance is null or not + out.writeBoolean(instance != null) + if (instance != null) { + // Some of the UDF are serializable, but some others are not + // Hive Utilities can handle both cases + val baos = new java.io.ByteArrayOutputStream() + serializePlan(instance, baos) + val functionInBytes = baos.toByteArray + + // output the function bytes + out.writeInt(functionInBytes.length) + out.write(functionInBytes, 0, functionInBytes.length) + } + } + + def readExternal(in: java.io.ObjectInput) { + // read the function name + functionClassName = in.readUTF() + + if (in.readBoolean()) { + // if the instance is not null + // read the function in bytes + val functionInBytesLength = in.readInt() + val functionInBytes = new Array[Byte](functionInBytesLength) + in.read(functionInBytes, 0, functionInBytesLength) + + // deserialize the function object via Hive Utilities + instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), + Utils.getContextOrSparkClassLoader.loadClass(functionClassName)) + } + } + + def createFunction[UDFType <: AnyRef](): UDFType = { + if (instance != null) { + instance.asInstanceOf[UDFType] + } else { + val func = Utils.getContextOrSparkClassLoader + .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + if (!func.isInstanceOf[UDF]) { + // We cache the function if it's no the Simple UDF, + // as we always have to create new instance for Simple UDF + instance = func + } + func + } + } + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { + val f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) + f.setCompressCodec(w.compressCodec) + f.setCompressType(w.compressType) + f.setTableInfo(w.tableInfo) + f.setDestTableId(w.destTableId) + f + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) + extends Serializable with Logging { + var compressCodec: String = _ + var compressType: String = _ + var destTableId: Int = _ + + def setCompressed(compressed: Boolean) { + this.compressed = compressed + } + + def getDirName(): String = dir + + def setDestTableId(destTableId: Int) { + this.destTableId = destTableId + } + + def setTableInfo(tableInfo: TableDesc) { + this.tableInfo = tableInfo + } + + def setCompressCodec(intermediateCompressorCodec: String) { + compressCodec = intermediateCompressorCodec + } + + def setCompressType(intermediateCompressType: String) { + compressType = intermediateCompressType + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 0b6f7a334a715..334bfccc9d200 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,14 +25,13 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} -import org.apache.spark.SerializableWritable +import org.apache.spark.{Logging, SerializableWritable} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateUtils @@ -79,7 +78,7 @@ class HadoopTableReader( makeRDDForTable( hiveTable, Class.forName( - relation.tableDesc.getSerdeClassName, true, Utils.getSparkClassLoader) + relation.tableDesc.getSerdeClassName, true, Utils.getContextOrSparkClassLoader) .asInstanceOf[Class[Deserializer]], filterOpt = None) @@ -172,7 +171,7 @@ class HadoopTableReader( path.toString + tails } - val partPath = HiveShim.getDataLocationPath(partition) + val partPath = partition.getDataLocation val partNum = Utilities.getPartitionDesc(partition).getPartSpec.size(); var pathPatternStr = getPathPatternByPath(partNum, partPath) if (!pathPatternSet.contains(pathPatternStr)) { @@ -187,7 +186,7 @@ class HadoopTableReader( val hivePartitionRDDs = verifyPartitionPath(partitionToDeserializer) .map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) - val partPath = HiveShim.getDataLocationPath(partition) + val partPath = partition.getDataLocation val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) val ifc = partDesc.getInputFileFormatClass .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] @@ -325,7 +324,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector] } else { - HiveShim.getConvertedOI( + ObjectInspectorConverters.getConvertedOI( rawDeser.getObjectInspector, tableDeser.getObjectInspector).asInstanceOf[StructObjectInspector] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 196a3d836cab2..16851fdd71a98 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -90,14 +90,14 @@ private[hive] object IsolatedClientLoader { * `ClientInterface`, unless `isolationOn` is set to `false`. * * @param version The version of hive on the classpath. used to pick specific function signatures - * that are not compatibile accross versions. + * that are not compatible across versions. * @param execJars A collection of jar files that must include hive and hadoop. * @param config A set of options that will be added to the HiveConf of the constructed client. * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be * true unless loading the version of hive that is on Sparks classloader. - * @param rootClassLoader The system root classloader. Must not know about hive classes. - * @param baseClassLoader The spark classloader that is used to load shared classes. - * + * @param rootClassLoader The system root classloader. + * @param baseClassLoader The spark classloader that is used to load shared classes. Must not know + * about Hive classes. */ private[hive] class IsolatedClientLoader( val version: HiveVersion, @@ -110,7 +110,7 @@ private[hive] class IsolatedClientLoader( val barrierPrefixes: Seq[String] = Seq.empty) extends Logging { - // Check to make sure that the root classloader does not know about Hive. + // Check to make sure that the base classloader does not know about Hive. assert(Try(baseClassLoader.loadClass("org.apache.hive.HiveConf")).isFailure) /** All jars used by the hive specific classloader. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala index c600b158c5460..4d053ae42c2ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala @@ -30,7 +30,7 @@ private[client] object ReflectionException { /** * Provides implicit functions on any object for calling methods reflectively. */ -protected trait ReflectionMagic { +private[client] trait ReflectionMagic { /** code for InstanceMagic println( (1 to 22).map { n => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 7db9200d47440..410d9881ac214 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -29,5 +29,5 @@ package object client { case object v13 extends HiveVersion("0.13.1", false) } // scalastyle:on - + } \ No newline at end of file diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 62dc4167b78dd..11ee5503146b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -63,7 +63,7 @@ case class HiveTableScan( BindReferences.bindReference(pred, relation.partitionKeys) } - // Create a local copy of hiveconf,so that scan specific modifications should not impact + // Create a local copy of hiveconf,so that scan specific modifications should not impact // other queries @transient private[this] val hiveExtraConf = new HiveConf(context.hiveconf) @@ -72,7 +72,7 @@ case class HiveTableScan( addColumnMetadataToConf(hiveExtraConf) @transient - private[this] val hadoopReader = + private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context, hiveExtraConf) private[this] def castFromString(value: String, dataType: DataType) = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index c0b0b104e9142..eeb472602be3c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,27 +19,25 @@ package org.apache.spark.sql.hive.execution import java.util -import scala.collection.JavaConversions._ - import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils -import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.{SerializableWritable, SparkException, TaskContext} +import scala.collection.JavaConversions._ + private[hive] case class InsertIntoHiveTable( table: MetastoreRelation, @@ -106,7 +104,7 @@ case class InsertIntoHiveTable( } writerContainer - .getLocalFileWriter(row) + .getLocalFileWriter(row, table.schema) .write(serializer.serialize(outputData, standardOI)) } @@ -126,7 +124,7 @@ case class InsertIntoHiveTable( // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation - val tmpLocation = HiveShim.getExternalTmpPath(hiveContext, tableLocation) + val tmpLocation = hiveContext.getExternalTmpPath(tableLocation.toUri) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val isCompressed = sc.hiveconf.getBoolean( ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) @@ -194,10 +192,9 @@ case class InsertIntoHiveTable( if (partition.nonEmpty) { // loadPartition call orders directories created on the iteration order of the this map - val orderedPartitionSpec = new util.LinkedHashMap[String,String]() - table.hiveQlTable.getPartCols().foreach{ - entry=> - orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) + val orderedPartitionSpec = new util.LinkedHashMap[String, String]() + table.hiveQlTable.getPartCols().foreach { entry => + orderedPartitionSpec.put(entry.getName, partitionSpec.get(entry.getName).getOrElse("")) } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index bfd26e0170c70..fd623370cc407 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -62,7 +62,7 @@ case class ScriptTransformation( val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val reader = new BufferedReader(new InputStreamReader(inputStream)) - + val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) val iterator: Iterator[Row] = new Iterator[Row] with HiveInspectors { @@ -95,7 +95,7 @@ case class ScriptTransformation( val raw = outputSerde.deserialize(writable) val dataList = outputSoi.getStructFieldsDataAsList(raw) val fieldList = outputSoi.getAllStructFieldRefs() - + var i = 0 dataList.foreach( element => { if (element == null) { @@ -117,7 +117,7 @@ case class ScriptTransformation( if (!hasNext) { throw new NoSuchElementException } - + if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() @@ -192,7 +192,7 @@ case class HiveScriptIOSchema ( val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - + def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = { val (columns, columnTypes) = parseAttrs(input) val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps) @@ -206,22 +206,22 @@ case class HiveScriptIOSchema ( } def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - + val columns = attrs.map { case aref: AttributeReference => aref.name case e: NamedExpression => e.name case _ => null } - + val columnTypes = attrs.map { case aref: AttributeReference => aref.dataType case e: NamedExpression => e.dataType - case _ => null + case _ => null } (columns, columnTypes) } - + def initSerDe(serdeClassName: String, columns: Seq[String], columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { @@ -240,7 +240,7 @@ case class HiveScriptIOSchema ( (kv._1.split("'")(1), kv._2.split("'")(1)) }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) - + val properties = new Properties() properties.putAll(propsMap) serde.initialize(null, properties) @@ -261,7 +261,7 @@ case class HiveScriptIOSchema ( null } } - + def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = { if (outputSerde != null) { outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index bc6b3a2d58c38..6e6ac987b668a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.spark.sql.AnalysisException - import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConversions._ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions @@ -30,29 +27,31 @@ import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ -/* Implicit conversions */ -import scala.collection.JavaConversions._ private[hive] abstract class HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) - def lookupFunction(name: String, children: Seq[Expression]): Expression = { + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is // not always serializable. val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( - sys.error(s"Couldn't find function $name")) + throw new AnalysisException(s"undefined function $name")) val functionClassName = functionInfo.getFunctionClass.getName @@ -75,9 +74,11 @@ private[hive] abstract class HiveFunctionRegistry private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { - type EvaluatedType = Any + type UDFType = UDF + override def deterministic: Boolean = isUDFDeterministic + override def nullable: Boolean = true @transient @@ -139,7 +140,8 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector) private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type UDFType = GenericUDF - type EvaluatedType = Any + + override def deterministic: Boolean = isUDFDeterministic override def nullable: Boolean = true @@ -316,7 +318,7 @@ private[hive] case class HiveWindowFunction( // The object inspector of values returned from the Hive window function. @transient - protected lazy val returnInspector = { + protected lazy val returnInspector = { evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) } @@ -336,8 +338,6 @@ private[hive] case class HiveWindowFunction( def nullable: Boolean = true - override type EvaluatedType = Any - override def eval(input: Row): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") @@ -413,7 +413,7 @@ private[hive] case class HiveGenericUdaf( protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction() @transient - protected lazy val objectInspector = { + protected lazy val objectInspector = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) resolver.getEvaluator(parameterInfo) .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) @@ -446,7 +446,7 @@ private[hive] case class HiveUdaf( new GenericUDAFBridge(funcWrapper.createFunction()) @transient - protected lazy val objectInspector = { + protected lazy val objectInspector = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) resolver.getEvaluator(parameterInfo) .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) @@ -558,12 +558,12 @@ private[hive] case class HiveUdafFunction( } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() } - + private val inspectors = exprs.map(toInspector).toArray - - private val function = { + + private val function = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - resolver.getEvaluator(parameterInfo) + resolver.getEvaluator(parameterInfo) } private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) @@ -578,7 +578,7 @@ private[hive] case class HiveUdafFunction( @transient protected lazy val cached = new Array[AnyRef](exprs.length) - + def update(input: Row): Unit = { val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, inspectors, cached)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index cbc381cc81b59..ee440e304ec19 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -34,8 +34,9 @@ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} -import org.apache.spark.sql.hive.{ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} +import org.apache.spark.sql.types._ /** * Internal helper class that saves an RDD using a Hive OutputFormat. @@ -69,7 +70,7 @@ private[hive] class SparkHiveWriterContainer( @transient protected lazy val jobContext = newJobContext(conf.value, jID.value) @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) @transient private lazy val outputFormat = - conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]] + conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef, Writable]] def driverSideSetup() { setIDs(0, 0, 0) @@ -92,7 +93,7 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = writer + def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = writer def close() { // Seems the boolean value passed into close does not matter. @@ -195,11 +196,20 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { + override def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = { + def convertToHiveRawString(col: String, value: Any): String = { + val raw = String.valueOf(value) + schema(col).dataType match { + case DateType => DateUtils.toString(raw.toInt) + case _: DecimalType => BigDecimal(raw).toString() + case _ => raw + } + } + val dynamicPartPath = dynamicPartColNames .zip(row.toSeq.takeRight(dynamicPartColNames.length)) .map { case (col, rawVal) => - val string = if (rawVal == null) null else String.valueOf(rawVal) + val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal) val colString = if (string == null || string.isEmpty) { defaultPartName diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 2e06cabfa80c9..7c7afc824d7a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -189,7 +189,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } } - case class TestTable(name: String, commands: (()=>Unit)*) + case class TestTable(name: String, commands: (() => Unit)*) protected[hive] implicit class SqlCmd(sql: String) { def cmd: () => Unit = { @@ -253,8 +253,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { | 'serialization.format'='${classOf[TBinaryProtocol].getName}' |) |STORED AS - |INPUTFORMAT '${classOf[SequenceFileInputFormat[_,_]].getName}' - |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_,_]].getName}' + |INPUTFORMAT '${classOf[SequenceFileInputFormat[_, _]].getName}' + |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_, _]].getName}' """.stripMargin) runSqlHive( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 945596db80326..39d315aaeab57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -57,7 +57,7 @@ class CachedTableSuite extends QueryTest { checkAnswer( sql("SELECT * FROM src s"), preCacheResults) - + uncacheTable("src") assertCached(sql("SELECT * FROM src"), 0) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 2a7374cc172b7..df137e7b2b333 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -26,12 +26,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.io.LongWritable -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Literal, Row} import org.apache.spark.sql.types._ -class HiveInspectorSuite extends FunSuite with HiveInspectors { +class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { test("Test wrap SettableStructObjectInspector") { val udaf = new UDAFPercentile.PercentileLongEvaluator() udaf.init() @@ -78,10 +78,10 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { Literal(java.sql.Date.valueOf("2014-09-23")) :: Literal(Decimal(BigDecimal(123.123))) :: Literal(new java.sql.Timestamp(123123)) :: - Literal(Array[Byte](1,2,3)) :: - Literal.create(Seq[Int](1,2,3), ArrayType(IntegerType)) :: - Literal.create(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: - Literal.create(Row(1,2.0d,3.0f), + Literal(Array[Byte](1, 2, 3)) :: + Literal.create(Seq[Int](1, 2, 3), ArrayType(IntegerType)) :: + Literal.create(Map[Int, Int](1 -> 2, 2 -> 1), MapType(IntegerType, IntegerType)) :: + Literal.create(Row(1, 2.0d, 3.0f), StructType(StructField("c1", IntegerType) :: StructField("c2", DoubleType) :: StructField("c3", FloatType) :: Nil)) :: @@ -111,8 +111,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name) :_*), - java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) :_*)) + java.util.Arrays.asList(fields.map(f => f.name) : _*), + java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) : _*)) } def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index fa8e11ffec2b4..e9bb32667936c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.hive +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.hive.test.TestHive -import org.scalatest.FunSuite import org.apache.spark.sql.test.ExamplePointUDT import org.apache.spark.sql.types.StructType -class HiveMetastoreCatalogSuite extends FunSuite { +class HiveMetastoreCatalogSuite extends SparkFunSuite { test("struct field should accept underscore in sub-column name") { val metastr = "struct" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 941a2941649b8..f765395e148af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.client.{ManagedTable, HiveColumn, ExternalTable, HiveTable} -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class HiveQlSuite extends FunSuite with BeforeAndAfterAll { +class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll() { if (SessionState.get() == null) { SessionState.start(new HiveConf()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index acf2f7da30188..aa5dbe2db6903 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -160,7 +160,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) - assert(listFolders(tmpDir,List()).sortBy(_.toString()) == expected.sortBy(_.toString)) + assert(listFolders(tmpDir, List()).sortBy(_.toString()) == expected.sortBy(_.toString)) sql("DROP TABLE table_with_partition") sql("DROP TABLE tmp_table") } @@ -240,7 +240,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { checkAnswer(sql("select key,value from table_with_partition where ds='1' "), testData.collect().toSeq ) - + // test difference type of field sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") checkAnswer(sql("select key,value from table_with_partition where ds='1' "), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index e12a6c21ccac4..1c15997ea8e6d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -29,7 +29,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { import org.apache.spark.sql.hive.test.TestHive.implicits._ val df = - sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 9623ef06aa9b0..af586712e3235 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -21,770 +21,816 @@ import java.io.File import scala.collection.mutable.ArrayBuffer +import org.scalatest.BeforeAndAfterAll + import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred.InvalidInputException -import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql._ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. */ -class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { +class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { + override val sqlContext = TestHive + + var jsonFilePath: String = _ - override def afterEach(): Unit = { - reset() - Utils.deleteRecursively(tempPath) + override def beforeAll(): Unit = { + jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile } - val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile - var tempPath: File = Utils.createTempDir() - tempPath.delete() - - test ("persistent JSON table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - read.json(filePath).collect().toSeq) + test("persistent JSON table") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath).collect().toSeq) + } } - test ("persistent JSON table with a user specified schema") { - sql( - s""" - |CREATE TABLE jsonTable ( - |a string, - |b String, - |`c_!@(3)` int, - |`` Struct<`d!`:array, `=`:array>>) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - read.json(filePath).registerTempTable("expectedJsonTable") - - checkAnswer( - sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), - sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable").collect().toSeq) + test("persistent JSON table with a user specified schema") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable ( + |a string, + |b String, + |`c_!@(3)` int, + |`` Struct<`d!`:array, `=`:array>>) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") + checkAnswer( + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable")) + } + } } - test ("persistent JSON table with a user specified schema with a subset of fields") { - // This works because JSON objects are self-describing and JSONRelation can get needed - // field values based on field names. - sql( - s""" - |CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - val innerStruct = StructType( - StructField("=", ArrayType(StructType(StructField("Dd2", BooleanType, true) :: Nil))) :: Nil) - val expectedSchema = StructType( - StructField("", innerStruct, true) :: - StructField("b", StringType, true) :: Nil) - - assert(expectedSchema === table("jsonTable").schema) - - read.json(filePath).registerTempTable("expectedJsonTable") - - checkAnswer( - sql("SELECT b, ``.`=` FROM jsonTable"), - sql("SELECT b, ``.`=` FROM expectedJsonTable").collect().toSeq) + test("persistent JSON table with a user specified schema with a subset of fields") { + withTable("jsonTable") { + // This works because JSON objects are self-describing and JSONRelation can get needed + // field values based on field names. + sql( + s"""CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + val innerStruct = StructType(Seq( + StructField("=", ArrayType(StructType(StructField("Dd2", BooleanType, true) :: Nil))))) + + val expectedSchema = StructType(Seq( + StructField("", innerStruct, true), + StructField("b", StringType, true))) + + assert(expectedSchema === table("jsonTable").schema) + + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") + checkAnswer( + sql("SELECT b, ``.`=` FROM jsonTable"), + sql("SELECT b, ``.`=` FROM expectedJsonTable")) + } + } } test("resolve shortened provider names") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - read.json(filePath).collect().toSeq) + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath).collect().toSeq) + } } test("drop table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - read.json(filePath).collect().toSeq) - - sql("DROP TABLE jsonTable") - - intercept[Exception] { - sql("SELECT * FROM jsonTable").collect() - } + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath)) - assert( - (new File(filePath)).exists(), - "The table with specified path is considered as an external table, " + - "its data should not deleted after DROP TABLE.") + sql("DROP TABLE jsonTable") + + intercept[Exception] { + sql("SELECT * FROM jsonTable").collect() + } + + assert( + new File(jsonFilePath).exists(), + "The table with specified path is considered as an external table, " + + "its data should not deleted after DROP TABLE.") + } } test("check change without refresh") { - val tempDir = File.createTempFile("sparksql", "json", Utils.createTempDir()) - tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b")) - - Utils.deleteRecursively(tempDir) - sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - // Schema is cached so the new column does not show. The updated values in existing columns - // will show. - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a1", "b1")) - - sql("REFRESH TABLE jsonTable") - - // Check that the refresh worked - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a1", "b1", "c1")) - Utils.deleteRecursively(tempDir) + withTempPath { tempDir => + withTable("jsonTable") { + (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b")) + + Utils.deleteRecursively(tempDir) + (("a1", "b1", "c1") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + // Schema is cached so the new column does not show. The updated values in existing columns + // will show. + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a1", "b1")) + + sql("REFRESH TABLE jsonTable") + + // Check that the refresh worked + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a1", "b1", "c1")) + } + } } test("drop, change, recreate") { - val tempDir = File.createTempFile("sparksql", "json", Utils.createTempDir()) - tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b")) - - Utils.deleteRecursively(tempDir) - sparkContext.parallelize(("a", "b", "c") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql("DROP TABLE jsonTable") - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) - - // New table should reflect new schema. - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b", "c")) - Utils.deleteRecursively(tempDir) + withTempPath { tempDir => + (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b")) + + Utils.deleteRecursively(tempDir) + (("a", "b", "c") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + sql("DROP TABLE jsonTable") + + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + // New table should reflect new schema. + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b", "c")) + } + } } test("invalidate cache and reload") { - sql( - s""" - |CREATE TABLE jsonTable (`c_!@(3)` int) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable (`c_!@(3)` int) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) - read.json(filePath).registerTempTable("expectedJsonTable") + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") - checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - // Discard the cached relation. - invalidateTable("jsonTable") + // Discard the cached relation. + invalidateTable("jsonTable") - checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - invalidateTable("jsonTable") - val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) + invalidateTable("jsonTable") + val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) - assert(expectedSchema === table("jsonTable").schema) + assert(expectedSchema === table("jsonTable").schema) + } + } } test("CTAS") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - - assert(table("ctasJsonTable").schema === table("jsonTable").schema) - - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM jsonTable").collect()) + withTempPath { tempPath => + withTable("jsonTable", "ctasJsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + } } test("CTAS with IF NOT EXISTS") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - - // Create the table again should trigger a AnalysisException. - val message = intercept[AnalysisException] { - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - }.getMessage - assert(message.contains("Table ctasJsonTable already exists."), - "We should complain that ctasJsonTable already exists") - - // The following statement should be fine if it has IF NOT EXISTS. - // It tries to create a table ctasJsonTable with a new schema. - // The actual table's schema and data should not be changed. - sql( - s""" - |CREATE TABLE IF NOT EXISTS ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT a FROM jsonTable - """.stripMargin) - - // Discard the cached relation. - invalidateTable("ctasJsonTable") - - // Schema should not be changed. - assert(table("ctasJsonTable").schema === table("jsonTable").schema) - // Table data should not be changed. - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM jsonTable").collect()) + withTempPath { path => + val tempPath = path.getCanonicalPath + + withTable("jsonTable", "ctasJsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + // Create the table again should trigger a AnalysisException. + val message = intercept[AnalysisException] { + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + }.getMessage + + assert( + message.contains("Table ctasJsonTable already exists."), + "We should complain that ctasJsonTable already exists") + + // The following statement should be fine if it has IF NOT EXISTS. + // It tries to create a table ctasJsonTable with a new schema. + // The actual table's schema and data should not be changed. + sql( + s"""CREATE TABLE IF NOT EXISTS ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT a FROM jsonTable + """.stripMargin) + + // Discard the cached relation. + invalidateTable("ctasJsonTable") + + // Schema should not be changed. + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + // Table data should not be changed. + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + } } test("CTAS a managed table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") - val filesystemPath = new Path(expectedPath) - val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) - if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) - - // It is a managed table when we do not specify the location. - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |AS - |SELECT * FROM jsonTable - """.stripMargin) - - assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") - - sql( - s""" - |CREATE TABLE loadedTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${expectedPath}' - |) - """.stripMargin) - - assert(table("ctasJsonTable").schema === table("loadedTable").schema) - - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM loadedTable").collect() - ) - - sql("DROP TABLE ctasJsonTable") - assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") + withTable("jsonTable", "ctasJsonTable", "loadedTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") + val filesystemPath = new Path(expectedPath) + val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) + if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) + + // It is a managed table when we do not specify the location. + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |AS + |SELECT * FROM jsonTable + """.stripMargin) + + assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") + + sql( + s"""CREATE TABLE loadedTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$expectedPath' + |) + """.stripMargin) + + assert(table("ctasJsonTable").schema === table("loadedTable").schema) + + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM loadedTable")) + + sql("DROP TABLE ctasJsonTable") + assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") + } } test("SPARK-5286 Fail to drop an invalid table when using the data source API") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path 'it is not a path at all!' - |) - """.stripMargin) - - sql("DROP TABLE jsonTable").collect().foreach(println) + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path 'it is not a path at all!' + |) + """.stripMargin) + + sql("DROP TABLE jsonTable").collect().foreach(println) + } } test("SPARK-5839 HiveMetastoreCatalog does not recognize table aliases of data source tables.") { - val originalDefaultSource = conf.defaultDataSourceName + withTable("savedJsonTable") { + // Save the df as a managed table (by not specifying the path). + (1 to 10) + .map(i => i -> s"str$i") + .toDF("a", "b") + .write + .format("json") + .saveAsTable("savedJsonTable") - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = read.json(rdd) - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - // Save the df as a managed table (by not specifiying the path). - df.write.saveAsTable("savedJsonTable") + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str$i"))) - checkAnswer( - sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), - (1 to 4).map(i => Row(i, s"str${i}"))) + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str$i"))) - checkAnswer( - sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), - (6 to 10).map(i => Row(i, s"str${i}"))) + invalidateTable("savedJsonTable") - invalidateTable("savedJsonTable") + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str$i"))) - checkAnswer( - sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), - (1 to 4).map(i => Row(i, s"str${i}"))) + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str$i"))) + } + } - checkAnswer( - sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), - (6 to 10).map(i => Row(i, s"str${i}"))) + test("save table") { + withTempPath { path => + val tempPath = path.getCanonicalPath + + withTable("savedJsonTable") { + val df = (1 to 10).map(i => i -> s"str$i").toDF("a", "b") + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "json") { + // Save the df as a managed table (by not specifying the path). + df.write.saveAsTable("savedJsonTable") + + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // Right now, we cannot append to an existing JSON table. + intercept[RuntimeException] { + df.write.mode(SaveMode.Append).saveAsTable("savedJsonTable") + } + + // We can overwrite it. + df.write.mode(SaveMode.Overwrite).saveAsTable("savedJsonTable") + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // When the save mode is Ignore, we will do nothing when the table already exists. + df.select("b").write.mode(SaveMode.Ignore).saveAsTable("savedJsonTable") + assert(df.schema === table("savedJsonTable").schema) + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // Drop table will also delete the data. + sql("DROP TABLE savedJsonTable") + intercept[InvalidInputException] { + read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) + } + } + + // Create an external table by specifying the path. + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + df.write + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .option("path", tempPath.toString) + .saveAsTable("savedJsonTable") + + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + } + + // Data should not be deleted after we drop the table. + sql("DROP TABLE savedJsonTable") + checkAnswer(read.json(tempPath.toString), df) + } + } + } - // Drop table will also delete the data. - sql("DROP TABLE savedJsonTable") + test("create external table") { + withTempPath { tempPath => + withTable("savedJsonTable", "createdJsonTable") { + val df = read.json(sparkContext.parallelize((1 to 10).map { i => + s"""{ "a": $i, "b": "str$i" }""" + })) + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + df.write + .format("json") + .mode(SaveMode.Append) + .option("path", tempPath.toString) + .saveAsTable("savedJsonTable") + } + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "json") { + createExternalTable("createdJsonTable", tempPath.toString) + assert(table("createdJsonTable").schema === df.schema) + checkAnswer(sql("SELECT * FROM createdJsonTable"), df) + + assert( + intercept[AnalysisException] { + createExternalTable("createdJsonTable", jsonFilePath.toString) + }.getMessage.contains("Table createdJsonTable already exists."), + "We should complain that createdJsonTable already exists") + } + + // Data should not be deleted. + sql("DROP TABLE createdJsonTable") + checkAnswer(read.json(tempPath.toString), df) + + // Try to specify the schema. + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + val schema = StructType(StructField("b", StringType, true) :: Nil) + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map("path" -> tempPath.toString)) + + checkAnswer( + sql("SELECT * FROM createdJsonTable"), + sql("SELECT b FROM savedJsonTable")) + + sql("DROP TABLE createdJsonTable") + + assert( + intercept[RuntimeException] { + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map.empty[String, String]) + }.getMessage.contains("'path' must be specified for json data."), + "We should complain that path is not specified.") + } + } + } + } - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + test("scan a parquet table created through a CTAS statement") { + withSQLConf( + "spark.sql.hive.convertMetastoreParquet" -> "true", + SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + + withTempTable("jt") { + (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") + + withTable("test_parquet_ctas") { + sql( + """CREATE TABLE test_parquet_ctas STORED AS PARQUET + |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5 + """.stripMargin) + + checkAnswer( + sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), + Row(3) :: Row(4) :: Nil) + + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(p: ParquetRelation2) => // OK + case _ => + fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}") + } + } + } + } } - test("save table") { - val originalDefaultSource = conf.defaultDataSourceName + test("Pre insert nullability check (ArrayType)") { + withTable("arrayInParquet") { + { + val df = (Tuple1(Seq(Int.box(1), null: Integer)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + ArrayType(IntegerType, containsNull = true), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("arrayInParquet") + } - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = read.json(rdd) + { + val df = (Tuple1(Seq(2, 3)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + ArrayType(IntegerType, containsNull = false), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Append) + .insertInto("arrayInParquet") + } - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - // Save the df as a managed table (by not specifiying the path). - df.write.saveAsTable("savedJsonTable") + (Tuple1(Seq(4, 5)) :: Nil).toDF("a") + .write + .mode(SaveMode.Append) + .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) + (Tuple1(Seq(Int.box(6), null: Integer)) :: Nil).toDF("a") + .write + .mode(SaveMode.Append) + .saveAsTable("arrayInParquet") - // Right now, we cannot append to an existing JSON table. - intercept[RuntimeException] { - df.write.mode(SaveMode.Append).saveAsTable("savedJsonTable") - } + refreshTable("arrayInParquet") - // We can overwrite it. - df.write.mode(SaveMode.Overwrite).saveAsTable("savedJsonTable") - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) - - // When the save mode is Ignore, we will do nothing when the table already exists. - df.select("b").write.mode(SaveMode.Ignore).saveAsTable("savedJsonTable") - assert(df.schema === table("savedJsonTable").schema) - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) - - // Drop table will also delete the data. - sql("DROP TABLE savedJsonTable") - intercept[InvalidInputException] { - read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) + checkAnswer( + sql("SELECT a FROM arrayInParquet"), + Row(ArrayBuffer(1, null)) :: + Row(ArrayBuffer(2, 3)) :: + Row(ArrayBuffer(4, 5)) :: + Row(ArrayBuffer(6, null)) :: Nil) } - - // Create an external table by specifying the path. - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.write - .format("org.apache.spark.sql.json") - .mode(SaveMode.Append) - .option("path", tempPath.toString) - .saveAsTable("savedJsonTable") - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) - - // Data should not be deleted after we drop the table. - sql("DROP TABLE savedJsonTable") - checkAnswer( - read.json(tempPath.toString), - df.collect()) - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) } - test("create external table") { - val originalDefaultSource = conf.defaultDataSourceName - - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = read.json(rdd) - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.write.format("org.apache.spark.sql.json") - .mode(SaveMode.Append) - .option("path", tempPath.toString) - .saveAsTable("savedJsonTable") - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - createExternalTable("createdJsonTable", tempPath.toString) - assert(table("createdJsonTable").schema === df.schema) - checkAnswer( - sql("SELECT * FROM createdJsonTable"), - df.collect()) - - var message = intercept[AnalysisException] { - createExternalTable("createdJsonTable", filePath.toString) - }.getMessage - assert(message.contains("Table createdJsonTable already exists."), - "We should complain that ctasJsonTable already exists") - - // Data should not be deleted. - sql("DROP TABLE createdJsonTable") - checkAnswer( - read.json(tempPath.toString), - df.collect()) - - // Try to specify the schema. - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - val schema = StructType(StructField("b", StringType, true) :: Nil) - createExternalTable( - "createdJsonTable", - "org.apache.spark.sql.json", - schema, - Map("path" -> tempPath.toString)) - checkAnswer( - sql("SELECT * FROM createdJsonTable"), - sql("SELECT b FROM savedJsonTable").collect()) - - sql("DROP TABLE createdJsonTable") - - message = intercept[RuntimeException] { - createExternalTable( - "createdJsonTable", - "org.apache.spark.sql.json", - schema, - Map.empty[String, String]) - }.getMessage - assert( - message.contains("'path' must be specified for json data."), - "We should complain that path is not specified.") - - sql("DROP TABLE savedJsonTable") - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) - } - - if (HiveShim.version == "0.13.1") { - test("scan a parquet table created through a CTAS statement") { - val originalConvertMetastore = getConf("spark.sql.hive.convertMetastoreParquet", "true") - val originalUseDataSource = getConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") - setConf("spark.sql.hive.convertMetastoreParquet", "true") - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") - - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - read.json(rdd).registerTempTable("jt") - sql( - """ - |create table test_parquet_ctas STORED AS parquET - |AS select tmp.a from jt tmp where tmp.a < 5 - """.stripMargin) + test("Pre insert nullability check (MapType)") { + withTable("mapInParquet") { + { + val df = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + MapType(IntegerType, IntegerType, valueContainsNull = true), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("mapInParquet") + } - checkAnswer( - sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), - Row(3) :: Row(4) :: Nil - ) - - table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation2) => // OK - case _ => - fail( - "test_parquet_ctas should be converted to " + - s"${classOf[ParquetRelation2].getCanonicalName}") + { + val df = (Tuple1(Map(2 -> 3)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + MapType(IntegerType, IntegerType, valueContainsNull = false), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Append) + .insertInto("mapInParquet") } - // Clenup and reset confs. - sql("DROP TABLE IF EXISTS jt") - sql("DROP TABLE IF EXISTS test_parquet_ctas") - setConf("spark.sql.hive.convertMetastoreParquet", originalConvertMetastore) - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalUseDataSource) - } - } + (Tuple1(Map(4 -> 5)) :: Nil).toDF("a") + .write + .format("parquet") + .mode(SaveMode.Append) + .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. - test("Pre insert nullability check (ArrayType)") { - val df1 = - createDataFrame(Tuple1(Seq(Int.box(1), null.asInstanceOf[Integer])) :: Nil).toDF("a") - val expectedSchema1 = - StructType( - StructField("a", ArrayType(IntegerType, containsNull = true), nullable = true) :: Nil) - assert(df1.schema === expectedSchema1) - df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("arrayInParquet") - - val df2 = - createDataFrame(Tuple1(Seq(2, 3)) :: Nil).toDF("a") - val expectedSchema2 = - StructType( - StructField("a", ArrayType(IntegerType, containsNull = false), nullable = true) :: Nil) - assert(df2.schema === expectedSchema2) - df2.write.mode(SaveMode.Append).insertInto("arrayInParquet") - createDataFrame(Tuple1(Seq(4, 5)) :: Nil).toDF("a").write.mode(SaveMode.Append) - .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. - createDataFrame(Tuple1(Seq(Int.box(6), null.asInstanceOf[Integer])) :: Nil).toDF("a").write - .mode(SaveMode.Append).saveAsTable("arrayInParquet") - refreshTable("arrayInParquet") - - checkAnswer( - sql("SELECT a FROM arrayInParquet"), - Row(ArrayBuffer(1, null)) :: - Row(ArrayBuffer(2, 3)) :: - Row(ArrayBuffer(4, 5)) :: - Row(ArrayBuffer(6, null)) :: Nil) - - sql("DROP TABLE arrayInParquet") - } + (Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") + .write + .format("parquet") + .mode(SaveMode.Append) + .saveAsTable("mapInParquet") - test("Pre insert nullability check (MapType)") { - val df1 = - createDataFrame(Tuple1(Map(1 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") - val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = true) - val expectedSchema1 = - StructType( - StructField("a", mapType1, nullable = true) :: Nil) - assert(df1.schema === expectedSchema1) - df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("mapInParquet") - - val df2 = - createDataFrame(Tuple1(Map(2 -> 3)) :: Nil).toDF("a") - val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = false) - val expectedSchema2 = - StructType( - StructField("a", mapType2, nullable = true) :: Nil) - assert(df2.schema === expectedSchema2) - df2.write.mode(SaveMode.Append).insertInto("mapInParquet") - createDataFrame(Tuple1(Map(4 -> 5)) :: Nil).toDF("a").write.mode(SaveMode.Append) - .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. - createDataFrame(Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a").write - .format("parquet").mode(SaveMode.Append).saveAsTable("mapInParquet") - refreshTable("mapInParquet") - - checkAnswer( - sql("SELECT a FROM mapInParquet"), - Row(Map(1 -> null)) :: - Row(Map(2 -> 3)) :: - Row(Map(4 -> 5)) :: - Row(Map(6 -> null)) :: Nil) - - sql("DROP TABLE mapInParquet") + refreshTable("mapInParquet") + + checkAnswer( + sql("SELECT a FROM mapInParquet"), + Row(Map(1 -> null)) :: + Row(Map(2 -> 3)) :: + Row(Map(4 -> 5)) :: + Row(Map(6 -> null)) :: Nil) + } } test("SPARK-6024 wide schema support") { - // We will need 80 splits for this schema if the threshold is 4000. - val schema = StructType((1 to 5000).map(i => StructField(s"c_${i}", StringType, true))) - assert( - schema.json.size > conf.schemaStringLengthThreshold, - "To correctly test the fix of SPARK-6024, the value of " + - s"spark.sql.sources.schemaStringLengthThreshold needs to be less than ${schema.json.size}") - // Manually create a metastore data source table. - catalog.createDataSourceTable( - tableName = "wide_schema", - userSpecifiedSchema = Some(schema), - partitionColumns = Array.empty[String], - provider = "json", - options = Map("path" -> "just a dummy path"), - isExternal = false) - - invalidateTable("wide_schema") - - val actualSchema = table("wide_schema").schema - assert(schema === actualSchema) + withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD -> "4000") { + withTable("wide_schema") { + // We will need 80 splits for this schema if the threshold is 4000. + val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true))) + + // Manually create a metastore data source table. + catalog.createDataSourceTable( + tableName = "wide_schema", + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + provider = "json", + options = Map("path" -> "just a dummy path"), + isExternal = false) + + invalidateTable("wide_schema") + + val actualSchema = table("wide_schema").schema + assert(schema === actualSchema) + } + } } test("SPARK-6655 still support a schema stored in spark.sql.sources.schema") { val tableName = "spark6655" - val schema = StructType(StructField("int", IntegerType, true) :: Nil) - - val hiveTable = HiveTable( - specifiedDatabase = Some("default"), - name = tableName, - schema = Seq.empty, - partitionColumns = Seq.empty, - properties = Map( - "spark.sql.sources.provider" -> "json", - "spark.sql.sources.schema" -> schema.json, - "EXTERNAL" -> "FALSE"), - tableType = ManagedTable, - serdeProperties = Map( - "path" -> catalog.hiveDefaultTableFilePath(tableName))) - - catalog.client.createTable(hiveTable) - - invalidateTable(tableName) - val actualSchema = table(tableName).schema - assert(schema === actualSchema) - sql(s"drop table $tableName") + withTable(tableName) { + val schema = StructType(StructField("int", IntegerType, true) :: Nil) + val hiveTable = HiveTable( + specifiedDatabase = Some("default"), + name = tableName, + schema = Seq.empty, + partitionColumns = Seq.empty, + properties = Map( + "spark.sql.sources.provider" -> "json", + "spark.sql.sources.schema" -> schema.json, + "EXTERNAL" -> "FALSE"), + tableType = ManagedTable, + serdeProperties = Map( + "path" -> catalog.hiveDefaultTableFilePath(tableName))) + + catalog.client.createTable(hiveTable) + + invalidateTable(tableName) + val actualSchema = table(tableName).schema + assert(schema === actualSchema) + } } test("Saving partition columns information") { - val df = - sparkContext.parallelize(1 to 10, 4).map { i => - Tuple4(i, i + 1, s"str$i", s"str${i + 1}") - }.toDF("a", "b", "c", "d") - + val df = (1 to 10).map(i => (i, i + 1, s"str$i", s"str${i + 1}")).toDF("a", "b", "c", "d") val tableName = s"partitionInfo_${System.currentTimeMillis()}" - df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName) - invalidateTable(tableName) - val metastoreTable = catalog.client.getTable("default", tableName) - val expectedPartitionColumns = - StructType(df.schema("d") :: df.schema("b") :: Nil) - val actualPartitionColumns = - StructType( - metastoreTable.partitionColumns.map(c => - StructField(c.name, HiveMetastoreTypes.toDataType(c.hiveType)))) - // Make sure partition columns are correctly stored in metastore. - assert( - expectedPartitionColumns.sameType(actualPartitionColumns), - s"Partitions columns stored in metastore $actualPartitionColumns is not the " + - s"partition columns defined by the saveAsTable operation $expectedPartitionColumns.") - - // Check the content of the saved table. - checkAnswer( - table(tableName).selectExpr("c", "b", "d", "a"), - df.selectExpr("c", "b", "d", "a").collect()) - - sql(s"drop table $tableName") + + withTable(tableName) { + df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName) + invalidateTable(tableName) + val metastoreTable = catalog.client.getTable("default", tableName) + val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) + val actualPartitionColumns = + StructType( + metastoreTable.partitionColumns.map(c => + StructField(c.name, HiveMetastoreTypes.toDataType(c.hiveType)))) + // Make sure partition columns are correctly stored in metastore. + assert( + expectedPartitionColumns.sameType(actualPartitionColumns), + s"Partitions columns stored in metastore $actualPartitionColumns is not the " + + s"partition columns defined by the saveAsTable operation $expectedPartitionColumns.") + + // Check the content of the saved table. + checkAnswer( + table(tableName).select("c", "b", "d", "a"), + df.select("c", "b", "d", "a")) + } } test("insert into a table") { - def createDF(from: Int, to: Int): DataFrame = - createDataFrame((from to to).map(i => Tuple2(i, s"str$i"))).toDF("c1", "c2") + def createDF(from: Int, to: Int): DataFrame = { + (from to to).map(i => i -> s"str$i").toDF("c1", "c2") + } - createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") - checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), - (6 to 9).map(i => Row(i, s"str$i"))) + withTable("insertParquet") { + createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 9).map(i => Row(i, s"str$i"))) - intercept[AnalysisException] { - createDF(10, 19).write.format("parquet").saveAsTable("insertParquet") - } + intercept[AnalysisException] { + createDF(10, 19).write.format("parquet").saveAsTable("insertParquet") + } - createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") - checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), - (6 to 19).map(i => Row(i, s"str$i"))) + createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 19).map(i => Row(i, s"str$i"))) - createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), - (6 to 24).map(i => Row(i, s"str$i"))) + createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + (6 to 24).map(i => Row(i, s"str$i"))) - intercept[AnalysisException] { - createDF(30, 39).write.saveAsTable("insertParquet") - } + intercept[AnalysisException] { + createDF(30, 39).write.saveAsTable("insertParquet") + } + + createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + (6 to 34).map(i => Row(i, s"str$i"))) - createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), - (6 to 34).map(i => Row(i, s"str$i"))) - - createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), - (6 to 44).map(i => Row(i, s"str$i"))) - - createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), - (52 to 54).map(i => Row(i, s"str$i"))) - createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), - (50 to 59).map(i => Row(i, s"str$i"))) - - createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), - (70 to 79).map(i => Row(i, s"str$i"))) + createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + (6 to 44).map(i => Row(i, s"str$i"))) + + createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + (52 to 54).map(i => Row(i, s"str$i"))) + createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (50 to 59).map(i => Row(i, s"str$i"))) + + createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (70 to 79).map(i => Row(i, s"str$i"))) + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 4990092df6a99..017bc2adc103b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -20,16 +20,17 @@ package org.apache.spark.sql.hive import com.google.common.io.Files import org.apache.spark.sql.{QueryTest, _} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest { - import org.apache.spark.sql.hive.test.TestHive.implicits._ + + private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + import ctx.implicits._ + import ctx.sql test("SPARK-5068: query data when path doesn't exist"){ - val testData = TestHive.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") @@ -48,8 +49,8 @@ class QueryPartitionSuite extends QueryTest { // test for the exist path checkAnswer(sql("select key,value from table_with_partition"), - testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect - ++ testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect) + testData.toDF.collect ++ testData.toDF.collect + ++ testData.toDF.collect ++ testData.toDF.collect) // delete the path of one partition tmpDir.listFiles @@ -58,8 +59,7 @@ class QueryPartitionSuite extends QueryTest { // test for after delete the path checkAnswer(sql("select key,value from table_with_partition"), - testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect - ++ testData.toSchemaRDD.collect) + testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) sql("DROP TABLE table_with_partition") sql("DROP TABLE createAndInsertTest") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala index 8afe5459d4f1b..93dcb10f7a296 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.sql.hive -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.hive.test.TestHive -class SerializationSuite extends FunSuite { +class SerializationSuite extends SparkFunSuite { test("[SPARK-5840] HiveContext should be serializable") { - val hiveContext = TestHive + val hiveContext = org.apache.spark.sql.hive.test.TestHive hiveContext.hiveconf val serializer = new JavaSerializer(new SparkConf()).newInstance() val bytes = serializer.serialize(hiveContext) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 00a69de9e4262..78c94e6490e36 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -23,13 +23,18 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.execution._ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { - TestHive.reset() - TestHive.cacheTables = false + + private lazy val ctx: HiveContext = { + val ctx = org.apache.spark.sql.hive.test.TestHive + ctx.reset() + ctx.cacheTables = false + ctx + } + + import ctx.sql test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { @@ -72,17 +77,13 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes + ctx.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() - // TODO: How does it works? needs to add it back for other hive version. - if (HiveShim.version =="0.12.0") { - assert(queryTotalSize("analyzeTable") === conf.defaultSizeInBytes) - } sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable") === BigInt(11624)) @@ -110,7 +111,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === conf.defaultSizeInBytes) + assert(queryTotalSize("analyzeTable_part") === ctx.conf.defaultSizeInBytes) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") @@ -121,9 +122,9 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") intercept[UnsupportedOperationException] { - analyze("tempTable") + ctx.analyze("tempTable") } - catalog.unregisterTable(Seq("tempTable")) + ctx.catalog.unregisterTable(Seq("tempTable")) } test("estimates the size of a test MetastoreRelation") { @@ -151,8 +152,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = df.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= conf.autoBroadcastJoinThreshold - && sizes(1) <= conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold + && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -163,8 +164,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, expectedAnswer) // check correctness of output - TestHive.conf.settings.synchronized { - val tmp = conf.autoBroadcastJoinThreshold + ctx.conf.settings.synchronized { + val tmp = ctx.conf.autoBroadcastJoinThreshold sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") df = sql(query) @@ -207,8 +208,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { .isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(1) <= conf.autoBroadcastJoinThreshold - && sizes(0) <= conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold + && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -221,8 +222,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, answer) // check correctness of output - TestHive.conf.settings.synchronized { - val tmp = conf.autoBroadcastJoinThreshold + ctx.conf.settings.synchronized { + val tmp = ctx.conf.autoBroadcastJoinThreshold sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") df = sql(leftSemiJoinQuery) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 85b6bc93d7122..4056dee777574 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,20 +17,20 @@ package org.apache.spark.sql.hive -/* Implicits */ - import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.TestHive._ case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { + + private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + test("UDF case insensitive") { - udf.register("random0", () => { Math.random()}) - udf.register("RANDOM1", () => { Math.random()}) - udf.register("strlenScala", (_: String).length + (_:Int)) - assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + ctx.udf.register("random0", () => { Math.random() }) + ctx.udf.register("RANDOM1", () => { Math.random() }) + ctx.udf.register("strlenScala", (_: String).length + (_: Int)) + assert(ctx.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 321dc8d7322b8..7eb4842726665 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,18 +17,17 @@ package org.apache.spark.sql.hive.client -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.util.Utils -import org.scalatest.FunSuite /** - * A simple set of tests that call the methods of a hive ClientInterface, loading different version - * of hive from maven central. These tests are simple in that they are mostly just testing to make - * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionallity + * A simple set of tests that call the methods of a hive ClientInterface, loading different version + * of hive from maven central. These tests are simple in that they are mostly just testing to make + * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ -class VersionsSuite extends FunSuite with Logging { +class VersionsSuite extends SparkFunSuite with Logging { private def buildConf() = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index 23ece7e7cf6e9..b0d3dd44daedc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class ConcurrentHiveSuite extends FunSuite with BeforeAndAfterAll { +class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { test("Multiple Hive Instances") { (1 to 10).map { i => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 9c056e493bfde..c9dd4c0935a72 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.hive.execution import java.io._ -import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} +import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.sources.DescribeCommand import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -40,7 +40,7 @@ import org.apache.spark.sql.hive.test.TestHive * configured using system properties. */ abstract class HiveComparisonTest - extends FunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { + extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { /** * When set, any cache files that result in test failures will be deleted. Used when the test @@ -273,7 +273,7 @@ abstract class HiveComparisonTest } val hiveCacheFiles = queryList.zipWithIndex.map { - case (queryString, i) => + case (queryString, i) => val cachedAnswerName = s"$testCaseName-$i-${getMd5(queryString)}" new File(answerCache, cachedAnswerName) } @@ -304,7 +304,7 @@ abstract class HiveComparisonTest // other DDL has not been executed yet. hiveQueries.foreach(_.logical) val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { - case ((queryString, i), hiveQuery, cachedAnswerFile)=> + case ((queryString, i), hiveQuery, cachedAnswerFile) => try { // Hooks often break the harness and don't really affect our test anyway, don't // even try running them. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4af31d482ce42..6d8d99ebc8164 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -57,7 +57,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF sql( """ - |CREATE TEMPORARY FUNCTION udtf_count2 + |CREATE TEMPORARY FUNCTION udtf_count2 |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' """.stripMargin) } @@ -874,15 +874,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |WITH serdeproperties('s1'='9') """.stripMargin) } - // Now only verify 0.12.0, and ignore other versions due to binary compatibility - // current TestSerDe.jar is from 0.12.0 - if (HiveShim.version == "0.12.0") { - sql(s"ADD JAR $testJar") - sql( - """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' - |WITH serdeproperties('s1'='9') - """.stripMargin) - } sql("DROP TABLE alter1") } @@ -890,15 +881,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // this is a test case from mapjoin_addjar.q val testJar = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath val testData = TestHive.getHiveFile("data/files/sample.json").getCanonicalPath - if (HiveShim.version == "0.13.1") { - sql(s"ADD JAR $testJar") - sql( - """CREATE TABLE t1(a string, b string) - |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) - sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") - sql("select * from src join t1 on src.key = t1.a") - sql("DROP TABLE t1") - } + sql(s"ADD JAR $testJar") + sql( + """CREATE TABLE t1(a string, b string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") + sql("select * from src join t1 on src.key = t1.a") + sql("DROP TABLE t1") } test("ADD FILE command") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 3dfa6e72e1242..b08db6de2d2f6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -77,7 +77,7 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection - sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("caseSensitivityTest") val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") @@ -88,14 +88,14 @@ class HiveResolutionSuite extends HiveComparisonTest { ignore("case insensitivity with scala reflection joins") { // Test resolution with Scala Reflection - sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("caseSensitivityTest") sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() } test("nested repeated resolution") { - sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("nestedRepeatedTest") assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index ab53c6309e089..2209fc2f30a3c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -61,7 +61,7 @@ class HiveTableScanSuite extends HiveComparisonTest { TestHive.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() TestHive.sql("drop table tb") } - + test("Spark-4077: timestamp query for null value") { TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null") TestHive.sql( @@ -71,12 +71,12 @@ class HiveTableScanSuite extends HiveComparisonTest { FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' """.stripMargin) - val location = + val location = Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile() - + TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") - assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() - === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")),Row(null))) + assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() + === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")), Row(null))) TestHive.sql("DROP TABLE timestamp_query_null") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index 7f49eac490572..ce5985888f540 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -101,7 +101,7 @@ class HiveUdfSuite extends QueryTest { sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") TestHive.reset() } - + test("SPARK-2693 udaf aggregates test") { checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), sql("SELECT max(key) FROM src").collect().toSeq) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index b707f5e68489b..40a35674e4cb8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim, MetastoreRelation} +import org.apache.spark.sql.hive.{HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ @@ -327,38 +327,36 @@ class SQLQuerySuite extends QueryTest { "org.apache.hadoop.hive.ql.io.RCFileInputFormat", "org.apache.hadoop.hive.ql.io.RCFileOutputFormat", "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe", - "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22","MANAGED_TABLE" + "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" ) - if (HiveShim.version =="0.13.1") { - val origUseParquetDataSource = conf.parquetUseDataSourceApi - try { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") - sql( - """CREATE TABLE ctas5 - | STORED AS parquet AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin).collect() - - checkExistence(sql("DESC EXTENDED ctas5"), true, - "name:key", "type:string", "name:value", "ctas5", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", - "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", - "MANAGED_TABLE" - ) - - val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") - // use the Hive SerDe for parquet tables - sql("set spark.sql.hive.convertMetastoreParquet = false") - checkAnswer( - sql("SELECT key, value FROM ctas5 ORDER BY key, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) - sql(s"set spark.sql.hive.convertMetastoreParquet = $default") - } finally { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) - } + val origUseParquetDataSource = conf.parquetUseDataSourceApi + try { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sql( + """CREATE TABLE ctas5 + | STORED AS parquet AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect() + + checkExistence(sql("DESC EXTENDED ctas5"), true, + "name:key", "type:string", "name:value", "ctas5", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + "MANAGED_TABLE" + ) + + val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") + // use the Hive SerDe for parquet tables + sql("set spark.sql.hive.convertMetastoreParquet = false") + checkAnswer( + sql("SELECT key, value FROM ctas5 ORDER BY key, value"), + sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql(s"set spark.sql.hive.convertMetastoreParquet = $default") + } finally { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) } } @@ -780,6 +778,42 @@ class SQLQuerySuite extends QueryTest { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("window function: multiple window expressions in a single expression") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.registerTempTable("nums") + + val expected = + Row(1, 1, 1, 55, 1, 57) :: + Row(0, 2, 3, 55, 2, 60) :: + Row(1, 3, 6, 55, 4, 65) :: + Row(0, 4, 10, 55, 6, 71) :: + Row(1, 5, 15, 55, 9, 79) :: + Row(0, 6, 21, 55, 12, 88) :: + Row(1, 7, 28, 55, 16, 99) :: + Row(0, 8, 36, 55, 20, 111) :: + Row(1, 9, 45, 55, 25, 125) :: + Row(0, 10, 55, 55, 30, 140) :: Nil + + val actual = sql( + """ + |SELECT + | y, + | x, + | sum(x) OVER w1 AS running_sum, + | sum(x) OVER w2 AS total_sum, + | sum(x) OVER w3 AS running_sum_per_y, + | ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as combined2 + |FROM nums + |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT RoW), + | w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOuNDED FoLLOWING), + | w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + """.stripMargin) + + checkAnswer(actual, expected) + + dropTempTable("nums") + } + test("test case key when") { (1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t") checkAnswer( @@ -837,4 +871,37 @@ class SQLQuerySuite extends QueryTest { java.lang.Math.exp(1.0).toString, java.lang.Math.floor(1.9).toString)) } + + test("dynamic partition value test") { + try { + sql("set hive.exec.dynamic.partition.mode=nonstrict") + // date + sql("drop table if exists dynparttest1") + sql("create table dynparttest1 (value int) partitioned by (pdate date)") + sql( + """ + |insert into table dynparttest1 partition(pdate) + | select count(*), cast('2015-05-21' as date) as pdate from src + """.stripMargin) + checkAnswer( + sql("select * from dynparttest1"), + Seq(Row(500, java.sql.Date.valueOf("2015-05-21")))) + + // decimal + sql("drop table if exists dynparttest2") + sql("create table dynparttest2 (value int) partitioned by (pdec decimal(5, 1))") + sql( + """ + |insert into table dynparttest2 partition(pdec) + | select count(*), cast('100.12' as decimal(5, 1)) as pdec from src + """.stripMargin) + checkAnswer( + sql("select * from dynparttest2"), + Seq(Row(500, new java.math.BigDecimal("100.1")))) + } finally { + sql("drop table if exists dynparttest1") + sql("drop table if exists dynparttest2") + sql("set hive.exec.dynamic.partition.mode=strict") + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 88c99e35260d9..0e63d84e9824a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.hive.orc import java.io.File import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.util.Utils -import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} +import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -38,7 +39,7 @@ case class OrcParData(intField: Int, stringField: String) case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { +class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal def withTempDir(f: File => Unit): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index cdd6e705f4a2c..b384fb39f3d66 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -21,8 +21,9 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc.CompressionKind -import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} +import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.hive.test.TestHive @@ -50,10 +51,7 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) -class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll with OrcTest { - override val sqlContext = TestHive - - import TestHive.read +class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { def getTempFilePath(prefix: String, suffix: String = ""): File = { val tempFile = File.createTempFile(prefix, suffix) @@ -68,7 +66,7 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll w withOrcFile(data) { file => checkAnswer( - read.format("orc").load(file), + sqlContext.read.format("orc").load(file), data.toDF().collect()) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 750f0b04aaa87..5daf691aa8c53 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -22,13 +22,11 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql._ private[sql] trait OrcTest extends SQLTestUtils { - protected def hiveContext = sqlContext.asInstanceOf[HiveContext] + lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive import sqlContext.sparkContext import sqlContext.implicits._ @@ -53,7 +51,7 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path))) + withOrcFile(data)(path => f(sqlContext.read.format("orc").load(path))) } /** @@ -65,7 +63,7 @@ private[sql] trait OrcTest extends SQLTestUtils { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withOrcDataFrame(data) { df => - hiveContext.registerDataFrameAsTable(df, tableName) + sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 7851f38fd4056..e62ac909cbd0c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -38,7 +38,7 @@ case class ParquetData(intField: Int, stringField: String) // The data that also includes the partitioning key case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) -case class StructContainer(intStructField :Int, stringStructField: String) +case class StructContainer(intStructField: Int, stringStructField: String) case class ParquetDataWithComplexTypes( intField: Int, @@ -735,7 +735,7 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { val filePath = new File(tempDir, "testParquet").getCanonicalPath val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath - val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") intercept[Throwable](df2.write.parquet(filePath)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index de907846b9180..0f959b3d0b86d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputForma import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLContext} @@ -108,7 +109,10 @@ class SimpleTextRelation( sparkContext.textFile(inputStatuses.map(_.getPath).mkString(",")).map { record => Row(record.split(",").zip(fields).map { case (value, dataType) => - Cast(Literal(value), dataType).eval() + // `Cast`ed values are always of Catalyst types (i.e. UTF8String instead of String, etc.) + val catalystValue = Cast(Literal(value), dataType).eval() + // Here we're converting Catalyst values to Scala values to test `needsConversion` + CatalystTypeConverters.convertToScala(catalystValue, dataType) }: _*) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 70328e1ef810d..76469d7a3d6a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.sources +import java.io.File + +import com.google.common.io.Files import org.apache.hadoop.fs.Path -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive @@ -28,9 +30,9 @@ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { - override val sqlContext: SQLContext = TestHive + override lazy val sqlContext: SQLContext = TestHive - import sqlContext._ + import sqlContext.sql import sqlContext.implicits._ val dataSourceName = classOf[SimpleTextSource].getCanonicalName @@ -41,19 +43,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { StructField("a", IntegerType, nullable = false), StructField("b", StringType, nullable = false))) - val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") + lazy val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") - val partitionedTestDF1 = (for { + lazy val partitionedTestDF1 = (for { i <- 1 to 3 p2 <- Seq("foo", "bar") } yield (i, s"val_$i", 1, p2)).toDF("a", "b", "p1", "p2") - val partitionedTestDF2 = (for { + lazy val partitionedTestDF2 = (for { i <- 1 to 3 p2 <- Seq("foo", "bar") } yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2") - val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) + lazy val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) def checkQueries(df: DataFrame): Unit = { // Selects everything @@ -76,6 +78,12 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { df.filter('a > 1 && 'p1 < 2).select('b, 'p1), for (i <- 2 to 3; _ <- Seq("foo", "bar")) yield Row(s"val_$i", 1)) + // Project many copies of columns with different types (reproduction for SPARK-7858) + checkAnswer( + df.filter('a > 1 && 'p1 < 2).select('b, 'b, 'b, 'b, 'p1, 'p1, 'p1, 'p1), + for (i <- 2 to 3; _ <- Seq("foo", "bar")) + yield Row(s"val_$i", s"val_$i", s"val_$i", s"val_$i", 1, 1, 1, 1)) + // Self-join df.registerTempTable("t") withTempTable("t") { @@ -95,7 +103,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("path", file.getCanonicalPath) .option("dataSchema", dataSchema.json) .load(), @@ -109,7 +117,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { testDF.write.mode(SaveMode.Append).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath).orderBy("a"), testDF.unionAll(testDF).orderBy("a").collect()) @@ -143,7 +151,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .save(file.getCanonicalPath) checkQueries( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath)) } @@ -164,7 +172,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -186,7 +194,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.unionAll(partitionedTestDF).collect()) @@ -208,7 +216,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -244,7 +252,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), testDF.collect()) + checkAnswer(sqlContext.table("t"), testDF.collect()) } } @@ -253,7 +261,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t") withTable("t") { - checkAnswer(table("t"), testDF.unionAll(testDF).orderBy("a").collect()) + checkAnswer(sqlContext.table("t"), testDF.unionAll(testDF).orderBy("a").collect()) } } @@ -272,7 +280,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { withTempTable("t") { testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") - assert(table("t").collect().isEmpty) + assert(sqlContext.table("t").collect().isEmpty) } } @@ -283,7 +291,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkQueries(table("t")) + checkQueries(sqlContext.table("t")) } } @@ -303,7 +311,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), partitionedTestDF.collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) } } @@ -323,7 +331,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) } } @@ -343,7 +351,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), partitionedTestDF.collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) } } @@ -392,7 +400,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .partitionBy("p1", "p2") .saveAsTable("t") - assert(table("t").collect().isEmpty) + assert(sqlContext.table("t").collect().isEmpty) } } @@ -404,7 +412,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .partitionBy("p1", "p2") .save(file.getCanonicalPath) - val df = read + val df = sqlContext.read .format(dataSourceName) .option("dataSchema", dataSchema.json) .load(s"${file.getCanonicalPath}/p1=*/p2=???") @@ -444,10 +452,24 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTempTable("t") { - checkAnswer(table("t"), input.collect()) + checkAnswer(sqlContext.table("t"), input.collect()) } } } + + test("SPARK-7616: adjust column name order accordingly when saving partitioned table") { + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + + df.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("c", "a") + .saveAsTable("t") + + withTable("t") { + checkAnswer(sqlContext.table("t"), df.select('b, 'c, 'a).collect()) + } + } } class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { @@ -479,7 +501,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } } -class CommitFailureTestRelationSuite extends FunSuite with SQLTestUtils { +class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { import TestHive.implicits._ override val sqlContext = TestHive @@ -529,17 +551,62 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { } } - test("SPARK-7616: adjust column name order accordingly when saving partitioned table") { - val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + test("SPARK-7868: _temporary directories should be ignored") { + withTempPath { dir => + val df = Seq("a", "b", "c").zipWithIndex.toDF() - df.write - .format("parquet") - .mode(SaveMode.Overwrite) - .partitionBy("c", "a") - .saveAsTable("t") + df.write + .format("parquet") + .save(dir.getCanonicalPath) - withTable("t") { - checkAnswer(table("t"), df.select('b, 'c, 'a).collect()) + df.write + .format("parquet") + .save(s"${dir.getCanonicalPath}/_temporary") + + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) + } + } + + test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df = Seq(1 -> "a").toDF() + + // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw + // since it's not a valid Parquet file. + val emptyFile = new File(path, "empty") + Files.createParentDirs(emptyFile) + Files.touch(emptyFile) + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Ignore).save(path) + + // This should only complain that the destination directory already exists, rather than file + // "empty" is not a Parquet file. + assert { + intercept[RuntimeException] { + df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) + }.getMessage.contains("already exists") + } + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + checkAnswer(read.format("parquet").load(path), df) + } + } + + test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { + withTempPath { dir => + intercept[AnalysisException] { + // Parquet doesn't allow field names with spaces. Here we are intentionally making an + // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger + // the bug. Please refer to spark-8079 for more details. + range(1, 10) + .withColumnRenamed("id", "a b") + .write + .format("parquet") + .save(dir.getCanonicalPath) + } } } } diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala deleted file mode 100644 index 33e96eaabfbf6..0000000000000 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.net.URI -import java.util.{ArrayList => JArrayList, Properties} - -import scala.collection.JavaConversions._ -import scala.language.implicitConversions - -import org.apache.hadoop.{io => hadoopIo} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} -import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} -import org.apache.hadoop.hive.ql.processors._ -import org.apache.hadoop.hive.ql.stats.StatsSetupConst -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, ObjectInspector, PrimitiveObjectInspector} -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat - -import org.apache.spark.sql.types.{UTF8String, Decimal, DecimalType} - -private[hive] case class HiveFunctionWrapper(functionClassName: String) - extends java.io.Serializable { - - // for Serialization - def this() = this(null) - - import org.apache.spark.util.Utils._ - def createFunction[UDFType <: AnyRef](): UDFType = { - getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] - } -} - -/** - * A compatibility layer for interacting with Hive version 0.12.0. - */ -private[hive] object HiveShim { - val version = "0.12.0" - - def getTableDesc( - serdeClass: Class[_ <: Deserializer], - inputFormatClass: Class[_ <: InputFormat[_, _]], - outputFormatClass: Class[_], - properties: Properties) = { - new TableDesc(serdeClass, inputFormatClass, outputFormatClass, properties) - } - - def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.STRING, - getStringWritable(value)) - - def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.INT, - getIntWritable(value)) - - def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DOUBLE, - getDoubleWritable(value)) - - def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BOOLEAN, - getBooleanWritable(value)) - - def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.LONG, - getLongWritable(value)) - - def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.FLOAT, - getFloatWritable(value)) - - def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.SHORT, - getShortWritable(value)) - - def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BYTE, - getByteWritable(value)) - - def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BINARY, - getBinaryWritable(value)) - - def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DATE, - getDateWritable(value)) - - def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.TIMESTAMP, - getTimestampWritable(value)) - - def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DECIMAL, - getDecimalWritable(value)) - - def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.VOID, null) - - def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) - - def getIntWritable(value: Any): hadoopIo.IntWritable = - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) - - def getDoubleWritable(value: Any): hiveIo.DoubleWritable = - if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double]) - - def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = - if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) - - def getLongWritable(value: Any): hadoopIo.LongWritable = - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) - - def getFloatWritable(value: Any): hadoopIo.FloatWritable = - if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float]) - - def getShortWritable(value: Any): hiveIo.ShortWritable = - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) - - def getByteWritable(value: Any): hiveIo.ByteWritable = - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) - - def getBinaryWritable(value: Any): hadoopIo.BytesWritable = - if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) - - def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) - - def getTimestampWritable(value: Any): hiveIo.TimestampWritable = - if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - } - - def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = - if (value == null) { - null - } else { - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal)) - } - - def getPrimitiveNullWritable: NullWritable = NullWritable.get() - - def createDriverResultsArray = new JArrayList[String] - - def processResults(results: JArrayList[String]) = results - - def getStatsSetupConstTotalSize = StatsSetupConst.TOTAL_SIZE - - def getStatsSetupConstRawDataSize = StatsSetupConst.RAW_DATA_SIZE - - def createDefaultDBIfNeeded(context: HiveContext) = { } - - def getCommandProcessor(cmd: Array[String], conf: HiveConf) = { - CommandProcessorFactory.get(cmd(0), conf) - } - - def createDecimal(bd: java.math.BigDecimal): HiveDecimal = { - new HiveDecimal(bd) - } - - def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - ColumnProjectionUtils.appendReadColumnIDs(conf, ids) - ColumnProjectionUtils.appendReadColumnNames(conf, names) - } - - def getExternalTmpPath(context: Context, uri: URI) = { - context.getExternalTmpFileURI(uri) - } - - def getDataLocationPath(p: Partition) = p.getPartitionPath - - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl) - - def compatibilityBlackList = Seq( - "decimal_.*", - "udf7", - "drop_partitions_filter2", - "show_.*", - "serde_regex", - "udf_to_date", - "udaf_collect_set", - "udf_concat" - ) - - def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { - tbl.setDataLocation(new Path(crtTbl.getLocation()).toUri()) - } - - def decimalMetastoreString(decimalType: DecimalType): String = "decimal" - - def decimalTypeInfo(decimalType: DecimalType): TypeInfo = - TypeInfoFactory.decimalTypeInfo - - def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { - DecimalType.Unlimited - } - - def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - if (hdoi.preferWritable()) { - Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue) - } else { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) - } - } - - def getConvertedOI( - inputOI: ObjectInspector, - outputOI: ObjectInspector): ObjectInspector = { - ObjectInspectorConverters.getConvertedOI(inputOI, outputOI, true) - } - - def prepareWritable(w: Writable): Writable = { - w - } - - def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = {} -} - -private[hive] class ShimFileSinkDesc( - var dir: String, - var tableInfo: TableDesc, - var compressed: Boolean) - extends FileSinkDesc(dir, tableInfo, compressed) { -} diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala deleted file mode 100644 index dbc5e029e2047..0000000000000 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ /dev/null @@ -1,457 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.rmi.server.UID -import java.util.{Properties, ArrayList => JArrayList} -import java.io.{OutputStream, InputStream} - -import scala.collection.JavaConversions._ -import scala.language.implicitConversions -import scala.reflect.ClassTag - -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} -import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} -import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory -import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector} -import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.{io => hadoopIo} - -import org.apache.spark.Logging -import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} -import org.apache.spark.util.Utils._ - -/** - * This class provides the UDF creation and also the UDF instance serialization and - * de-serialization cross process boundary. - * - * Detail discussion can be found at https://github.com/apache/spark/pull/3640 - * - * @param functionClassName UDF class name - */ -private[hive] case class HiveFunctionWrapper(var functionClassName: String) - extends java.io.Externalizable { - - // for Serialization - def this() = this(null) - - @transient - def deserializeObjectByKryo[T: ClassTag]( - kryo: Kryo, - in: InputStream, - clazz: Class[_]): T = { - val inp = new Input(in) - val t: T = kryo.readObject(inp,clazz).asInstanceOf[T] - inp.close() - t - } - - @transient - def serializeObjectByKryo( - kryo: Kryo, - plan: Object, - out: OutputStream ) { - val output: Output = new Output(out) - kryo.writeObject(output, plan) - output.close() - } - - def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { - deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) - .asInstanceOf[UDFType] - } - - def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { - serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) - } - - private var instance: AnyRef = null - - def writeExternal(out: java.io.ObjectOutput) { - // output the function name - out.writeUTF(functionClassName) - - // Write a flag if instance is null or not - out.writeBoolean(instance != null) - if (instance != null) { - // Some of the UDF are serializable, but some others are not - // Hive Utilities can handle both cases - val baos = new java.io.ByteArrayOutputStream() - serializePlan(instance, baos) - val functionInBytes = baos.toByteArray - - // output the function bytes - out.writeInt(functionInBytes.length) - out.write(functionInBytes, 0, functionInBytes.length) - } - } - - def readExternal(in: java.io.ObjectInput) { - // read the function name - functionClassName = in.readUTF() - - if (in.readBoolean()) { - // if the instance is not null - // read the function in bytes - val functionInBytesLength = in.readInt() - val functionInBytes = new Array[Byte](functionInBytesLength) - in.read(functionInBytes, 0, functionInBytesLength) - - // deserialize the function object via Hive Utilities - instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), - getContextOrSparkClassLoader.loadClass(functionClassName)) - } - } - - def createFunction[UDFType <: AnyRef](): UDFType = { - if (instance != null) { - instance.asInstanceOf[UDFType] - } else { - val func = getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] - if (!func.isInstanceOf[UDF]) { - // We cache the function if it's no the Simple UDF, - // as we always have to create new instance for Simple UDF - instance = func - } - func - } - } -} - -/** - * A compatibility layer for interacting with Hive version 0.13.1. - */ -private[hive] object HiveShim { - val version = "0.13.1" - - def getTableDesc( - serdeClass: Class[_ <: Deserializer], - inputFormatClass: Class[_ <: InputFormat[_, _]], - outputFormatClass: Class[_], - properties: Properties) = { - new TableDesc(inputFormatClass, outputFormatClass, properties) - } - - - def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.stringTypeInfo, getStringWritable(value)) - - def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.intTypeInfo, getIntWritable(value)) - - def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) - - def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) - - def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.longTypeInfo, getLongWritable(value)) - - def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) - - def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.shortTypeInfo, getShortWritable(value)) - - def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.byteTypeInfo, getByteWritable(value)) - - def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) - - def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.dateTypeInfo, getDateWritable(value)) - - def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) - - def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) - - def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.voidTypeInfo, null) - - def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) - - def getIntWritable(value: Any): hadoopIo.IntWritable = - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) - - def getDoubleWritable(value: Any): hiveIo.DoubleWritable = - if (value == null) { - null - } else { - new hiveIo.DoubleWritable(value.asInstanceOf[Double]) - } - - def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = - if (value == null) { - null - } else { - new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) - } - - def getLongWritable(value: Any): hadoopIo.LongWritable = - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) - - def getFloatWritable(value: Any): hadoopIo.FloatWritable = - if (value == null) { - null - } else { - new hadoopIo.FloatWritable(value.asInstanceOf[Float]) - } - - def getShortWritable(value: Any): hiveIo.ShortWritable = - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) - - def getByteWritable(value: Any): hiveIo.ByteWritable = - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) - - def getBinaryWritable(value: Any): hadoopIo.BytesWritable = - if (value == null) { - null - } else { - new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) - } - - def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) - - def getTimestampWritable(value: Any): hiveIo.TimestampWritable = - if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - } - - def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = - if (value == null) { - null - } else { - // TODO precise, scale? - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal)) - } - - def getPrimitiveNullWritable: NullWritable = NullWritable.get() - - def createDriverResultsArray = new JArrayList[Object] - - def processResults(results: JArrayList[Object]) = { - results.map { r => - r match { - case s: String => s - case a: Array[Object] => a(0).asInstanceOf[String] - } - } - } - - def getStatsSetupConstTotalSize = StatsSetupConst.TOTAL_SIZE - - def getStatsSetupConstRawDataSize = StatsSetupConst.RAW_DATA_SIZE - - def createDefaultDBIfNeeded(context: HiveContext) = { - context.runSqlHive("CREATE DATABASE default") - context.runSqlHive("USE default") - } - - def getCommandProcessor(cmd: Array[String], conf: HiveConf) = { - CommandProcessorFactory.get(cmd, conf) - } - - def createDecimal(bd: java.math.BigDecimal): HiveDecimal = { - HiveDecimal.create(bd) - } - - /* - * This function in hive-0.13 become private, but we have to do this to walkaround hive bug - */ - private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { - val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") - val result: StringBuilder = new StringBuilder(old) - var first: Boolean = old.isEmpty - - for (col <- cols) { - if (first) { - first = false - } else { - result.append(',') - } - result.append(col) - } - conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) - } - - /* - * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty - */ - def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - if (ids != null && ids.size > 0) { - ColumnProjectionUtils.appendReadColumns(conf, ids) - } - if (names != null && names.size > 0) { - appendReadColumnNames(conf, names) - } - } - - def getExternalTmpPath(context: Context, path: Path) = { - context.getExternalTmpPath(path.toUri) - } - - def getDataLocationPath(p: Partition) = p.getDataLocation - - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsOf(tbl) - - def compatibilityBlackList = Seq() - - def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { - tbl.setDataLocation(new Path(crtTbl.getLocation())) - } - - /* - * Bug introdiced in hive-0.13. FileSinkDesc is serializable, but its member path is not. - * Fix it through wrapper. - * */ - implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { - var f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) - f.setCompressCodec(w.compressCodec) - f.setCompressType(w.compressType) - f.setTableInfo(w.tableInfo) - f.setDestTableId(w.destTableId) - f - } - - // Precision and scale to pass for unlimited decimals; these are the same as the precision and - // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) - private val UNLIMITED_DECIMAL_PRECISION = 38 - private val UNLIMITED_DECIMAL_SCALE = 18 - - def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { - case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" - case _ => s"decimal($UNLIMITED_DECIMAL_PRECISION,$UNLIMITED_DECIMAL_SCALE)" - } - - def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { - case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) - case _ => new DecimalTypeInfo(UNLIMITED_DECIMAL_PRECISION, UNLIMITED_DECIMAL_SCALE) - } - - def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { - val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] - DecimalType(info.precision(), info.scale()) - } - - def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - if (hdoi.preferWritable()) { - Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, - hdoi.precision(), hdoi.scale()) - } else { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) - } - } - - def getConvertedOI(inputOI: ObjectInspector, outputOI: ObjectInspector): ObjectInspector = { - ObjectInspectorConverters.getConvertedOI(inputOI, outputOI) - } - - /* - * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that - * is needed to initialize before serialization. - */ - def prepareWritable(w: Writable): Writable = { - w match { - case w: AvroGenericRecordWritable => - w.setRecordReaderID(new UID()) - case _ => - } - w - } - - def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = { - if (crtTbl != null && crtTbl.getNullFormat() != null) { - tbl.setSerdeParam(serdeConstants.SERIALIZATION_NULL_FORMAT, crtTbl.getNullFormat()) - } - } -} - -/* - * Bug introduced in hive-0.13. FileSinkDesc is serilizable, but its member path is not. - * Fix it through wrapper. - */ -private[hive] class ShimFileSinkDesc( - var dir: String, - var tableInfo: TableDesc, - var compressed: Boolean) - extends Serializable with Logging { - var compressCodec: String = _ - var compressType: String = _ - var destTableId: Int = _ - - def setCompressed(compressed: Boolean) { - this.compressed = compressed - } - - def getDirName = dir - - def setDestTableId(destTableId: Int) { - this.destTableId = destTableId - } - - def setTableInfo(tableInfo: TableDesc) { - this.tableInfo = tableInfo - } - - def setCompressCodec(intermediateCompressorCodec: String) { - compressCodec = intermediateCompressorCodec - } - - def setCompressType(intermediateCompressType: String) { - compressType = intermediateCompressType - } -} diff --git a/streaming/pom.xml b/streaming/pom.xml index 5ab7f4472c38b..697895e72fe5b 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 5e58ed714829e..9cd9684d36404 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.Map import scala.collection.mutable.Queue import scala.reflect.ClassTag +import scala.util.control.NonFatal import akka.actor.{Props, SupervisorStrategy} import org.apache.hadoop.conf.Configuration @@ -270,6 +271,8 @@ class StreamingContext private[streaming] ( * Create an input stream with any arbitrary user implemented receiver. * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver + * + * @deprecated As of 1.0.0", replaced by `receiverStream`. */ @deprecated("Use receiverStream", "1.0.0") def networkStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { @@ -461,7 +464,7 @@ class StreamingContext private[streaming] ( val conf = sc_.hadoopConfiguration conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = fileStream[LongWritable, BytesWritable, FixedLengthBinaryInputFormat]( - directory, FileInputDStream.defaultFilter : Path => Boolean, newFilesOnly=true, conf) + directory, FileInputDStream.defaultFilter: Path => Boolean, newFilesOnly = true, conf) val data = br.map { case (k, v) => val bytes = v.getBytes require(bytes.length == recordLength, "Byte array does not have correct length. " + @@ -576,18 +579,26 @@ class StreamingContext private[streaming] ( def start(): Unit = synchronized { state match { case INITIALIZED => - validate() startSite.set(DStream.getCreationSite()) sparkContext.setCallSite(startSite.get) StreamingContext.ACTIVATION_LOCK.synchronized { StreamingContext.assertNoOtherContextIsActive() - scheduler.start() - uiTab.foreach(_.attach()) - state = StreamingContextState.ACTIVE + try { + validate() + scheduler.start() + state = StreamingContextState.ACTIVE + } catch { + case NonFatal(e) => + logError("Error starting the context, marking it as stopped", e) + scheduler.stop(false) + state = StreamingContextState.STOPPED + throw e + } StreamingContext.setActiveContext(this) } shutdownHookRef = Utils.addShutdownHook( StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) + uiTab.foreach(_.attach()) logInfo("StreamingContext started") case ACTIVE => logWarning("StreamingContext has already been started") @@ -608,6 +619,8 @@ class StreamingContext private[streaming] ( * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. * @param timeout time to wait in milliseconds + * + * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. */ @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long) { @@ -732,6 +745,10 @@ object StreamingContext extends Logging { } } + /** + * @deprecated As of 1.3.0, replaced by implicit functions in the DStream companion object. + * This is kept here only for backward compatibility. + */ @deprecated("Replaced by implicit functions in the DStream companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 93baad19e3ee1..959ac9c177f81 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -227,7 +227,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * @param numPartitions Number of partitions of each RDD in the new DStream. */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - :JavaPairDStream[K, JIterable[V]] = { + : JavaPairDStream[K, JIterable[V]] = { dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) .mapValues(asJavaIterable _) } @@ -247,7 +247,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner - ):JavaPairDStream[K, JIterable[V]] = { + ): JavaPairDStream[K, JIterable[V]] = { dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner) .mapValues(asJavaIterable _) } @@ -262,7 +262,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * batching interval */ def reduceByKeyAndWindow(reduceFunc: JFunction2[V, V, V], windowDuration: Duration) - :JavaPairDStream[K, V] = { + : JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration) } @@ -281,7 +281,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( reduceFunc: JFunction2[V, V, V], windowDuration: Duration, slideDuration: Duration - ):JavaPairDStream[K, V] = { + ): JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index b639b94d5ca47..989e3a729ebc2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -148,6 +148,9 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** The underlying SparkContext */ val sparkContext = new JavaSparkContext(ssc.sc) + /** + * @deprecated As of 0.9.0, replaced by `sparkContext` + */ @deprecated("use sparkContext", "0.9.0") val sc: JavaSparkContext = sparkContext @@ -619,6 +622,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. * @param timeout time to wait in milliseconds + * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. */ @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long): Unit = { @@ -677,6 +681,7 @@ object JavaStreamingContext { * * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( @@ -699,6 +704,7 @@ object JavaStreamingContext { * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( @@ -724,6 +730,7 @@ object JavaStreamingContext { * file system * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index c858647c6406d..192aa6a139bcb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -603,6 +603,8 @@ abstract class DStream[T: ClassTag] ( /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of 0.9.0, replaced by `foreachRDD`. */ @deprecated("use foreachRDD", "0.9.0") def foreach(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { @@ -612,6 +614,8 @@ abstract class DStream[T: ClassTag] ( /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of 0.9.0, replaced by `foreachRDD`. */ @deprecated("use foreachRDD", "0.9.0") def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = ssc.withScope { @@ -659,7 +663,7 @@ abstract class DStream[T: ClassTag] ( // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean val cleanedF = context.sparkContext.clean(transformFunc, false) - val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { + val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { assert(rdds.length == 1) cleanedF(rdds.head.asInstanceOf[RDD[T]], time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index eca69f00188e4..6c1fab56740ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -69,7 +69,7 @@ import org.apache.spark.util.{TimeStampedHashMap, Utils} * processing semantics are undefined. */ private[streaming] -class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( +class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( @transient ssc_ : StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, @@ -251,7 +251,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( /** Generate one RDD from an array of files */ private def filesToRDD(files: Seq[String]): RDD[(K, V)] = { - val fileRDDs = files.map(file =>{ + val fileRDDs = files.map { file => val rdd = serializableConfOpt.map(_.value) match { case Some(config) => context.sparkContext.newAPIHadoopFile( file, @@ -267,7 +267,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( "Refer to the streaming programming guide for more details.") } rdd - }) + } new UnionRDD(context.sparkContext, fileRDDs) } @@ -294,7 +294,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() - generatedRDDs = new mutable.HashMap[Time, RDD[(K,V)]] () + generatedRDDs = new mutable.HashMap[Time, RDD[(K, V)]]() batchTimeToSelectedFiles = new mutable.HashMap[Time, Array[String]] with mutable.SynchronizedMap[Time, Array[String]] recentlySelectedFiles = new mutable.HashSet[String]() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index fda22eb6ec42e..358e4c66df7ba 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming.StreamingContext.rddToFileName /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. */ -class PairDStreamFunctions[K, V](self: DStream[(K,V)]) +class PairDStreamFunctions[K, V](self: DStream[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K]) extends Serializable { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index e4ff05e12f201..e76e7eb0dea19 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -70,7 +70,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray // Register the input blocks information into InputInfoTracker - val inputInfo = InputInfo(id, blockInfos.map(_.numRecords).sum) + val inputInfo = InputInfo(id, blockInfos.flatMap(_.numRecords).sum) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) if (blockInfos.nonEmpty) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index df9f7f140eddc..6a583bf2a3626 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -38,7 +38,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( _windowDuration: Duration, _slideDuration: Duration, partitioner: Partitioner - ) extends DStream[(K,V)](parent.ssc) { + ) extends DStream[(K, V)](parent.ssc) { require(_windowDuration.isMultipleOf(parent.slideDuration), "The window duration of ReducedWindowedDStream (" + _windowDuration + ") " + @@ -58,7 +58,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( super.persist(StorageLevel.MEMORY_ONLY_SER) reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) - def windowDuration: Duration = _windowDuration + def windowDuration: Duration = _windowDuration override def dependencies: List[DStream[_]] = List(reducedStream) @@ -68,7 +68,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( override def parentRememberDuration: Duration = rememberDuration + windowDuration - override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + override def persist(storageLevel: StorageLevel): DStream[(K, V)] = { super.persist(storageLevel) reducedStream.persist(storageLevel) this @@ -118,7 +118,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( // Get the RDD of the reduced value of the previous window val previousWindowRDD = - getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) + getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K, V)]())) // Make the list of RDDs that needs to cogrouped together for reducing their reduced values val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index 7757ccac09a58..e0ffd5d86b435 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -25,19 +25,19 @@ import scala.reflect.ClassTag private[streaming] class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag]( - parent: DStream[(K,V)], + parent: DStream[(K, V)], createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true - ) extends DStream[(K,C)] (parent.ssc) { + ) extends DStream[(K, C)] (parent.ssc) { override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration - override def compute(validTime: Time): Option[RDD[(K,C)]] = { + override def compute(validTime: Time): Option[RDD[(K, C)]] = { parent.getOrCompute(validTime) match { case Some(rdd) => Some(rdd.combineByKey[C]( createCombiner, mergeValue, mergeCombiner, partitioner, mapSideCombine)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 8b72bcf20653d..5ce5b7aae6e69 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.dstream +import scala.util.control.NonFatal + import org.apache.spark.streaming.StreamingContext import org.apache.spark.storage.StorageLevel import org.apache.spark.util.NextIterator @@ -74,13 +76,17 @@ class SocketReceiver[T: ClassTag]( while(!isStopped && iterator.hasNext) { store(iterator.next) } - logInfo("Stopped receiving") - restart("Retrying connecting to " + host + ":" + port) + if (!isStopped()) { + restart("Socket data stream had no more data") + } else { + logInfo("Stopped receiving") + } } catch { case e: java.net.ConnectException => restart("Error connecting to " + host + ":" + port, e) - case t: Throwable => - restart("Error receiving data", t) + case NonFatal(e) => + logWarning("Error receiving data", e) + restart("Error receiving data", e) } finally { if (socket != null) { socket.close() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index de8718d0a80fe..621d6dff788f4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -51,7 +51,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => { val i = iterator.map(t => { val itr = t._2._2.iterator - val headOption = if(itr.hasNext) Some(itr.next) else None + val headOption = if (itr.hasNext) Some(itr.next()) else None (t._1, t._2._1.toSeq, headOption) }) updateFuncLocal(i) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 899865a906c27..4efba039f8959 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -44,7 +44,7 @@ class WindowedDStream[T: ClassTag]( // Persist parent level by default, as those RDDs are going to be obviously reused. parent.persist(StorageLevel.MEMORY_ONLY_SER) - def windowDuration: Duration = _windowDuration + def windowDuration: Duration = _windowDuration override def dependencies: List[DStream[_]] = List(parent) @@ -68,7 +68,7 @@ class WindowedDStream[T: ClassTag]( new PartitionerAwareUnionRDD(ssc.sc, rddsInWindow) } else { logDebug("Using normal union for windowing at " + validTime) - new UnionRDD(ssc.sc,rddsInWindow) + new UnionRDD(ssc.sc, rddsInWindow) } Some(windowRDD) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 4bebcc5aa7ca0..8d73593ab6375 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -164,7 +164,7 @@ private[streaming] class BlockGenerator( private def keepPushingBlocks() { logInfo("Started block pushing thread") try { - while(!stopped) { + while (!stopped) { Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { case Some(block) => pushBlock(block) case None => @@ -191,7 +191,7 @@ private[streaming] class BlockGenerator( logError(message, t) listener.onError(message, t) } - + private def pushBlock(block: Block) { listener.onPushBlock(block.id, block.buffer) logInfo("Pushed block " + block.id) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index 97db9ded83367..8df542b367d27 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -17,8 +17,9 @@ package org.apache.spark.streaming.receiver +import com.google.common.util.concurrent.{RateLimiter => GuavaRateLimiter} + import org.apache.spark.{Logging, SparkConf} -import com.google.common.util.concurrent.{RateLimiter=>GuavaRateLimiter} /** Provides waitToPush() method to limit the rate at which receivers consume data. * diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 651b534ac1900..207d64d9414ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -62,7 +62,7 @@ private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockI private[streaming] class BlockManagerBasedBlockHandler( blockManager: BlockManager, storageLevel: StorageLevel) extends ReceivedBlockHandler with Logging { - + def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { val putResult: Seq[(BlockId, BlockStatus)] = block match { case ArrayBufferBlock(arrayBuffer) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 92938379b9c17..8be732b64e3a3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -138,8 +138,8 @@ private[streaming] class ReceiverSupervisorImpl( ) { val blockId = blockIdOption.getOrElse(nextBlockId) val numRecords = receivedBlock match { - case ArrayBufferBlock(arrayBuffer) => arrayBuffer.size - case _ => -1 + case ArrayBufferBlock(arrayBuffer) => Some(arrayBuffer.size.toLong) + case _ => None } val time = System.currentTimeMillis diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index a72efccf2f994..7c0db8a863c67 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -23,7 +23,9 @@ import org.apache.spark.Logging import org.apache.spark.streaming.{Time, StreamingContext} /** To track the information of input stream at specified batch time. */ -private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) +private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) { + require(numRecords >= 0, "numRecords must not be negative") +} /** * This class manages all the input streams as well as their input data statistics. The information diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 1d1ddaaccf217..4af9b6d3b56ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -126,6 +126,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { eventLoop.post(ErrorReported(msg, e)) } + def isStarted(): Boolean = synchronized { + eventLoop != null + } + private def processEvent(event: JobSchedulerEvent) { try { event match { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala index dc11e84f29965..656ac80df8979 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala @@ -24,11 +24,13 @@ import org.apache.spark.streaming.util.WriteAheadLogRecordHandle /** Information about blocks received by the receiver */ private[streaming] case class ReceivedBlockInfo( streamId: Int, - numRecords: Long, + numRecords: Option[Long], metadataOption: Option[Any], blockStoreResult: ReceivedBlockStoreResult ) { + require(numRecords.isEmpty || numRecords.get >= 0, "numRecords must not be negative") + @volatile private var _isBlockIdValid = true def blockId: StreamBlockId = blockStoreResult.blockId diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index f73f7e705ee0d..f1504b09c9873 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -230,7 +230,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false class ReceiverLauncher { @transient val env = ssc.env @volatile @transient private var running = false - @transient val thread = new Thread() { + @transient val thread = new Thread() { override def run() { try { SparkEnv.set(env) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 87ba4f84a9ceb..fe6328b1ce727 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -200,7 +200,7 @@ private[streaming] class FileBasedWriteAheadLog( /** Initialize the log directory or recover existing logs inside the directory */ private def initializeOrRecover(): Unit = synchronized { val logDirectoryPath = new Path(logDirectory) - val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) + val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath }) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index 4d968f8bfa7a8..408936653c790 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -27,7 +27,7 @@ object RawTextHelper { * Splits lines and counts the words. */ def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { - val map = new OpenHashMap[String,Long] + val map = new OpenHashMap[String, Long] var i = 0 var j = 0 while (iter.hasNext) { @@ -98,7 +98,7 @@ object RawTextHelper { * before real workload starts. */ def warmUp(sc: SparkContext) { - for(i <- 0 to 1) { + for (i <- 0 to 1) { sc.parallelize(1 to 200000, 1000) .map(_ % 1331).map(_.toString) .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index f269cb74e0c2b..08faeaa58f419 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -255,7 +255,7 @@ class BasicOperationsSuite extends TestSuiteBase { Seq( ) ) val operation = (s1: DStream[String], s2: DStream[String]) => { - s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))).mapValues(x => (x._1.toSeq, x._2.toSeq)) + s1.map(x => (x, 1)).cogroup(s2.map(x => (x, "x"))).mapValues(x => (x._1.toSeq, x._2.toSeq)) } testOperation(inputData1, inputData2, operation, outputData, true) } @@ -427,9 +427,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("updateStateByKey - object lifecycle") { val inputData = Seq( - Seq("a","b"), + Seq("a", "b"), null, - Seq("a","c","a"), + Seq("a", "c", "a"), Seq("c"), null, null diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala index 6a1dd6949b204..9b5e4dc819a2b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.streaming import java.io.NotSerializableException -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{HashPartitioner, SparkContext, SparkException} +import org.apache.spark.{HashPartitioner, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.ReturnStatementInClosureException @@ -29,7 +29,7 @@ import org.apache.spark.util.ReturnStatementInClosureException /** * Test that closures passed to DStream operations are actually cleaned. */ -class DStreamClosureSuite extends FunSuite with BeforeAndAfterAll { +class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { private var ssc: StreamingContext = null override def beforeAll(): Unit = { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala index e3fb2ef130859..8844c9d74b933 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.streaming -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.ui.UIUtils @@ -27,7 +27,7 @@ import org.apache.spark.streaming.ui.UIUtils /** * Tests whether scope information is passed from DStream operations to RDDs correctly. */ -class DStreamScopeSuite extends FunSuite with BeforeAndAfter with BeforeAndAfterAll { +class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { private var ssc: StreamingContext = null private val batchDuration: Duration = Seconds(1) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 0122514f9374c..b74d67c63a788 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -418,7 +418,7 @@ class TestServer(portToBind: Int = 0) extends Logging { val servingThread = new Thread() { override def run() { try { - while(true) { + while (true) { logInfo("Accepting connections on port " + port) val clientSocket = serverSocket.accept() if (startLatch.getCount == 1) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 23804237bda80..cca8cedb1d080 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -25,7 +25,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ @@ -41,7 +41,11 @@ import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ -class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class ReceivedBlockHandlerSuite + extends SparkFunSuite + with BeforeAndAfter + with Matchers + with Logging { val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val hadoopConf = new Configuration() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index b1af8d5eaacfb..be305b5e0dfea 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -25,10 +25,10 @@ import scala.language.{implicitConversions, postfixOps} import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ @@ -37,7 +37,7 @@ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} class ReceivedBlockTrackerSuite - extends FunSuite with BeforeAndAfter with Matchers with Logging { + extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val hadoopConf = new Configuration() val akkaTimeout = 10 seconds @@ -224,7 +224,7 @@ class ReceivedBlockTrackerSuite /** Generate blocks infos using random ids */ def generateBlockInfos(): Seq[ReceivedBlockInfo] = { - List.fill(5)(ReceivedBlockInfo(streamId, 0, None, + List.fill(5)(ReceivedBlockInfo(streamId, Some(0L), None, BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index f8e8030791df1..819dd2ccfe915 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -25,16 +25,16 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} +import org.scalatest.{Assertions, BeforeAndAfter} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, SparkFunSuite} -class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { +class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging { val master = "local[2]" val appName = this.getClass.getSimpleName @@ -151,6 +151,22 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w assert(StreamingContext.getActive().isEmpty) } + test("start failure should stop internal components") { + ssc = new StreamingContext(conf, batchDuration) + val inputStream = addInputStream(ssc) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + inputStream.map(x => (x, 1)).updateStateByKey[Int](updateFunc) + // Require that the start fails because checkpoint directory was not set + intercept[Exception] { + ssc.start() + } + assert(ssc.getState() === StreamingContextState.STOPPED) + assert(ssc.scheduler.isStarted === false) + } + + test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() @@ -732,7 +748,9 @@ class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) def onStop() { // Simulate slow receiver by waiting for all records to be produced - while(!SlowTestReceiver.receivedAllRecords) Thread.sleep(100) + while (!SlowTestReceiver.receivedAllRecords) { + Thread.sleep(100) + } // no clean to be done, the receiving thread should stop on it own } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 312cce408cfe7..1dc8960d60528 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -133,8 +133,10 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { /** Check if a sequence of numbers is in increasing order */ def isInIncreasingOrder(seq: Seq[Long]): Boolean = { - for(i <- 1 until seq.size) { - if (seq(i - 1) > seq(i)) return false + for (i <- 1 until seq.size) { + if (seq(i - 1) > seq(i)) { + return false + } } true } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 554cd30223f44..31b1aebf6a8ec 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -24,12 +24,12 @@ import scala.collection.mutable.SynchronizedBuffer import scala.language.implicitConversions import scala.reflect.ClassTag -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.time.{Span, Seconds => ScalaTestSeconds} import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} import org.apache.spark.streaming.scheduler._ @@ -204,7 +204,7 @@ class BatchCounter(ssc: StreamingContext) { * This is the base trait for Spark Streaming testsuites. This provides basic functionality * to run user-defined set of input on user-defined stream operations, and verify the output. */ -trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { +trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context def framework: String = this.getClass.getSimpleName diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 441bbf95d0153..cbc24aee4fa1e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -28,14 +28,11 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ - - - /** * Selenium tests for the Spark Web UI. */ class UISeleniumSuite - extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { + extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { implicit var webDriver: WebDriver = _ @@ -197,4 +194,3 @@ class UISeleniumSuite } } } - diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 6859b65c7165f..cb017b798b2a4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -21,15 +21,15 @@ import java.io.File import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} class WriteAheadLogBackedBlockRDDSuite - extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterEach { val conf = new SparkConf() .setMaster("local[2]") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index 5478b41845943..2e210397fe7c7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.streaming.scheduler -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.streaming.{Time, Duration, StreamingContext} -class InputInfoTrackerSuite extends FunSuite with BeforeAndAfter { +class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { private var ssc: StreamingContext = _ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 2a0f45830e03c..c9175d61b1f49 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -64,7 +64,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala index e9ab917ab845c..d3ca2b58f36c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.streaming.ui import java.util.TimeZone import java.util.concurrent.TimeUnit -import org.scalatest.FunSuite import org.scalatest.Matchers -class UIUtilsSuite extends FunSuite with Matchers{ +import org.apache.spark.SparkFunSuite + +class UIUtilsSuite extends SparkFunSuite with Matchers{ test("shortTimeUnitString") { assert("ns" === UIUtils.shortTimeUnitString(TimeUnit.NANOSECONDS)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala index 9ebf7b484f421..78fc344b00177 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.streaming.util import java.io.ByteArrayOutputStream import java.util.concurrent.TimeUnit._ -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class RateLimitedOutputStreamSuite extends FunSuite { +class RateLimitedOutputStreamSuite extends SparkFunSuite { private def benchmark[U](f: => U): Long = { val start = System.nanoTime diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 79098bcf4861c..325ff7c74c39d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -28,15 +28,15 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} -class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { +class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { import WriteAheadLogSuite._ - + val hadoopConf = new Configuration() var tempDir: File = null var testDir: String = null @@ -359,7 +359,7 @@ object WriteAheadLogSuite { ): FileBasedWriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) - + // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => manualClock.advance(500) diff --git a/tools/pom.xml b/tools/pom.xml index 1c6f3e83a1819..feffde4c857eb 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 2fd17267ac427..62c6354f1e203 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java index 24b2892098059..192c6714b2406 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java @@ -25,8 +25,7 @@ public final class PlatformDependent { /** * Facade in front of {@link sun.misc.Unsafe}, used to avoid directly exposing Unsafe outside of - * this package. This also lets us aovid accidental use of deprecated methods or methods that - * aren't present in Java 6. + * this package. This also lets us avoid accidental use of deprecated methods. */ public static final class UNSAFE { diff --git a/yarn/pom.xml b/yarn/pom.xml index 00d219f836708..644def7501dc8 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -39,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.hadoop hadoop-yarn-api diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala index aaae6f9734a85..77af46c192cc2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -60,8 +60,11 @@ private[yarn] class AMDelegationTokenRenewer( private val hadoopUtil = YarnSparkHadoopUtil.get - private val daysToKeepFiles = sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) - private val numFilesToKeep = sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + private val daysToKeepFiles = + sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) + private val numFilesToKeep = + sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) /** * Schedule a login from the keytab and principal set using the --principal and --keytab @@ -121,7 +124,7 @@ private[yarn] class AMDelegationTokenRenewer( import scala.concurrent.duration._ try { val remoteFs = FileSystem.get(hadoopConf) - val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val credentialsPath = new Path(credentialsFile) val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles days).toMillis hadoopUtil.listFilesSorted( remoteFs, credentialsPath.getParent, @@ -160,7 +163,7 @@ private[yarn] class AMDelegationTokenRenewer( val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) logInfo("Successfully logged into KDC.") val tempCreds = keytabLoggedInUGI.getCredentials - val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val credentialsPath = new Path(credentialsFile) val dst = credentialsPath.getParent keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { // Get a copy of the credentials @@ -186,8 +189,7 @@ private[yarn] class AMDelegationTokenRenewer( } val nextSuffix = lastCredentialsFileSuffix + 1 val tokenPathStr = - sparkConf.get("spark.yarn.credentials.file") + - SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix + credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix val tokenPath = new Path(tokenPathStr) val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) logInfo("Writing out delegation tokens to " + tempTokenPath.toString) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index af4927b0e4bf7..002d7b6eaf498 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -34,7 +34,7 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, Spar import org.apache.spark.SparkException import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} import org.apache.spark.deploy.history.HistoryServer -import org.apache.spark.scheduler.cluster.YarnSchedulerBackend +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util._ @@ -67,6 +67,7 @@ private[spark] class ApplicationMaster( @volatile private var reporterThread: Thread = _ @volatile private var allocator: YarnAllocator = _ + private val allocatorLock = new Object() // Fields used in client mode. private var rpcEnv: RpcEnv = null @@ -220,7 +221,7 @@ private[spark] class ApplicationMaster( sparkContextRef.compareAndSet(sc, null) } - private def registerAM(uiAddress: String, securityMgr: SecurityManager) = { + private def registerAM(_rpcEnv: RpcEnv, uiAddress: String, securityMgr: SecurityManager) = { val sc = sparkContextRef.get() val appId = client.getAttemptId().getApplicationId().toString() @@ -231,8 +232,14 @@ private[spark] class ApplicationMaster( .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } .getOrElse("") - allocator = client.register(yarnConf, - if (sc != null) sc.getConf else sparkConf, + val _sparkConf = if (sc != null) sc.getConf else sparkConf + val driverUrl = _rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + allocator = client.register(driverUrl, + yarnConf, + _sparkConf, if (sc != null) sc.preferredNodeLocationData else Map(), uiAddress, historyAddress, @@ -279,7 +286,7 @@ private[spark] class ApplicationMaster( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + registerAM(rpcEnv, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } } @@ -289,7 +296,7 @@ private[spark] class ApplicationMaster( rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) waitForSparkDriver() addAmIpFilter() - registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + registerAM(rpcEnv, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -353,7 +360,9 @@ private[spark] class ApplicationMaster( } logDebug(s"Number of pending allocations is $numPendingAllocate. " + s"Sleeping for $sleepInterval.") - Thread.sleep(sleepInterval) + allocatorLock.synchronized { + allocatorLock.wait(sleepInterval) + } } catch { case e: InterruptedException => } @@ -540,8 +549,15 @@ private[spark] class ApplicationMaster( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestExecutors(requestedTotal) => Option(allocator) match { - case Some(a) => a.requestTotalExecutors(requestedTotal) - case None => logWarning("Container allocator is not ready to request executors yet.") + case Some(a) => + allocatorLock.synchronized { + if (a.requestTotalExecutors(requestedTotal)) { + allocatorLock.notifyAll() + } + } + + case None => + logWarning("Container allocator is not ready to request executors yet.") } context.reply(true) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7e023f2d92578..f4d43214b08ca 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -121,24 +121,31 @@ private[spark] class Client( } catch { case e: Throwable => if (appId != null) { - val appStagingDir = getAppStagingDir(appId) - try { - val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) - val stagingDirPath = new Path(appStagingDir) - val fs = FileSystem.get(hadoopConf) - if (!preserveFiles && fs.exists(stagingDirPath)) { - logInfo("Deleting staging directory " + stagingDirPath) - fs.delete(stagingDirPath, true) - } - } catch { - case ioe: IOException => - logWarning("Failed to cleanup staging dir " + appStagingDir, ioe) - } + cleanupStagingDir(appId) } throw e } } + /** + * Cleanup application staging directory. + */ + private def cleanupStagingDir(appId: ApplicationId): Unit = { + val appStagingDir = getAppStagingDir(appId) + try { + val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) + val stagingDirPath = new Path(appStagingDir) + val fs = FileSystem.get(hadoopConf) + if (!preserveFiles && fs.exists(stagingDirPath)) { + logInfo("Deleting staging directory " + stagingDirPath) + fs.delete(stagingDirPath, true) + } + } catch { + case ioe: IOException => + logWarning("Failed to cleanup staging dir " + appStagingDir, ioe) + } + } + /** * Set up the context for submitting our ApplicationMaster. * This uses the YarnClientApplication not available in the Yarn alpha API. @@ -782,6 +789,7 @@ private[spark] class Client( if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { + cleanupStagingDir(appId) return (state, report.getFinalApplicationStatus) } @@ -1142,9 +1150,9 @@ object Client extends Logging { logDebug("HiveMetaStore configured in localmode") } } catch { - case e:java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } - case e:java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } - case e:Exception => { logError("Unexpected Exception " + e) + case e: java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } + case e: java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } + case e: Exception => { logError("Unexpected Exception " + e) throw new RuntimeException("Unexpected exception", e) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 5653c9f14dc6d..9c7b1b3988082 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -98,6 +98,12 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) numExecutors = initialNumExecutors } + principal = Option(principal) + .orElse(sparkConf.getOption("spark.yarn.principal")) + .orNull + keytab = Option(keytab) + .orElse(sparkConf.getOption("spark.yarn.keytab")) + .orNull } /** diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index c592ecfdfce06..3d3a966960e9f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -43,22 +43,22 @@ private[spark] class ClientDistributedCacheManager() extends Logging { * Add a resource to the list of distributed cache resources. This list can * be sent to the ApplicationMaster and possibly the executors so that it can * be downloaded into the Hadoop distributed cache for use by this application. - * Adds the LocalResource to the localResources HashMap passed in and saves + * Adds the LocalResource to the localResources HashMap passed in and saves * the stats of the resources to they can be sent to the executors and verified. * * @param fs FileSystem * @param conf Configuration * @param destPath path to the resource * @param localResources localResource hashMap to insert the resource into - * @param resourceType LocalResourceType + * @param resourceType LocalResourceType * @param link link presented in the distributed cache to the destination - * @param statCache cache to store the file/directory stats + * @param statCache cache to store the file/directory stats * @param appMasterOnly Whether to only add the resource to the app master */ def addResource( fs: FileSystem, conf: Configuration, - destPath: Path, + destPath: Path, localResources: HashMap[String, LocalResource], resourceType: LocalResourceType, link: String, @@ -74,15 +74,15 @@ private[spark] class ClientDistributedCacheManager() extends Logging { amJarRsrc.setSize(destStatus.getLen()) if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name") localResources(link) = amJarRsrc - + if (!appMasterOnly) { val uri = destPath.toUri() val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link) if (resourceType == LocalResourceType.FILE) { - distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } else { - distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } } @@ -95,13 +95,13 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (keys, tupleValues) = distCacheFiles.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { - env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = - timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = - sizes.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = - visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = + sizes.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -112,13 +112,13 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (keys, tupleValues) = distCacheArchives.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { - env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = - timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = - sizes.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = - visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + sizes.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -160,7 +160,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { def ancestorsHaveExecutePermissions( fs: FileSystem, path: Path, - statCache: Map[URI, FileStatus]): Boolean = { + statCache: Map[URI, FileStatus]): Boolean = { var current = path while (current != null) { // the subdirs in the path should have execute permissions for others @@ -197,7 +197,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = { val stat = statCache.get(uri) match { case Some(existstat) => existstat - case None => + case None => val newStat = fs.getFileStatus(new Path(uri)) statCache.put(uri, newStat) newStat diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 8a08f561a2df2..940873fbd046c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -34,10 +34,8 @@ import org.apache.hadoop.yarn.util.RackResolver import org.apache.log4j.{Level, Logger} -import org.apache.spark.{SparkEnv, Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.AkkaUtils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -53,6 +51,7 @@ import org.apache.spark.util.AkkaUtils * synchronized. */ private[yarn] class YarnAllocator( + driverUrl: String, conf: Configuration, sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], @@ -107,13 +106,6 @@ private[yarn] class YarnAllocator( new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) launcherPool.allowCoreThreadTimeOut(true) - private val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(securityMgr.akkaSSLOptions.enabled), - SparkEnv.driverActorSystemName, - sparkConf.get("spark.driver.host"), - sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) - // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) @@ -154,11 +146,16 @@ private[yarn] class YarnAllocator( * Request as many executors from the ResourceManager as needed to reach the desired total. If * the requested total is smaller than the current number of running executors, no executors will * be killed. + * + * @return Whether the new requested total is different than the old value. */ - def requestTotalExecutors(requestedTotal: Int): Unit = synchronized { + def requestTotalExecutors(requestedTotal: Int): Boolean = synchronized { if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal + true + } else { + false } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index ffe71dfd7d257..7f533ee55e8bb 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -55,6 +55,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg * @param uiHistoryAddress Address of the application on the History Server. */ def register( + driverUrl: String, conf: YarnConfiguration, sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], @@ -72,7 +73,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(conf, sparkConf, amClient, getAttemptId(), args, securityMgr) + new YarnAllocator(driverUrl, conf, sparkConf, amClient, getAttemptId(), args, securityMgr) } /** diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 5e6531895c7ba..68d01c17ef720 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -144,9 +144,9 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } object YarnSparkHadoopUtil { - // Additional memory overhead + // Additional memory overhead // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering - // the common cases. Memory overhead tends to grow with container size. + // the common cases. Memory overhead tends to grow with container size. val MEMORY_OVERHEAD_FACTOR = 0.10 val MEMORY_OVERHEAD_MIN = 384 diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 80b57d1355a3a..804dfecde7867 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy.yarn import java.net.URI -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar import org.mockito.Mockito.when @@ -36,16 +35,18 @@ import org.apache.hadoop.yarn.util.{Records, ConverterUtils} import scala.collection.mutable.HashMap import scala.collection.mutable.Map +import org.apache.spark.SparkFunSuite -class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { + +class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar { class MockClientDistributedCacheManager extends ClientDistributedCacheManager { - override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): + override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): LocalResourceVisibility = { LocalResourceVisibility.PRIVATE } } - + test("test getFileStatus empty") { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] @@ -60,7 +61,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] val uri = new URI("/tmp/testing") - val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", + val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus()) val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus) @@ -77,7 +78,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -100,11 +101,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) // add another one and verify both there and order correct - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing2")) val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2") when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", + distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", statCache, false) val resource2 = localResources("link2") assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -116,7 +117,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val env2 = new HashMap[String, String]() distMgr.setDistFilesEnv(env2) val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') - val files = env2("SPARK_YARN_CACHE_FILES").split(',') + val files = env2("SPARK_YARN_CACHE_FILES").split(',') val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',') assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link") @@ -140,7 +141,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) intercept[Exception] { - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, statCache, false) } assert(localResources.get("link") === None) @@ -154,11 +155,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, true) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -188,11 +189,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 508819e242a26..01d33c9ce9297 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -33,12 +33,12 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{SparkException, SparkConf} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils -class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { +class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { override def beforeAll(): Unit = { System.setProperty("SPARK_YARN_MODE", "true") @@ -203,7 +203,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { def getFieldValue2[A: ClassTag, A1: ClassTag, B]( clazz: Class[_], field: String, - defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { + defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { Try(clazz.getField(field)).map(_.get(null)).map { case v: A => mapTo(v) case v1: A1 => mapTo1(v1) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 455f1019d86dd..7509000771d94 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -26,13 +26,13 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.spark.SecurityManager +import org.apache.spark.{SecurityManager, SparkFunSuite} import org.apache.spark.SparkConf import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ import org.apache.spark.scheduler.SplitInfo -import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} class MockResolver extends DNSToSwitchMapping { @@ -46,7 +46,7 @@ class MockResolver extends DNSToSwitchMapping { def reloadCachedMappings(names: JList[String]) {} } -class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach { +class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { val conf = new Configuration() conf.setClass( CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, @@ -90,6 +90,7 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach "--jar", "somejar.jar", "--class", "SomeClass") new YarnAllocator( + "not used", conf, sparkConf, rmClient, diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index dcaeb2e43ff41..93d587d0cb36a 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -18,21 +18,21 @@ package org.apache.spark.deploy.yarn import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.net.URL import java.util.Properties import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ import scala.collection.mutable -import scala.io.Source import com.google.common.base.Charsets.UTF_8 import com.google.common.io.ByteStreams import com.google.common.io.Files import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils} +import org.apache.spark._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} @@ -43,7 +43,7 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging { +class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { // log4j configuration for the YARN containers, so that their output is collected // by YARN instead of trying to overwrite unit-tests.log. @@ -326,7 +326,7 @@ private object YarnClusterDriver extends Logging with Matchers { var result = "failure" try { val data = sc.parallelize(1 to 4, 4).collect().toSet - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) data should be (Set(1, 2, 3, 4)) result = "success" } finally { @@ -344,18 +344,20 @@ private object YarnClusterDriver extends Logging with Matchers { assert(info.logUrlMap.nonEmpty) } - // If we are running in yarn-cluster mode, verify that driver logs are downloadable. + // If we are running in yarn-cluster mode, verify that driver logs links and present and are + // in the expected format. if (conf.get("spark.master") == "yarn-cluster") { assert(listener.driverLogs.nonEmpty) val driverLogs = listener.driverLogs.get assert(driverLogs.size === 2) assert(driverLogs.containsKey("stderr")) assert(driverLogs.containsKey("stdout")) - val stderr = driverLogs("stderr") // YARN puts everything in stderr. - val lines = Source.fromURL(stderr).getLines() - // Look for a line that contains YarnClusterSchedulerBackend, since that is guaranteed in - // cluster mode. - assert(lines.exists(_.contains("YarnClusterSchedulerBackend"))) + val urlStr = driverLogs("stderr") + // Ensure that this is a valid URL, else this will throw an exception + new URL(urlStr) + val containerId = YarnSparkHadoopUtil.get.getContainerId + val user = Utils.getCurrentUserName() + assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=0")) } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index e10b985c3c236..49bee0866dd43 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -25,15 +25,15 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.apache.hadoop.yarn.api.records.ApplicationAccessType -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils -class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { +class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging { val hasBash = try {