diff --git a/.rat-excludes b/.rat-excludes index 0240e81c45ea2..236c2db05367c 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -91,3 +91,5 @@ help/* html/* INDEX .lintr +gen-java.* +.*avpr diff --git a/LICENSE b/LICENSE index 8672be55eca3e..f9e412cade345 100644 --- a/LICENSE +++ b/LICENSE @@ -948,6 +948,6 @@ The following components are provided under the MIT License. See project link fo (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-core:1.8.5 - http://www.mockito.org) + (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) diff --git a/R/README.md b/R/README.md index d7d65b4f0eca5..005f56da1670c 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R #### Build Spark -Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run ``` build/mvn -DskipTests -Psparkr package ``` diff --git a/R/install-dev.bat b/R/install-dev.bat index 008a5c668bc45..f32670b67de96 100644 --- a/R/install-dev.bat +++ b/R/install-dev.bat @@ -25,3 +25,8 @@ set SPARK_HOME=%~dp0.. MKDIR %SPARK_HOME%\R\lib R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ + +rem Zip the SparkR package so that it can be distributed to worker nodes on YARN +pushd %SPARK_HOME%\R\lib +%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR +popd diff --git a/R/install-dev.sh b/R/install-dev.sh index 1edd551f8d243..4972bb9217072 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -34,7 +34,7 @@ LIB_DIR="$FWDIR/lib" mkdir -p $LIB_DIR -pushd $FWDIR +pushd $FWDIR > /dev/null # Generate Rd files if devtools is installed Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' @@ -42,4 +42,8 @@ Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtoo # Install SparkR to $LIB_DIR R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ -popd +# Zip the SparkR package so that it can be distributed to worker nodes on YARN +cd $LIB_DIR +jar cfM "$LIB_DIR/sparkr.zip" SparkR + +popd > /dev/null diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index efc85bbc4b316..d028821534b1a 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -32,4 +32,3 @@ Collate: 'serialize.R' 'sparkR.R' 'utils.R' - 'zzz.R' diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 6feabf4189c2d..60702824acb46 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -169,8 +169,8 @@ setMethod("isLocal", #'} setMethod("showDF", signature(x = "DataFrame"), - function(x, numRows = 20) { - s <- callJMethod(x@sdf, "showString", numToInt(numRows)) + function(x, numRows = 20, truncate = TRUE) { + s <- callJMethod(x@sdf, "showString", numToInt(numRows), truncate) cat(s) }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 89511141d3ef7..d2d096709245d 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -165,7 +165,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), serializedFuncArr, rdd@env$prev_serializedMode, packageNamesArr, - as.character(.sparkREnv[["libname"]]), broadcastArr, callJMethod(prev_jrdd, "classTag")) } else { @@ -175,7 +174,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), rdd@env$prev_serializedMode, serializedMode, packageNamesArr, - as.character(.sparkREnv[["libname"]]), broadcastArr, callJMethod(prev_jrdd, "classTag")) } diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 9a743a3411533..30978bb50d339 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -86,7 +86,9 @@ infer_type <- function(x) { createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { if (is.data.frame(data)) { # get the names of columns, they will be put into RDD - schema <- names(data) + if (is.null(schema)) { + schema <- names(data) + } n <- nrow(data) m <- ncol(data) # get rid of factor type diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 79055b7f18558..fad9d71158c51 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -20,7 +20,8 @@ # @rdname aggregateRDD # @seealso reduce # @export -setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) +setGeneric("aggregateRDD", + function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) # @rdname cache-methods # @export diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 7f902ba8e683e..ebc6ff65e9d0f 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -215,7 +215,6 @@ setMethod("partitionBy", serializedHashFuncBytes, getSerializedMode(x), packageNamesArr, - as.character(.sparkREnv$libname), broadcastArr, callJMethod(jrdd, "classTag")) @@ -560,8 +559,8 @@ setMethod("join", # Left outer join two RDDs # # @description -# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -597,8 +596,8 @@ setMethod("leftOuterJoin", # Right outer join two RDDs # # @description -# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -634,8 +633,8 @@ setMethod("rightOuterJoin", # Full outer join two RDDs # # @description -# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 633b869f91784..172335809dec2 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -17,10 +17,6 @@ .sparkREnv <- new.env() -sparkR.onLoad <- function(libname, pkgname) { - .sparkREnv$libname <- libname -} - # Utility function that returns TRUE if we have an active connection to the # backend and FALSE otherwise connExists <- function(env) { @@ -80,7 +76,6 @@ sparkR.stop <- function() { #' @param sparkEnvir Named list of environment variables to set on worker nodes. #' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. #' @param sparkJars Character string vector of jar files to pass to the worker nodes. -#' @param sparkRLibDir The path where R is installed on the worker nodes. #' @param sparkPackages Character string vector of packages from spark-packages.org #' @export #' @examples @@ -101,15 +96,15 @@ sparkR.init <- function( sparkEnvir = list(), sparkExecutorEnv = list(), sparkJars = "", - sparkRLibDir = "", sparkPackages = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { - cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") + cat(paste("Re-using existing Spark Context.", + "Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")) return(get(".sparkRjsc", envir = .sparkREnv)) } - sparkMem <- Sys.getenv("SPARK_MEM", "512m") + sparkMem <- Sys.getenv("SPARK_MEM", "1024m") jars <- suppressWarnings(normalizePath(as.character(sparkJars))) # Classpath separator is ";" on Windows @@ -169,10 +164,6 @@ sparkR.init <- function( sparkHome <- normalizePath(sparkHome) } - if (nchar(sparkRLibDir) != 0) { - .sparkREnv$libname <- sparkRLibDir - } - sparkEnvirMap <- new.env() for (varname in names(sparkEnvir)) { sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] @@ -180,14 +171,16 @@ sparkR.init <- function( sparkExecutorEnvMap <- new.env() if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { - sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- + paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } for (varname in names(sparkExecutorEnv)) { sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] } nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + localJarPaths <- sapply(nonEmptyJars, + function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs # Seconds resolution is good enough for this purpose, so use ints diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 13cec0f712fb4..ea629a64f7158 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -334,18 +334,21 @@ getStorageLevel <- function(newLevel = c("DISK_ONLY", "MEMORY_ONLY_SER_2", "OFF_HEAP")) { match.arg(newLevel) + storageLevelClass <- "org.apache.spark.storage.StorageLevel" storageLevel <- switch(newLevel, - "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"), - "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"), - "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"), - "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"), - "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"), - "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"), - "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"), - "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"), - "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"), - "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"), - "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP")) + "DISK_ONLY" = callJStatic(storageLevelClass, "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic(storageLevelClass, "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic(storageLevelClass, "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic(storageLevelClass, "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic(storageLevelClass, "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP")) } # Utility function for functions where an argument needs to be integer but we want to allow @@ -545,9 +548,11 @@ mergePartitions <- function(rdd, zip) { lengthOfKeys <- part[[len - lengthOfValues]] stopifnot(len == lengthOfKeys + lengthOfValues) - # For zip operation, check if corresponding partitions of both RDDs have the same number of elements. + # For zip operation, check if corresponding partitions + # of both RDDs have the same number of elements. if (zip && lengthOfKeys != lengthOfValues) { - stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + stop(paste("Can only zip RDDs with same number of elements", + "in each pair of corresponding partitions.")) } if (lengthOfKeys > 1) { diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R index 8fe711b622086..2a8a8213d0849 100644 --- a/R/pkg/inst/profile/general.R +++ b/R/pkg/inst/profile/general.R @@ -16,7 +16,7 @@ # .First <- function() { - home <- Sys.getenv("SPARK_HOME") - .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") + .libPaths(c(packageDir, .libPaths())) Sys.setenv(NOAWT=1) } diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index 4db7266abc8e2..ccaea18ecab2a 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -82,7 +82,7 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { saveAsObjectFile(rdd2, fileName2) rdd <- objectFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1, recursive = TRUE) unlink(fileName2, recursive = TRUE) diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index a1e354e567be5..3be8c65a6c1a0 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -38,13 +38,13 @@ test_that("union on two RDDs", { union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, c(as.list(nums), mockFile)) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") rdd<- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") unlink(fileName) }) diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R index 8bc693be20c3c..cc1faeabffe30 100644 --- a/R/pkg/inst/tests/test_includeJAR.R +++ b/R/pkg/inst/tests/test_includeJAR.R @@ -18,8 +18,8 @@ context("include an external JAR in SparkContext") runScript <- function() { sparkHome <- Sys.getenv("SPARK_HOME") - jarPath <- paste("--jars", - shQuote(file.path(sparkHome, "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar"))) + sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" + jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") submitPath <- file.path(sparkHome, "bin/spark-submit") res <- system2(command = submitPath, @@ -31,7 +31,7 @@ runScript <- function() { test_that("sparkJars tag in SparkContext", { testOutput <- runScript() helloTest <- testOutput[1] - expect_true(helloTest == "Hello, Dave") + expect_equal(helloTest, "Hello, Dave") basicFunction <- testOutput[2] - expect_true(basicFunction == 4L) + expect_equal(basicFunction, "4") }) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R index fff028657db37..2552127cc547f 100644 --- a/R/pkg/inst/tests/test_parallelize_collect.R +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -57,7 +57,7 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { strListRDD2) for (rdd in rdds) { - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(.hasSlot(rdd, "jrdd") && inherits(rdd@jrdd, "jobj") && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 4fe653856756e..b79692873cec3 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -33,9 +33,9 @@ test_that("get number of partitions in RDD", { }) test_that("first on RDD", { - expect_true(first(rdd) == 1) + expect_equal(first(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) - expect_true(first(newrdd) == 2) + expect_equal(first(newrdd), 2) }) test_that("count and length on RDD", { @@ -669,13 +669,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), + list(2, list(NULL, 4)), list(3, list(3, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), + list("a", list(3, 1)), list("c", list(1, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -683,13 +685,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), + list(3, list(NULL, 3)), list(4, list(NULL, 4))))) rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), + list("d", list(NULL, 4)), list("c", list(NULL, 3))))) }) test_that("sortByKey() on pairwise RDDs", { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 6a08f894313c4..b0ea38854304e 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -61,7 +61,7 @@ test_that("infer types", { expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) testStruct <- infer_type(list(a = 1L, b = "2")) - expect_true(class(testStruct) == "structType") + expect_equal(class(testStruct), "structType") checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() @@ -73,39 +73,39 @@ test_that("infer types", { test_that("structType and structField", { testField <- structField("a", "string") - expect_true(inherits(testField, "structField")) - expect_true(testField$name() == "a") + expect_is(testField, "structField") + expect_equal(testField$name(), "a") expect_true(testField$nullable()) testSchema <- structType(testField, structField("b", "integer")) - expect_true(inherits(testSchema, "structType")) - expect_true(inherits(testSchema$fields()[[2]], "structField")) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType") + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") }) test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(sqlContext, rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- createDataFrame(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- createDataFrame(sqlContext, rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- createDataFrame(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) @@ -150,26 +150,26 @@ test_that("convert NAs to null type in DataFrames", { test_that("toDF", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- toDF(rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) @@ -219,21 +219,21 @@ test_that("create DataFrame with different data types", { test_that("jsonFile() on a local file returns a DataFrame", { df <- jsonFile(sqlContext, jsonPath) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) }) test_that("jsonRDD() on a RDD with json string", { rdd <- parallelize(sc, mockLines) - expect_true(count(rdd) == 3) + expect_equal(count(rdd), 3) df <- jsonRDD(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) rdd2 <- flatMap(rdd, function(x) c(x, x)) df <- jsonRDD(sqlContext, rdd2) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 6) + expect_is(df, "DataFrame") + expect_equal(count(df), 6) }) test_that("test cache, uncache and clearCache", { @@ -248,9 +248,9 @@ test_that("test cache, uncache and clearCache", { test_that("test tableNames and tables", { df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - expect_true(length(tableNames(sqlContext)) == 1) + expect_equal(length(tableNames(sqlContext)), 1) df <- tables(sqlContext) - expect_true(count(df) == 1) + expect_equal(count(df), 1) dropTempTable(sqlContext, "table1") }) @@ -258,8 +258,8 @@ test_that("registerTempTable() results in a queryable table and sql() results in df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") - expect_true(inherits(newdf, "DataFrame")) - expect_true(count(newdf) == 1) + expect_is(newdf, "DataFrame") + expect_equal(count(newdf), 1) dropTempTable(sqlContext, "table1") }) @@ -279,14 +279,14 @@ test_that("insertInto() on a registered table", { registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1") - expect_true(count(sql(sqlContext, "select * from table1")) == 5) - expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Michael") + expect_equal(count(sql(sqlContext, "select * from table1")), 5) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") dropTempTable(sqlContext, "table1") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_true(count(sql(sqlContext, "select * from table1")) == 2) - expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Bob") + expect_equal(count(sql(sqlContext, "select * from table1")), 2) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") dropTempTable(sqlContext, "table1") }) @@ -294,16 +294,16 @@ test_that("table() returns a new DataFrame", { df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") tabledf <- table(sqlContext, "table1") - expect_true(inherits(tabledf, "DataFrame")) - expect_true(count(tabledf) == 3) + expect_is(tabledf, "DataFrame") + expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") }) test_that("toRDD() returns an RRDD", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toRDD(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(count(testRDD) == 3) + expect_is(testRDD, "RDD") + expect_equal(count(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { @@ -311,9 +311,9 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) - expect_true(inherits(unioned, "RDD")) - expect_true(SparkR:::getSerializedMode(unioned) == "byte") - expect_true(collect(unioned)[[2]]$name == "Andy") + expect_is(unioned, "RDD") + expect_equal(SparkR:::getSerializedMode(unioned), "byte") + expect_equal(collect(unioned)[[2]]$name, "Andy") }) test_that("union on mixed serialization types correctly returns a byte RRDD", { @@ -333,16 +333,16 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) - expect_true(inherits(unionByte, "RDD")) - expect_true(SparkR:::getSerializedMode(unionByte) == "byte") - expect_true(collect(unionByte)[[1]] == 1) - expect_true(collect(unionByte)[[12]]$name == "Andy") + expect_is(unionByte, "RDD") + expect_equal(SparkR:::getSerializedMode(unionByte), "byte") + expect_equal(collect(unionByte)[[1]], 1) + expect_equal(collect(unionByte)[[12]]$name, "Andy") unionString <- unionRDD(textRDD, dfRDD) - expect_true(inherits(unionString, "RDD")) - expect_true(SparkR:::getSerializedMode(unionString) == "byte") - expect_true(collect(unionString)[[1]] == "Michael") - expect_true(collect(unionString)[[5]]$name == "Andy") + expect_is(unionString, "RDD") + expect_equal(SparkR:::getSerializedMode(unionString), "byte") + expect_equal(collect(unionString)[[1]], "Michael") + expect_equal(collect(unionString)[[5]]$name, "Andy") }) test_that("objectFile() works with row serialization", { @@ -352,7 +352,7 @@ test_that("objectFile() works with row serialization", { saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) - expect_true(inherits(objectIn, "RDD")) + expect_is(objectIn, "RDD") expect_equal(SparkR:::getSerializedMode(objectIn), "byte") expect_equal(collect(objectIn)[[2]]$age, 30) }) @@ -363,35 +363,35 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", { row$newCol <- row$age + 5 row }) - expect_true(inherits(testRDD, "RDD")) + expect_is(testRDD, "RDD") collected <- collect(testRDD) - expect_true(collected[[1]]$name == "Michael") - expect_true(collected[[2]]$newCol == "35") + expect_equal(collected[[1]]$name, "Michael") + expect_equal(collected[[2]]$newCol, 35) }) test_that("collect() returns a data.frame", { df <- jsonFile(sqlContext, jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) - expect_true(names(rdf)[1] == "age") - expect_true(nrow(rdf) == 3) - expect_true(ncol(rdf) == 2) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 3) + expect_equal(ncol(rdf), 2) }) test_that("limit() returns DataFrame with the correct number of rows", { df <- jsonFile(sqlContext, jsonPath) dfLimited <- limit(df, 2) - expect_true(inherits(dfLimited, "DataFrame")) - expect_true(count(dfLimited) == 2) + expect_is(dfLimited, "DataFrame") + expect_equal(count(dfLimited), 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { df <- jsonFile(sqlContext, jsonPath) - expect_true(nrow(collect(df)) == nrow(take(df, 10))) - expect_true(ncol(collect(df)) == ncol(take(df, 10))) + expect_equal(nrow(collect(df)), nrow(take(df, 10))) + expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) -test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { +test_that("multiple pipeline transformations result in an RDD with the correct values", { df <- jsonFile(sqlContext, jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -401,9 +401,9 @@ test_that("multiple pipeline transformations starting with a DataFrame result in row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE row }) - expect_true(inherits(second, "RDD")) - expect_true(count(second) == 3) - expect_true(collect(second)[[2]]$age == 35) + expect_is(second, "RDD") + expect_equal(count(second), 3) + expect_equal(collect(second)[[2]]$age, 35) expect_true(collect(second)[[2]]$testCol) expect_false(collect(second)[[3]]$testCol) }) @@ -430,36 +430,36 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { test_that("schema(), dtypes(), columns(), names() return the correct values/format", { df <- jsonFile(sqlContext, jsonPath) testSchema <- schema(df) - expect_true(length(testSchema$fields()) == 2) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") - expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") - expect_true(testSchema$fields()[[1]]$name() == "age") + expect_equal(length(testSchema$fields()), 2) + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") + expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string") + expect_equal(testSchema$fields()[[1]]$name(), "age") testTypes <- dtypes(df) - expect_true(length(testTypes[[1]]) == 2) - expect_true(testTypes[[1]][1] == "age") + expect_equal(length(testTypes[[1]]), 2) + expect_equal(testTypes[[1]][1], "age") testCols <- columns(df) - expect_true(length(testCols) == 2) - expect_true(testCols[2] == "name") + expect_equal(length(testCols), 2) + expect_equal(testCols[2], "name") testNames <- names(df) - expect_true(length(testNames) == 2) - expect_true(testNames[2] == "name") + expect_equal(length(testNames), 2) + expect_equal(testNames[2], "name") }) test_that("head() and first() return the correct data", { df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) - expect_true(nrow(testHead) == 3) - expect_true(ncol(testHead) == 2) + expect_equal(nrow(testHead), 3) + expect_equal(ncol(testHead), 2) testHead2 <- head(df, 2) - expect_true(nrow(testHead2) == 2) - expect_true(ncol(testHead2) == 2) + expect_equal(nrow(testHead2), 2) + expect_equal(ncol(testHead2), 2) testFirst <- first(df) - expect_true(nrow(testFirst) == 1) + expect_equal(nrow(testFirst), 1) }) test_that("distinct() on DataFrames", { @@ -472,15 +472,15 @@ test_that("distinct() on DataFrames", { df <- jsonFile(sqlContext, jsonPathWithDup) uniques <- distinct(df) - expect_true(inherits(uniques, "DataFrame")) - expect_true(count(uniques) == 3) + expect_is(uniques, "DataFrame") + expect_equal(count(uniques), 3) }) test_that("sample on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) - expect_true(inherits(sampled, "DataFrame")) + expect_is(sampled, "DataFrame") sampled2 <- sample(df, FALSE, 0.1) expect_true(count(sampled2) < 3) @@ -491,15 +491,15 @@ test_that("sample on a DataFrame", { test_that("select operators", { df <- select(jsonFile(sqlContext, jsonPath), "name", "age") - expect_true(inherits(df$name, "Column")) - expect_true(inherits(df[[2]], "Column")) - expect_true(inherits(df[["age"]], "Column")) + expect_is(df$name, "Column") + expect_is(df[[2]], "Column") + expect_is(df[["age"]], "Column") - expect_true(inherits(df[,1], "DataFrame")) + expect_is(df[,1], "DataFrame") expect_equal(columns(df[,1]), c("name")) expect_equal(columns(df[,"age"]), c("age")) df2 <- df[,c("age", "name")] - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(columns(df2), c("age", "name")) df$age2 <- df$age @@ -518,50 +518,50 @@ test_that("select operators", { test_that("select with column", { df <- jsonFile(sqlContext, jsonPath) df1 <- select(df, "name") - expect_true(columns(df1) == c("name")) - expect_true(count(df1) == 3) + expect_equal(columns(df1), c("name")) + expect_equal(count(df1), 3) df2 <- select(df, df$age) - expect_true(columns(df2) == c("age")) - expect_true(count(df2) == 3) + expect_equal(columns(df2), c("age")) + expect_equal(count(df2), 3) }) test_that("selectExpr() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") - expect_true(names(selected) == "(age * 2)") + expect_equal(names(selected), "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) selected2 <- selectExpr(df, "name as newName", "abs(age) as age") expect_equal(names(selected2), c("newName", "age")) - expect_true(count(selected2) == 3) + expect_equal(count(selected2), 3) }) test_that("column calculation", { df <- jsonFile(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) - expect_true(names(d) == c("age2")) + expect_equal(names(d), c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("read.df() from json file", { df <- read.df(sqlContext, jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) # Check if we can apply a user defined schema schema <- structType(structField("name", type = "string"), structField("age", type = "double")) df1 <- read.df(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df1, "DataFrame")) + expect_is(df1, "DataFrame") expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) # Run the same with loadDF df2 <- loadDF(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) }) @@ -569,8 +569,8 @@ test_that("write.df() as parquet file", { df <- read.df(sqlContext, jsonPath, "json") write.df(df, parquetPath, "parquet", mode="overwrite") df2 <- read.df(sqlContext, parquetPath, "parquet") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("test HiveContext", { @@ -580,17 +580,17 @@ test_that("test HiveContext", { skip("Hive is not build with SparkSQL, skipped") }) df <- createExternalTable(hiveCtx, "json", jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) df2 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") saveAsTable(df, "json", "json", "append", path = jsonPath2) df3 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df3, "DataFrame")) - expect_true(count(df3) == 6) + expect_is(df3, "DataFrame") + expect_equal(count(df3), 6) }) test_that("column operators", { @@ -643,65 +643,65 @@ test_that("string operators", { test_that("group by", { df <- jsonFile(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) df1 <- agg(df, age2 = max(df$age)) - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) expect_equal(columns(df1), c("age2")) gd <- groupBy(df, "name") - expect_true(inherits(gd, "GroupedData")) + expect_is(gd, "GroupedData") df2 <- count(gd) - expect_true(inherits(df2, "DataFrame")) - expect_true(3 == count(df2)) + expect_is(df2, "DataFrame") + expect_equal(3, count(df2)) # Also test group_by, summarize, mean gd1 <- group_by(df, "name") - expect_true(inherits(gd1, "GroupedData")) + expect_is(gd1, "GroupedData") df_summarized <- summarize(gd, mean_age = mean(df$age)) - expect_true(inherits(df_summarized, "DataFrame")) - expect_true(3 == count(df_summarized)) + expect_is(df_summarized, "DataFrame") + expect_equal(3, count(df_summarized)) df3 <- agg(gd, age = "sum") - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) df3 <- agg(gd, age = sum(df$age)) - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) expect_equal(columns(df3), c("name", "age")) df4 <- sum(gd, "age") - expect_true(inherits(df4, "DataFrame")) - expect_true(3 == count(df4)) - expect_true(3 == count(mean(gd, "age"))) - expect_true(3 == count(max(gd, "age"))) + expect_is(df4, "DataFrame") + expect_equal(3, count(df4)) + expect_equal(3, count(mean(gd, "age"))) + expect_equal(3, count(max(gd, "age"))) }) test_that("arrange() and orderBy() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) sorted <- arrange(df, df$age) - expect_true(collect(sorted)[1,2] == "Michael") + expect_equal(collect(sorted)[1,2], "Michael") sorted2 <- arrange(df, "name") - expect_true(collect(sorted2)[2,"age"] == 19) + expect_equal(collect(sorted2)[2,"age"], 19) sorted3 <- orderBy(df, asc(df$age)) expect_true(is.na(first(sorted3)$age)) - expect_true(collect(sorted3)[2, "age"] == 19) + expect_equal(collect(sorted3)[2, "age"], 19) sorted4 <- orderBy(df, desc(df$name)) - expect_true(first(sorted4)$name == "Michael") - expect_true(collect(sorted4)[3,"name"] == "Andy") + expect_equal(first(sorted4)$name, "Michael") + expect_equal(collect(sorted4)[3,"name"], "Andy") }) test_that("filter() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) filtered <- filter(df, "age > 20") - expect_true(count(filtered) == 1) - expect_true(collect(filtered)$name == "Andy") + expect_equal(count(filtered), 1) + expect_equal(collect(filtered)$name, "Andy") filtered2 <- where(df, df$name != "Michael") - expect_true(count(filtered2) == 2) - expect_true(collect(filtered2)$age[2] == 19) + expect_equal(count(filtered2), 2) + expect_equal(collect(filtered2)$age[2], 19) # test suites for %in% filtered3 <- filter(df, "age in (19)") @@ -727,36 +727,43 @@ test_that("join() on a DataFrame", { joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) - expect_true(count(joined) == 12) + expect_equal(count(joined), 12) joined2 <- join(df, df2, df$name == df2$name) expect_equal(names(joined2), c("age", "name", "name", "test")) - expect_true(count(joined2) == 3) + expect_equal(count(joined2), 3) joined3 <- join(df, df2, df$name == df2$name, "right_outer") expect_equal(names(joined3), c("age", "name", "name", "test")) - expect_true(count(joined3) == 4) + expect_equal(count(joined3), 4) expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) joined4 <- select(join(df, df2, df$name == df2$name, "outer"), alias(df$age + 5, "newAge"), df$name, df2$test) expect_equal(names(joined4), c("newAge", "name", "test")) - expect_true(count(joined4) == 4) + expect_equal(count(joined4), 4) expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) }) test_that("toJSON() returns an RDD of the correct values", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toJSON(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_is(testRDD, "RDD") + expect_equal(SparkR:::getSerializedMode(testRDD), "string") expect_equal(collect(testRDD)[[1]], mockLines[1]) }) test_that("showDF()", { df <- jsonFile(sqlContext, jsonPath) s <- capture.output(showDF(df)) - expect_output(s , "+----+-------+\n| age| name|\n+----+-------+\n|null|Michael|\n| 30| Andy|\n| 19| Justin|\n+----+-------+\n") + expected <- paste("+----+-------+\n", + "| age| name|\n", + "+----+-------+\n", + "|null|Michael|\n", + "| 30| Andy|\n", + "| 19| Justin|\n", + "+----+-------+\n", sep="") + expect_output(s , expected) }) test_that("isLocal()", { @@ -775,50 +782,50 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", { df2 <- read.df(sqlContext, jsonPath2, "json") unioned <- arrange(unionAll(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(unioned) == 6) - expect_true(first(unioned)$name == "Michael") + expect_is(unioned, "DataFrame") + expect_equal(count(unioned), 6) + expect_equal(first(unioned)$name, "Michael") excepted <- arrange(except(df, df2), desc(df$age)) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(excepted) == 2) - expect_true(first(excepted)$name == "Justin") + expect_is(unioned, "DataFrame") + expect_equal(count(excepted), 2) + expect_equal(first(excepted)$name, "Justin") intersected <- arrange(intersect(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(intersected) == 1) - expect_true(first(intersected)$name == "Andy") + expect_is(unioned, "DataFrame") + expect_equal(count(intersected), 1) + expect_equal(first(intersected)$name, "Andy") }) test_that("withColumn() and withColumnRenamed()", { df <- jsonFile(sqlContext, jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- withColumnRenamed(df, "age", "newerAge") - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) test_that("mutate() and rename()", { df <- jsonFile(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- rename(df, newerAge = df$age) - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) test_that("write.df() on DataFrame and works with parquetFile", { df <- jsonFile(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath) - expect_true(inherits(parquetDF, "DataFrame")) + expect_is(parquetDF, "DataFrame") expect_equal(count(df), count(parquetDF)) }) @@ -828,8 +835,8 @@ test_that("parquetFile works with multiple input paths", { parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") write.df(df, parquetPath2, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) - expect_true(inherits(parquetDF, "DataFrame")) - expect_true(count(parquetDF) == count(df)*2) + expect_is(parquetDF, "DataFrame") + expect_equal(count(parquetDF), count(df)*2) }) test_that("describe() on a DataFrame", { @@ -851,58 +858,58 @@ test_that("dropna() on a DataFrame", { expected <- rows[!is.na(rows$name),] actual <- collect(dropna(df, cols = "name")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age),] actual <- collect(dropna(df, cols = "age")) row.names(expected) <- row.names(actual) # identical on two dataframes does not work here. Don't know why. # use identical on all columns as a workaround. - expect_true(identical(expected$age, actual$age)) - expect_true(identical(expected$height, actual$height)) - expect_true(identical(expected$name, actual$name)) + expect_identical(expected$age, actual$age) + expect_identical(expected$height, actual$height) + expect_identical(expected$name, actual$name) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # drop with how expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] actual <- collect(dropna(df, "all")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df, "any")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, "any", cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height),] actual <- collect(dropna(df, "all", cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # drop with threshold expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) + as.integer(!is.na(rows$name)) >= 3,] actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) }) test_that("fillna() on a DataFrame", { @@ -915,22 +922,22 @@ test_that("fillna() on a DataFrame", { expected$age[is.na(expected$age)] <- 50 expected$height[is.na(expected$height)] <- 50.6 actual <- collect(fillna(df, 50.6)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, "unknown")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$age[is.na(expected$age)] <- 50 actual <- collect(fillna(df, 50.6, "age")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, "unknown", c("age", "name"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # fill with named list @@ -939,7 +946,7 @@ test_that("fillna() on a DataFrame", { expected$height[is.na(expected$height)] <- 50.6 expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) }) unlink(parquetPath) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R index c5eb417b40159..c2c724cdc762f 100644 --- a/R/pkg/inst/tests/test_take.R +++ b/R/pkg/inst/tests/test_take.R @@ -59,8 +59,8 @@ test_that("take() gives back the original elements in correct count and order", expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) - expect_true(length(take(strListRDD, 0)) == 0) - expect_true(length(take(strVectorRDD, 0)) == 0) - expect_true(length(take(numListRDD, 0)) == 0) - expect_true(length(take(numVectorRDD, 0)) == 0) + expect_equal(length(take(strListRDD, 0)), 0) + expect_equal(length(take(strVectorRDD, 0)), 0) + expect_equal(length(take(numListRDD, 0)), 0) + expect_equal(length(take(numVectorRDD, 0)), 0) }) diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 092ad9dc10c2e..58318dfef71ab 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -27,9 +27,9 @@ test_that("textFile() on a local file returns an RDD", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(count(rdd) > 0) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName) }) @@ -133,7 +133,7 @@ test_that("textFile() on multiple paths", { writeLines("Spark is awesome.", fileName2) rdd <- textFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1) unlink(fileName2) diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index 15030e6f1d77e..aa0d2a66b9082 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -45,10 +45,10 @@ test_that("serializeToBytes on RDD", { writeLines(mockFile, fileName) text.rdd <- textFile(sc, fileName) - expect_true(getSerializedMode(text.rdd) == "string") + expect_equal(getSerializedMode(text.rdd), "string") ser.rdd <- serializeToBytes(text.rdd) expect_equal(collect(ser.rdd), as.list(mockFile)) - expect_true(getSerializedMode(ser.rdd) == "byte") + expect_equal(getSerializedMode(ser.rdd), "byte") unlink(fileName) }) diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 43c4288912b18..192d3ae091134 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -22,7 +22,7 @@ # - SPARK_EXECUTOR_INSTANCES, Number of workers to start (Default: 2) # - SPARK_EXECUTOR_CORES, Number of cores for the workers (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Worker (e.g. 1000M, 2G) (Default: 1G) -# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb) +# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 1G) # - SPARK_YARN_APP_NAME, The name of your application (Default: Spark) # - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: ‘default’) # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. diff --git a/core/pom.xml b/core/pom.xml index 565437c4861a4..558cc3fb9f2f3 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -69,16 +69,6 @@ org.apache.hadoop hadoop-client - - - javax.servlet - servlet-api - - - org.codehaus.jackson - jackson-mapper-asl - - org.apache.spark @@ -353,28 +343,28 @@ test - org.mockito - mockito-core + org.hamcrest + hamcrest-core test - org.scalacheck - scalacheck_${scala.binary.version} + org.hamcrest + hamcrest-library test - junit - junit + org.mockito + mockito-core test - org.hamcrest - hamcrest-core + org.scalacheck + scalacheck_${scala.binary.version} test - org.hamcrest - hamcrest-library + junit + junit test diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java similarity index 91% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java rename to core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java index 3f746b886bc9b..0399abc63c235 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java +++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.serializer; import java.io.IOException; import java.io.InputStream; @@ -24,9 +24,7 @@ import scala.reflect.ClassTag; -import org.apache.spark.serializer.DeserializationStream; -import org.apache.spark.serializer.SerializationStream; -import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.PlatformDependent; /** @@ -35,7 +33,8 @@ * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work * around this, we pass a dummy no-op serializer. */ -final class DummySerializerInstance extends SerializerInstance { +@Private +public final class DummySerializerInstance extends SerializerInstance { public static final DummySerializerInstance INSTANCE = new DummySerializerInstance(); diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 9e9ed94b7890c..56289573209fb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -30,6 +30,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java new file mode 100644 index 0000000000000..45b78829e4cf7 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.annotation.Private; + +/** + * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific + * comparisons, such as lexicographic comparison for strings. + */ +@Private +public abstract class PrefixComparator { + public abstract int compare(long prefix1, long prefix2); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java new file mode 100644 index 0000000000000..438742565c51d --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import com.google.common.base.Charsets; +import com.google.common.primitives.Longs; +import com.google.common.primitives.UnsignedBytes; + +import org.apache.spark.annotation.Private; +import org.apache.spark.unsafe.types.UTF8String; + +@Private +public class PrefixComparators { + private PrefixComparators() {} + + public static final StringPrefixComparator STRING = new StringPrefixComparator(); + public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); + public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); + public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); + + public static final class StringPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + // TODO: can done more efficiently + byte[] a = Longs.toByteArray(aPrefix); + byte[] b = Longs.toByteArray(bPrefix); + for (int i = 0; i < 8; i++) { + int c = UnsignedBytes.compare(a[i], b[i]); + if (c != 0) return c; + } + return 0; + } + + public long computePrefix(byte[] bytes) { + if (bytes == null) { + return 0L; + } else { + byte[] padded = new byte[8]; + System.arraycopy(bytes, 0, padded, 0, Math.min(bytes.length, 8)); + return Longs.fromByteArray(padded); + } + } + + public long computePrefix(String value) { + return value == null ? 0L : computePrefix(value.getBytes(Charsets.UTF_8)); + } + + public long computePrefix(UTF8String value) { + return value == null ? 0L : computePrefix(value.getBytes()); + } + } + + /** + * Prefix comparator for all integral types (boolean, byte, short, int, long). + */ + public static final class IntegralPrefixComparator extends PrefixComparator { + @Override + public int compare(long a, long b) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public final long NULL_PREFIX = Long.MIN_VALUE; + } + + public static final class FloatPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + float a = Float.intBitsToFloat((int) aPrefix); + float b = Float.intBitsToFloat((int) bPrefix); + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public long computePrefix(float value) { + return Float.floatToIntBits(value) & 0xffffffffL; + } + + public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); + } + + public static final class DoublePrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public long computePrefix(double value) { + return Double.doubleToLongBits(value); + } + + public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY); + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java new file mode 100644 index 0000000000000..09e4258792204 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +/** + * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte + * prefix, this may simply return 0. + */ +public abstract class RecordComparator { + + /** + * Compare two records for order. + * + * @return a negative integer, zero, or a positive integer as the first record is less than, + * equal to, or greater than the second. + */ + public abstract int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java new file mode 100644 index 0000000000000..0c4ebde407cfc --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +final class RecordPointerAndKeyPrefix { + /** + * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * description of how these addresses are encoded. + */ + public long recordPointer; + + /** + * A key prefix, for use in comparisons. + */ + public long keyPrefix; +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java new file mode 100644 index 0000000000000..4d6731ee60af3 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; +import java.util.LinkedList; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +/** + * External sorter based on {@link UnsafeInMemorySorter}. + */ +public final class UnsafeExternalSorter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); + + private static final int PAGE_SIZE = 1 << 27; // 128 megabytes + @VisibleForTesting + static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; + + private final PrefixComparator prefixComparator; + private final RecordComparator recordComparator; + private final int initialSize; + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext; + private ShuffleWriteMetrics writeMetrics; + + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList(); + + // These variables are reset after spilling: + private UnsafeInMemorySorter sorter; + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + private long freeSpaceInCurrentPage = 0; + + private final LinkedList spillWriters = new LinkedList<>(); + + public UnsafeExternalSorter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + SparkConf conf) throws IOException { + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.initialSize = initialSize; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units + this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + initializeForWriting(); + } + + // TODO: metrics tracking + integration with shuffle write metrics + // need to connect the write metrics to task metrics so we count the spill IO somewhere. + + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeForWriting() throws IOException { + this.writeMetrics = new ShuffleWriteMetrics(); + // TODO: move this sizing calculation logic into a static method of sorter: + final long memoryRequested = initialSize * 8L * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); + } + + this.sorter = + new UnsafeInMemorySorter(memoryManager, recordComparator, prefixComparator, initialSize); + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + @VisibleForTesting + public void spill() throws IOException { + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spillWriters.size(), + spillWriters.size() > 1 ? " times" : " time"); + + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, + sorter.numRecords()); + spillWriters.add(spillWriter); + final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator(); + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final Object baseObject = sortedRecords.getBaseObject(); + final long baseOffset = sortedRecords.getBaseOffset(); + final int recordLength = sortedRecords.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); + } + spillWriter.close(); + final long sorterMemoryUsage = sorter.getMemoryUsage(); + sorter = null; + shuffleMemoryManager.release(sorterMemoryUsage); + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + initializeForWriting(); + } + + private long getMemoryUsage() { + return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + } + + public long freeMemory() { + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + memoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); + } + allocatedPages.clear(); + currentPage = null; + currentPagePosition = -1; + freeSpaceInCurrentPage = 0; + return memoryFreed; + } + + /** + * Checks whether there is enough space to insert a new record into the sorter. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + + * @return true if the record can be inserted without requiring more allocations, false otherwise. + */ + private boolean haveSpaceForRecord(int requiredSpace) { + assert (requiredSpace > 0); + return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + */ + private void allocateSpaceForRecord(int requiredSpace) throws IOException { + // TODO: merge these steps to first calculate total memory requirements for this insert, + // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the + // data page. + if (!sorter.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to expand sort pointer array"); + final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + sorter.expandPointerArray(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + } + } + + if (requiredSpace > freeSpaceInCurrentPage) { + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, + freeSpaceInCurrentPage); + // TODO: we should track metrics on the amount of space wasted when we roll over to a new page + // without using the free space at the end of the current page. We should also do this for + // BytesToBytesMap. + if (requiredSpace > PAGE_SIZE) { + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquired < PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + } + } + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = PAGE_SIZE; + allocatedPages.add(currentPage); + } + } + } + + /** + * Write a record to the sorter. + */ + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + long prefix) throws IOException { + // Need 4 bytes to store the record length. + final int totalSpaceRequired = lengthInBytes + 4; + if (!haveSpaceForRecord(totalSpaceRequired)) { + allocateSpaceForRecord(totalSpaceRequired); + } + + final long recordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object dataPageBaseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); + currentPagePosition += 4; + PlatformDependent.copyMemory( + recordBaseObject, + recordBaseOffset, + dataPageBaseObject, + currentPagePosition, + lengthInBytes); + currentPagePosition += lengthInBytes; + + sorter.insertRecord(recordAddress, prefix); + } + + public UnsafeSorterIterator getSortedIterator() throws IOException { + final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator(); + int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); + if (spillWriters.isEmpty()) { + return inMemoryIterator; + } else { + final UnsafeSorterSpillMerger spillMerger = + new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + spillMerger.addSpill(spillWriter.getReader(blockManager)); + } + spillWriters.clear(); + if (inMemoryIterator.hasNext()) { + spillMerger.addSpill(inMemoryIterator); + } + return spillMerger.getSortedIterator(); + } + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java new file mode 100644 index 0000000000000..fc34ad9cff369 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.Comparator; + +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.util.collection.Sorter; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records + * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm + * compares records, it will first compare the stored key prefixes; if the prefixes are not equal, + * then we do not need to traverse the record pointers to compare the actual records. Avoiding these + * random memory accesses improves cache hit rates. + */ +public final class UnsafeInMemorySorter { + + private static final class SortComparator implements Comparator { + + private final RecordComparator recordComparator; + private final PrefixComparator prefixComparator; + private final TaskMemoryManager memoryManager; + + SortComparator( + RecordComparator recordComparator, + PrefixComparator prefixComparator, + TaskMemoryManager memoryManager) { + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.memoryManager = memoryManager; + } + + @Override + public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { + final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix); + if (prefixComparisonResult == 0) { + final Object baseObject1 = memoryManager.getPage(r1.recordPointer); + final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + 4; // skip length + final Object baseObject2 = memoryManager.getPage(r2.recordPointer); + final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + 4; // skip length + return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2); + } else { + return prefixComparisonResult; + } + } + } + + private final TaskMemoryManager memoryManager; + private final Sorter sorter; + private final Comparator sortComparator; + + /** + * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ + private long[] pointerArray; + + /** + * The position in the sort buffer where new records can be inserted. + */ + private int pointerArrayInsertPosition = 0; + + public UnsafeInMemorySorter( + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + int initialSize) { + assert (initialSize > 0); + this.pointerArray = new long[initialSize * 2]; + this.memoryManager = memoryManager; + this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); + this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + } + + /** + * @return the number of records that have been inserted into this sorter. + */ + public int numRecords() { + return pointerArrayInsertPosition / 2; + } + + public long getMemoryUsage() { + return pointerArray.length * 8L; + } + + public boolean hasSpaceForAnotherRecord() { + return pointerArrayInsertPosition + 2 < pointerArray.length; + } + + public void expandPointerArray() { + final long[] oldArray = pointerArray; + // Guard against overflow: + final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; + pointerArray = new long[newLength]; + System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + } + + /** + * Inserts a record to be sorted. Assumes that the record pointer points to a record length + * stored as a 4-byte integer, followed by the record's bytes. + * + * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. + * @param keyPrefix a user-defined key prefix + */ + public void insertRecord(long recordPointer, long keyPrefix) { + if (!hasSpaceForAnotherRecord()) { + expandPointerArray(); + } + pointerArray[pointerArrayInsertPosition] = recordPointer; + pointerArrayInsertPosition++; + pointerArray[pointerArrayInsertPosition] = keyPrefix; + pointerArrayInsertPosition++; + } + + private static final class SortedIterator extends UnsafeSorterIterator { + + private final TaskMemoryManager memoryManager; + private final int sortBufferInsertPosition; + private final long[] sortBuffer; + private int position = 0; + private Object baseObject; + private long baseOffset; + private long keyPrefix; + private int recordLength; + + SortedIterator( + TaskMemoryManager memoryManager, + int sortBufferInsertPosition, + long[] sortBuffer) { + this.memoryManager = memoryManager; + this.sortBufferInsertPosition = sortBufferInsertPosition; + this.sortBuffer = sortBuffer; + } + + @Override + public boolean hasNext() { + return position < sortBufferInsertPosition; + } + + @Override + public void loadNext() { + // This pointer points to a 4-byte record length, followed by the record's bytes + final long recordPointer = sortBuffer[position]; + baseObject = memoryManager.getPage(recordPointer); + baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length + recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4); + keyPrefix = sortBuffer[position + 1]; + position += 2; + } + + @Override + public Object getBaseObject() { return baseObject; } + + @Override + public long getBaseOffset() { return baseOffset; } + + @Override + public int getRecordLength() { return recordLength; } + + @Override + public long getKeyPrefix() { return keyPrefix; } + } + + /** + * Return an iterator over record pointers in sorted order. For efficiency, all calls to + * {@code next()} will return the same mutable object. + */ + public UnsafeSorterIterator getSortedIterator() { + sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator); + return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray); + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java new file mode 100644 index 0000000000000..d09c728a7a638 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.util.collection.SortDataFormat; + +/** + * Supports sorting an array of (record pointer, key prefix) pairs. + * Used in {@link UnsafeInMemorySorter}. + *

+ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ +final class UnsafeSortDataFormat extends SortDataFormat { + + public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); + + private UnsafeSortDataFormat() { } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public RecordPointerAndKeyPrefix newKey() { + return new RecordPointerAndKeyPrefix(); + } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { + reuse.recordPointer = data[pos * 2]; + reuse.keyPrefix = data[pos * 2 + 1]; + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + long tempPointer = data[pos0 * 2]; + long tempKeyPrefix = data[pos0 * 2 + 1]; + data[pos0 * 2] = data[pos1 * 2]; + data[pos0 * 2 + 1] = data[pos1 * 2 + 1]; + data[pos1 * 2] = tempPointer; + data[pos1 * 2 + 1] = tempKeyPrefix; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos * 2] = src[srcPos * 2]; + dst[dstPos * 2 + 1] = src[srcPos * 2 + 1]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2); + } + + @Override + public long[] allocate(int length) { + assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; + return new long[length * 2]; + } + +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java similarity index 65% rename from core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala rename to core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java index 8df4f3b554c41..16ac2e8d821ba 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java @@ -15,17 +15,21 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster.mesos +package org.apache.spark.util.collection.unsafe.sort; -import org.apache.spark.SparkContext +import java.io.IOException; -private[spark] object MemoryUtils { - // These defaults copied from YARN - val OVERHEAD_FRACTION = 0.10 - val OVERHEAD_MINIMUM = 384 +public abstract class UnsafeSorterIterator { - def calculateTotalMemory(sc: SparkContext): Int = { - sc.conf.getInt("spark.mesos.executor.memoryOverhead", - math.max(OVERHEAD_FRACTION * sc.executorMemory, OVERHEAD_MINIMUM).toInt) + sc.executorMemory - } + public abstract boolean hasNext(); + + public abstract void loadNext() throws IOException; + + public abstract Object getBaseObject(); + + public abstract long getBaseOffset(); + + public abstract int getRecordLength(); + + public abstract long getKeyPrefix(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java new file mode 100644 index 0000000000000..8272c2a5be0d1 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; +import java.util.Comparator; +import java.util.PriorityQueue; + +final class UnsafeSorterSpillMerger { + + private final PriorityQueue priorityQueue; + + public UnsafeSorterSpillMerger( + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + final int numSpills) { + final Comparator comparator = new Comparator() { + + @Override + public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { + final int prefixComparisonResult = + prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); + if (prefixComparisonResult == 0) { + return recordComparator.compare( + left.getBaseObject(), left.getBaseOffset(), + right.getBaseObject(), right.getBaseOffset()); + } else { + return prefixComparisonResult; + } + } + }; + priorityQueue = new PriorityQueue(numSpills, comparator); + } + + public void addSpill(UnsafeSorterIterator spillReader) throws IOException { + if (spillReader.hasNext()) { + spillReader.loadNext(); + } + priorityQueue.add(spillReader); + } + + public UnsafeSorterIterator getSortedIterator() throws IOException { + return new UnsafeSorterIterator() { + + private UnsafeSorterIterator spillReader; + + @Override + public boolean hasNext() { + return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); + } + + @Override + public void loadNext() throws IOException { + if (spillReader != null) { + if (spillReader.hasNext()) { + spillReader.loadNext(); + priorityQueue.add(spillReader); + } + } + spillReader = priorityQueue.remove(); + } + + @Override + public Object getBaseObject() { return spillReader.getBaseObject(); } + + @Override + public long getBaseOffset() { return spillReader.getBaseOffset(); } + + @Override + public int getRecordLength() { return spillReader.getRecordLength(); } + + @Override + public long getKeyPrefix() { return spillReader.getKeyPrefix(); } + }; + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java new file mode 100644 index 0000000000000..29e9e0f30f934 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.*; + +import com.google.common.io.ByteStreams; + +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description + * of the file format). + */ +final class UnsafeSorterSpillReader extends UnsafeSorterIterator { + + private InputStream in; + private DataInputStream din; + + // Variables that change with every record read: + private int recordLength; + private long keyPrefix; + private int numRecordsRemaining; + + private byte[] arr = new byte[1024 * 1024]; + private Object baseObject = arr; + private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; + + public UnsafeSorterSpillReader( + BlockManager blockManager, + File file, + BlockId blockId) throws IOException { + assert (file.length() > 0); + final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + numRecordsRemaining = din.readInt(); + } + + @Override + public boolean hasNext() { + return (numRecordsRemaining > 0); + } + + @Override + public void loadNext() throws IOException { + recordLength = din.readInt(); + keyPrefix = din.readLong(); + if (recordLength > arr.length) { + arr = new byte[recordLength]; + baseObject = arr; + } + ByteStreams.readFully(in, arr, 0, recordLength); + numRecordsRemaining--; + if (numRecordsRemaining == 0) { + in.close(); + in = null; + din = null; + } + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } + + @Override + public int getRecordLength() { + return recordLength; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java new file mode 100644 index 0000000000000..b8d66659804ad --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.File; +import java.io.IOException; + +import scala.Tuple2; + +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Spills a list of sorted records to disk. Spill files have the following format: + * + * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...] + */ +final class UnsafeSorterSpillWriter { + + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. + private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + + private final File file; + private final BlockId blockId; + private final int numRecordsToWrite; + private BlockObjectWriter writer; + private int numRecordsSpilled = 0; + + public UnsafeSorterSpillWriter( + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics, + int numRecordsToWrite) throws IOException { + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempLocalBlock(); + this.file = spilledFileInfo._2(); + this.blockId = spilledFileInfo._1(); + this.numRecordsToWrite = numRecordsToWrite; + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + writer = blockManager.getDiskWriter( + blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); + // Write the number of records + writeIntToBuffer(numRecordsToWrite, 0); + writer.write(writeBuffer, 0, 4); + } + + // Based on DataOutputStream.writeLong. + private void writeLongToBuffer(long v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 56); + writeBuffer[offset + 1] = (byte)(v >>> 48); + writeBuffer[offset + 2] = (byte)(v >>> 40); + writeBuffer[offset + 3] = (byte)(v >>> 32); + writeBuffer[offset + 4] = (byte)(v >>> 24); + writeBuffer[offset + 5] = (byte)(v >>> 16); + writeBuffer[offset + 6] = (byte)(v >>> 8); + writeBuffer[offset + 7] = (byte)(v >>> 0); + } + + // Based on DataOutputStream.writeInt. + private void writeIntToBuffer(int v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 24); + writeBuffer[offset + 1] = (byte)(v >>> 16); + writeBuffer[offset + 2] = (byte)(v >>> 8); + writeBuffer[offset + 3] = (byte)(v >>> 0); + } + + /** + * Write a record to a spill file. + * + * @param baseObject the base object / memory page containing the record + * @param baseOffset the base offset which points directly to the record data. + * @param recordLength the length of the record. + * @param keyPrefix a sort key prefix + */ + public void write( + Object baseObject, + long baseOffset, + int recordLength, + long keyPrefix) throws IOException { + if (numRecordsSpilled == numRecordsToWrite) { + throw new IllegalStateException( + "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite); + } else { + numRecordsSpilled++; + } + writeIntToBuffer(recordLength, 0); + writeLongToBuffer(keyPrefix, 4); + int dataRemaining = recordLength; + int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len + long recordReadPosition = baseOffset; + while (dataRemaining > 0) { + final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); + PlatformDependent.copyMemory( + baseObject, + recordReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + toTransfer); + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE; + } + if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) { + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer)); + } + writer.recordWritten(); + } + + public void close() throws IOException { + writer.commitAndClose(); + writer = null; + writeBuffer = null; + } + + public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { + return new UnsafeSorterSpillReader(blockManager, file, blockId); + } +} diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 49329423dca76..0c50b4002cf7b 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -102,7 +102,7 @@ private[spark] class ExecutorAllocationManager( "spark.dynamicAllocation.executorIdleTimeout", "60s") private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds( - "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${2 * executorIdleTimeoutS}s") + "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${Integer.MAX_VALUE}s") // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 6909015ff66e6..221b1dab43278 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -24,8 +24,8 @@ import scala.collection.mutable import org.apache.spark.executor.TaskMetrics import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId -import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.scheduler._ +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -45,13 +45,23 @@ private[spark] case object TaskSchedulerIsSet private[spark] case object ExpireDeadHosts +private case class ExecutorRegistered(executorId: String) + +private case class ExecutorRemoved(executorId: String) + private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(sc: SparkContext) - extends ThreadSafeRpcEndpoint with Logging { +private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) + extends ThreadSafeRpcEndpoint with SparkListener with Logging { + + def this(sc: SparkContext) { + this(sc, new SystemClock) + } + + sc.addSparkListener(this) override val rpcEnv: RpcEnv = sc.env.rpcEnv @@ -86,30 +96,48 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) override def onStart(): Unit = { timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - Option(self).foreach(_.send(ExpireDeadHosts)) + Option(self).foreach(_.ask[Boolean](ExpireDeadHosts)) } }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS) } - override def receive: PartialFunction[Any, Unit] = { - case ExpireDeadHosts => - expireDeadHosts() + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + + // Messages sent and received locally + case ExecutorRegistered(executorId) => + executorLastSeen(executorId) = clock.getTimeMillis() + context.reply(true) + case ExecutorRemoved(executorId) => + executorLastSeen.remove(executorId) + context.reply(true) case TaskSchedulerIsSet => scheduler = sc.taskScheduler - } + context.reply(true) + case ExpireDeadHosts => + expireDeadHosts() + context.reply(true) - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + // Messages received from executors case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { - executorLastSeen(executorId) = System.currentTimeMillis() - eventLoopThread.submit(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) - context.reply(response) - } - }) + if (executorLastSeen.contains(executorId)) { + executorLastSeen(executorId) = clock.getTimeMillis() + eventLoopThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + context.reply(response) + } + }) + } else { + // This may happen if we get an executor's in-flight heartbeat immediately + // after we just removed it. It's not really an error condition so we should + // not log warning here. Otherwise there may be a lot of noise especially if + // we explicitly remove executors (SPARK-4134). + logDebug(s"Received heartbeat from unknown executor $executorId") + context.reply(HeartbeatResponse(reregisterBlockManager = true)) + } } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to @@ -119,9 +147,30 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) } } + /** + * If the heartbeat receiver is not stopped, notify it of executor registrations. + */ + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId))) + } + + /** + * If the heartbeat receiver is not stopped, notify it of executor removals so it doesn't + * log superfluous errors. + * + * Note that we must do this after the executor is actually removed to guard against the + * following race condition: if we remove an executor's metadata from our data structure + * prematurely, we may get an in-flight heartbeat from the executor before the executor is + * actually removed, in which case we will still mark the executor as a dead host later + * and expire it with loud error messages. + */ + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId))) + } + private def expireDeadHosts(): Unit = { logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.") - val now = System.currentTimeMillis() + val now = clock.getTimeMillis() for ((executorId, lastSeenMs) <- executorLastSeen) { if (now - lastSeenMs > executorTimeoutMs) { logWarning(s"Removing executor $executorId with no recent heartbeats: " + diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 7fcb7830e7b0b..87ab099267b2f 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -121,6 +121,7 @@ trait Logging { if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements if (!log4j12Initialized) { + // scalastyle:off println if (Utils.isInInterpreter) { val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { @@ -141,6 +142,7 @@ trait Logging { System.err.println(s"Spark was unable to load $defaultLogProps") } } + // scalastyle:on println } } Logging.initialized = true diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 2cdc167f85af0..32df42d57dbd6 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -17,7 +17,9 @@ package org.apache.spark -import java.io.File +import java.io.{File, FileInputStream} +import java.security.{KeyStore, NoSuchAlgorithmException} +import javax.net.ssl.{KeyManager, KeyManagerFactory, SSLContext, TrustManager, TrustManagerFactory} import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import org.eclipse.jetty.util.ssl.SslContextFactory @@ -38,7 +40,7 @@ import org.eclipse.jetty.util.ssl.SslContextFactory * @param trustStore a path to the trust-store file * @param trustStorePassword a password to access the trust-store file * @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java - * @param enabledAlgorithms a set of encryption algorithms to use + * @param enabledAlgorithms a set of encryption algorithms that may be used */ private[spark] case class SSLOptions( enabled: Boolean = false, @@ -48,7 +50,8 @@ private[spark] case class SSLOptions( trustStore: Option[File] = None, trustStorePassword: Option[String] = None, protocol: Option[String] = None, - enabledAlgorithms: Set[String] = Set.empty) { + enabledAlgorithms: Set[String] = Set.empty) + extends Logging { /** * Creates a Jetty SSL context factory according to the SSL settings represented by this object. @@ -63,7 +66,7 @@ private[spark] case class SSLOptions( trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) keyPassword.foreach(sslContextFactory.setKeyManagerPassword) protocol.foreach(sslContextFactory.setProtocol) - sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*) + sslContextFactory.setIncludeCipherSuites(supportedAlgorithms.toSeq: _*) Some(sslContextFactory) } else { @@ -94,7 +97,7 @@ private[spark] case class SSLOptions( .withValue("akka.remote.netty.tcp.security.protocol", ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) .withValue("akka.remote.netty.tcp.security.enabled-algorithms", - ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq)) + ConfigValueFactory.fromIterable(supportedAlgorithms.toSeq)) .withValue("akka.remote.netty.tcp.enable-ssl", ConfigValueFactory.fromAnyRef(true))) } else { @@ -102,6 +105,36 @@ private[spark] case class SSLOptions( } } + /* + * The supportedAlgorithms set is a subset of the enabledAlgorithms that + * are supported by the current Java security provider for this protocol. + */ + private val supportedAlgorithms: Set[String] = { + var context: SSLContext = null + try { + context = SSLContext.getInstance(protocol.orNull) + /* The set of supported algorithms does not depend upon the keys, trust, or + rng, although they will influence which algorithms are eventually used. */ + context.init(null, null, null) + } catch { + case npe: NullPointerException => + logDebug("No SSL protocol specified") + context = SSLContext.getDefault + case nsa: NoSuchAlgorithmException => + logDebug(s"No support for requested SSL protocol ${protocol.get}") + context = SSLContext.getDefault + } + + val providerAlgorithms = context.getServerSocketFactory.getSupportedCipherSuites.toSet + + // Log which algorithms we are discarding + (enabledAlgorithms &~ providerAlgorithms).foreach { cipher => + logDebug(s"Discarding unsupported cipher $cipher") + } + + enabledAlgorithms & providerAlgorithms + } + /** Returns a string representation of this SSLOptions with all the passwords masked. */ override def toString: String = s"SSLOptions{enabled=$enabled, " + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c7a7436462083..82704b1ab2189 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -315,6 +315,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _dagScheduler = ds } + /** + * A unique identifier for the Spark application. + * Its format depends on the scheduler implementation. + * (i.e. + * in case of local spark app something like 'local-1433865536131' + * in case of YARN something like 'application_1433865536131_34483' + * ) + */ def applicationId: String = _applicationId def applicationAttemptId: Option[String] = _applicationAttemptId @@ -490,7 +498,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _schedulerBackend = sched _taskScheduler = ts _dagScheduler = new DAGScheduler(this) - _heartbeatReceiver.send(TaskSchedulerIsSet) + _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's // constructor @@ -524,7 +532,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _executorAllocationManager = if (dynamicAllocationEnabled) { assert(supportDynamicAllocation, - "Dynamic allocation of executors is currently only supported in YARN mode") + "Dynamic allocation of executors is currently only supported in YARN and Mesos mode") Some(new ExecutorAllocationManager(this, listenerBus, _conf)) } else { None @@ -823,7 +831,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * @note Small files are preferred, large file is also allowable, but may cause bad performance. - * + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` * @param minPartitions A suggestion value of the minimal splitting number for input data. */ def wholeTextFiles( @@ -844,7 +853,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions).setName(path) } - /** * :: Experimental :: * @@ -870,9 +878,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @param minPartitions A suggestion value of the minimal splitting number for input data. - * * @note Small files are preferred; very large files may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` + * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @Experimental def binaryFiles( @@ -1354,10 +1363,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * Return whether dynamically adjusting the amount of resources allocated to - * this application is supported. This is currently only available for YARN. + * this application is supported. This is currently only available for YARN + * and Mesos coarse-grained mode. */ - private[spark] def supportDynamicAllocation = - master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false) + private[spark] def supportDynamicAllocation: Boolean = { + (master.contains("yarn") + || master.contains("mesos") + || _conf.getBoolean("spark.dynamicAllocation.testing", false)) + } /** * :: DeveloperApi :: @@ -1375,7 +1388,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") + "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestTotalExecutors(numExecutors) @@ -1393,7 +1406,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @DeveloperApi override def requestExecutors(numAdditionalExecutors: Int): Boolean = { assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") + "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestExecutors(numAdditionalExecutors) @@ -1411,7 +1424,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @DeveloperApi override def killExecutors(executorIds: Seq[String]): Boolean = { assert(supportDynamicAllocation, - "Killing executors is currently only supported in YARN mode") + "Killing executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.killExecutors(executorIds) @@ -1896,6 +1909,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * be a HDFS path if running on a cluster. */ def setCheckpointDir(directory: String) { + + // If we are running on a cluster, log a warning if the directory is local. + // Otherwise, the driver may attempt to reconstruct the checkpointed RDD from + // its own local file system, which is incorrect because the checkpoint files + // are actually on the executor machines. + if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) { + logWarning("Checkpoint directory must be non-local " + + "if Spark is running on a cluster: " + directory) + } + checkpointDir = Option(directory).map { dir => val path = new Path(dir, UUID.randomUUID().toString) val fs = path.getFileSystem(hadoopConfiguration) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index b0665570e2681..d18fc599e9890 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -22,7 +22,6 @@ import java.net.Socket import akka.actor.ActorSystem -import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties @@ -77,7 +76,7 @@ class SparkEnv ( val conf: SparkConf) extends Logging { // TODO Remove actorSystem - @deprecated("Actor system is no longer supported as of 1.4") + @deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0") val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem private[spark] var isStopped = false @@ -90,39 +89,42 @@ class SparkEnv ( private var driverTmpDirToDelete: Option[String] = None private[spark] def stop() { - isStopped = true - pythonWorkers.foreach { case(key, worker) => worker.stop() } - Option(httpFileServer).foreach(_.stop()) - mapOutputTracker.stop() - shuffleManager.stop() - broadcastManager.stop() - blockManager.stop() - blockManager.master.stop() - metricsSystem.stop() - outputCommitCoordinator.stop() - rpcEnv.shutdown() - - // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut - // down, but let's call it anyway in case it gets fixed in a later release - // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. - // actorSystem.awaitTermination() - - // Note that blockTransferService is stopped by BlockManager since it is started by it. - - // If we only stop sc, but the driver process still run as a services then we need to delete - // the tmp dir, if not, it will create too many tmp dirs. - // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the - // current working dir in executor which we do not need to delete. - driverTmpDirToDelete match { - case Some(path) => { - try { - Utils.deleteRecursively(new File(path)) - } catch { - case e: Exception => - logWarning(s"Exception while deleting Spark temp dir: $path", e) + + if (!isStopped) { + isStopped = true + pythonWorkers.values.foreach(_.stop()) + Option(httpFileServer).foreach(_.stop()) + mapOutputTracker.stop() + shuffleManager.stop() + broadcastManager.stop() + blockManager.stop() + blockManager.master.stop() + metricsSystem.stop() + outputCommitCoordinator.stop() + rpcEnv.shutdown() + + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut + // down, but let's call it anyway in case it gets fixed in a later release + // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. + // actorSystem.awaitTermination() + + // Note that blockTransferService is stopped by BlockManager since it is started by it. + + // If we only stop sc, but the driver process still run as a services then we need to delete + // the tmp dir, if not, it will create too many tmp dirs. + // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the + // current working dir in executor which we do not need to delete. + driverTmpDirToDelete match { + case Some(path) => { + try { + Utils.deleteRecursively(new File(path)) + } catch { + case e: Exception => + logWarning(s"Exception while deleting Spark temp dir: $path", e) + } } + case None => // We just need to delete tmp dir created by driver, so do nothing on executor } - case None => // We just need to delete tmp dir created by driver, so do nothing on executor } } @@ -171,7 +173,7 @@ object SparkEnv extends Logging { /** * Returns the ThreadLocal SparkEnv. */ - @deprecated("Use SparkEnv.get instead", "1.2") + @deprecated("Use SparkEnv.get instead", "1.2.0") def getThreadLocal: SparkEnv = { env } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 1a5f2bca26c2b..b7e72d4d0ed0b 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -95,7 +95,9 @@ private[spark] class RBackend { private[spark] object RBackend extends Logging { def main(args: Array[String]): Unit = { if (args.length < 1) { + // scalastyle:off println System.err.println("Usage: RBackend ") + // scalastyle:on println System.exit(-1) } val sparkRBackend = new RBackend() diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 4dfa7325934ff..23a470d6afcae 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -39,7 +39,6 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( deserializer: String, serializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Broadcast[Object]]) extends RDD[U](parent) with Logging { protected var dataStream: DataInputStream = _ @@ -60,7 +59,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( // The stdout/stderr is shared by multiple tasks, because we use one daemon // to launch child process as worker. - val errThread = RRDD.createRWorker(rLibDir, listenPort) + val errThread = RRDD.createRWorker(listenPort) // We use two sockets to separate input and output, then it's easy to manage // the lifecycle of them to avoid deadlock. @@ -161,7 +160,9 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( dataOut.write(elem.asInstanceOf[Array[Byte]]) } else if (deserializer == SerializationFormats.STRING) { // write string(for StringRRDD) + // scalastyle:off println printOut.println(elem) + // scalastyle:on println } } @@ -233,11 +234,10 @@ private class PairwiseRRDD[T: ClassTag]( hashFunc: Array[Byte], deserializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, (Int, Array[Byte])]( parent, numPartitions, hashFunc, deserializer, - SerializationFormats.BYTE, packageNames, rLibDir, + SerializationFormats.BYTE, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): (Int, Array[Byte]) = { @@ -264,10 +264,9 @@ private class RRDD[T: ClassTag]( deserializer: String, serializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, Array[Byte]]( - parent, -1, func, deserializer, serializer, packageNames, rLibDir, + parent, -1, func, deserializer, serializer, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): Array[Byte] = { @@ -291,10 +290,9 @@ private class StringRRDD[T: ClassTag]( func: Array[Byte], deserializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, String]( - parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, + parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): String = { @@ -390,9 +388,10 @@ private[r] object RRDD { thread } - private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { - val rCommand = "Rscript" + private def createRProcess(port: Int, script: String): BufferedStreamThread = { + val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript") val rOptions = "--vanilla" + val rLibDir = RUtils.sparkRPackagePath(isDriver = false) val rExecScript = rLibDir + "/SparkR/worker/" + script val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) // Unset the R_TESTS environment variable for workers. @@ -411,7 +410,7 @@ private[r] object RRDD { /** * ProcessBuilder used to launch worker R processes. */ - def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = { + def createRWorker(port: Int): BufferedStreamThread = { val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) if (!Utils.isWindows && useDaemon) { synchronized { @@ -419,7 +418,7 @@ private[r] object RRDD { // we expect one connections val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val daemonPort = serverSocket.getLocalPort - errThread = createRProcess(rLibDir, daemonPort, "daemon.R") + errThread = createRProcess(daemonPort, "daemon.R") // the socket used to send out the input of task serverSocket.setSoTimeout(10000) val sock = serverSocket.accept() @@ -441,7 +440,7 @@ private[r] object RRDD { errThread } } else { - createRProcess(rLibDir, port, "worker.R") + createRProcess(port, "worker.R") } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala new file mode 100644 index 0000000000000..d53abd3408c55 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.File + +import org.apache.spark.{SparkEnv, SparkException} + +private[spark] object RUtils { + /** + * Get the SparkR package path in the local spark distribution. + */ + def localSparkRPackagePath: Option[String] = { + val sparkHome = sys.env.get("SPARK_HOME") + sparkHome.map( + Seq(_, "R", "lib").mkString(File.separator) + ) + } + + /** + * Get the SparkR package path in various deployment modes. + * This assumes that Spark properties `spark.master` and `spark.submit.deployMode` + * and environment variable `SPARK_HOME` are set. + */ + def sparkRPackagePath(isDriver: Boolean): String = { + val (master, deployMode) = + if (isDriver) { + (sys.props("spark.master"), sys.props("spark.submit.deployMode")) + } else { + val sparkConf = SparkEnv.get.conf + (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode")) + } + + val isYarnCluster = master.contains("yarn") && deployMode == "cluster" + val isYarnClient = master.contains("yarn") && deployMode == "client" + + // In YARN mode, the SparkR package is distributed as an archive symbolically + // linked to the "sparkr" file in the current directory. Note that this does not apply + // to the driver in client mode because it is run outside of the cluster. + if (isYarnCluster || (isYarnClient && !isDriver)) { + new File("sparkr").getAbsolutePath + } else { + // Otherwise, assume the package is local + // TODO: support this for Mesos + localSparkRPackagePath.getOrElse { + throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 848b62f9de71b..f03875a3e8c89 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -18,17 +18,17 @@ package org.apache.spark.deploy import scala.collection.mutable.HashSet -import scala.concurrent._ +import scala.concurrent.ExecutionContext +import scala.reflect.ClassTag +import scala.util.{Failure, Success} -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} import org.apache.log4j.{Level, Logger} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} +import org.apache.spark.util.{ThreadUtils, SparkExitCode, Utils} /** * Proxy that relays messages to the driver. @@ -36,20 +36,30 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} * We currently don't support retry if submission fails. In HA mode, client will submit request to * all masters and see which one could handle it. */ -private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { - - private val masterActors = driverArgs.masters.map { m => - context.actorSelection(Master.toAkkaUrl(m, AkkaUtils.protocol(context.system))) - } - private val lostMasters = new HashSet[Address] - private var activeMasterActor: ActorSelection = null - - val timeout = RpcUtils.askTimeout(conf) - - override def preStart(): Unit = { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - +private class ClientEndpoint( + override val rpcEnv: RpcEnv, + driverArgs: ClientArguments, + masterEndpoints: Seq[RpcEndpointRef], + conf: SparkConf) + extends ThreadSafeRpcEndpoint with Logging { + + // A scheduled executor used to send messages at the specified time. + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("client-forward-message") + // Used to provide the implicit parameter of `Future` methods. + private val forwardMessageExecutionContext = + ExecutionContext.fromExecutor(forwardMessageThread, + t => t match { + case ie: InterruptedException => // Exit normally + case e: Throwable => + logError(e.getMessage, e) + System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) + }) + + private val lostMasters = new HashSet[RpcAddress] + private var activeMasterEndpoint: RpcEndpointRef = null + + override def onStart(): Unit = { driverArgs.cmd match { case "launch" => // TODO: We could add an env variable here and intercept it in `sc.addJar` that would @@ -82,44 +92,52 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) driverArgs.cores, driverArgs.supervise, command) - - // This assumes only one Master is active at a time - for (masterActor <- masterActors) { - masterActor ! RequestSubmitDriver(driverDescription) - } + ayncSendToMasterAndForwardReply[SubmitDriverResponse]( + RequestSubmitDriver(driverDescription)) case "kill" => val driverId = driverArgs.driverId - // This assumes only one Master is active at a time - for (masterActor <- masterActors) { - masterActor ! RequestKillDriver(driverId) - } + ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) + } + } + + /** + * Send the message to master and forward the reply to self asynchronously. + */ + private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { + for (masterEndpoint <- masterEndpoints) { + masterEndpoint.ask[T](message).onComplete { + case Success(v) => self.send(v) + case Failure(e) => + logWarning(s"Error sending messages to master $masterEndpoint", e) + }(forwardMessageExecutionContext) } } /* Find out driver status then exit the JVM */ def pollAndReportStatus(driverId: String) { - println("... waiting before polling master for driver state") + // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread + // is fine. + logInfo("... waiting before polling master for driver state") Thread.sleep(5000) - println("... polling master for driver state") - val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout) - .mapTo[DriverStatusResponse] - val statusResponse = Await.result(statusFuture, timeout) + logInfo("... polling master for driver state") + val statusResponse = + activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId)) statusResponse.found match { case false => - println(s"ERROR: Cluster master did not recognize $driverId") + logError(s"ERROR: Cluster master did not recognize $driverId") System.exit(-1) case true => - println(s"State of $driverId is ${statusResponse.state.get}") + logInfo(s"State of $driverId is ${statusResponse.state.get}") // Worker node, if present (statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match { case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) => - println(s"Driver running on $hostPort ($id)") + logInfo(s"Driver running on $hostPort ($id)") case _ => } // Exception, if present statusResponse.exception.map { e => - println(s"Exception from cluster was: $e") + logError(s"Exception from cluster was: $e") e.printStackTrace() System.exit(-1) } @@ -127,50 +145,62 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) } } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { - case SubmitDriverResponse(success, driverId, message) => - println(message) + case SubmitDriverResponse(master, success, driverId, message) => + logInfo(message) if (success) { - activeMasterActor = context.actorSelection(sender.path) + activeMasterEndpoint = master pollAndReportStatus(driverId.get) } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } - case KillDriverResponse(driverId, success, message) => - println(message) + case KillDriverResponse(master, driverId, success, message) => + logInfo(message) if (success) { - activeMasterActor = context.actorSelection(sender.path) + activeMasterEndpoint = master pollAndReportStatus(driverId) } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } + } - case DisassociatedEvent(_, remoteAddress, _) => - if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master $remoteAddress.") - lostMasters += remoteAddress - // Note that this heuristic does not account for the fact that a Master can recover within - // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This - // is not currently a concern, however, because this client does not retry submissions. - if (lostMasters.size >= masterActors.size) { - println("No master is available, exiting.") - System.exit(-1) - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + logError(s"Error connecting to master $remoteAddress.") + lostMasters += remoteAddress + // Note that this heuristic does not account for the fact that a Master can recover within + // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This + // is not currently a concern, however, because this client does not retry submissions. + if (lostMasters.size >= masterEndpoints.size) { + logError("No master is available, exiting.") + System.exit(-1) } + } + } - case AssociationErrorEvent(cause, _, remoteAddress, _, _) => - if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master ($remoteAddress).") - println(s"Cause was: $cause") - lostMasters += remoteAddress - if (lostMasters.size >= masterActors.size) { - println("No master is available, exiting.") - System.exit(-1) - } + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + logError(s"Error connecting to master ($remoteAddress).") + logError(s"Cause was: $cause") + lostMasters += remoteAddress + if (lostMasters.size >= masterEndpoints.size) { + logError("No master is available, exiting.") + System.exit(-1) } + } + } + + override def onError(cause: Throwable): Unit = { + logError(s"Error processing messages, exiting.") + cause.printStackTrace() + System.exit(-1) + } + + override def onStop(): Unit = { + forwardMessageThread.shutdownNow() } } @@ -179,10 +209,12 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) */ object Client { def main(args: Array[String]) { + // scalastyle:off println if (!sys.props.contains("SPARK_SUBMIT")) { println("WARNING: This client is deprecated and will be removed in a future version of Spark") println("Use ./bin/spark-submit with \"--master spark://host:port\"") } + // scalastyle:on println val conf = new SparkConf() val driverArgs = new ClientArguments(args) @@ -194,15 +226,13 @@ object Client { conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) - val (actorSystem, _) = AkkaUtils.createActorSystem( - "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) + val rpcEnv = + RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely - for (m <- driverArgs.masters) { - Master.toAkkaUrl(m, AkkaUtils.protocol(actorSystem)) - } - actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) + val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL). + map(rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, _, Master.ENDPOINT_NAME)) + rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf)) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 316e2d59f01b8..72cc330a398da 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -72,9 +72,11 @@ private[deploy] class ClientArguments(args: Array[String]) { cmd = "launch" if (!ClientArguments.isValidJarUrl(_jarUrl)) { + // scalastyle:off println println(s"Jar url '${_jarUrl}' is not in valid format.") println(s"Must be a jar file path in URL format " + "(e.g. hdfs://host:port/XX.jar, file:///XX.jar)") + // scalastyle:on println printUsageAndExit(-1) } @@ -110,14 +112,16 @@ private[deploy] class ClientArguments(args: Array[String]) { | (default: $DEFAULT_SUPERVISE) | -v, --verbose Print more debugging output """.stripMargin + // scalastyle:off println System.err.println(usage) + // scalastyle:on println System.exit(exitCode) } } private[deploy] object ClientArguments { val DEFAULT_CORES = 1 - val DEFAULT_MEMORY = 512 // MB + val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // MB val DEFAULT_SUPERVISE = false def isValidJarUrl(s: String): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 9db6fd1ac4dbe..12727de9b4cf3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -24,11 +24,12 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable -/** Contains messages sent between Scheduler actor nodes. */ +/** Contains messages sent between Scheduler endpoint nodes. */ private[deploy] object DeployMessages { // Worker to Master @@ -37,6 +38,7 @@ private[deploy] object DeployMessages { id: String, host: String, port: Int, + worker: RpcEndpointRef, cores: Int, memory: Int, webUiPort: Int, @@ -63,11 +65,11 @@ private[deploy] object DeployMessages { case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription], driverIds: Seq[String]) - case class Heartbeat(workerId: String) extends DeployMessage + case class Heartbeat(workerId: String, worker: RpcEndpointRef) extends DeployMessage // Master to Worker - case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage + case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage case class RegisterWorkerFailed(message: String) extends DeployMessage @@ -92,13 +94,13 @@ private[deploy] object DeployMessages { // Worker internal - case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders + case object WorkDirCleanup // Sent to Worker endpoint periodically for cleaning up app folders case object ReregisterWithMaster // used when a worker attempts to reconnect to a master // AppClient to Master - case class RegisterApplication(appDescription: ApplicationDescription) + case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef) extends DeployMessage case class UnregisterApplication(appId: String) @@ -107,7 +109,7 @@ private[deploy] object DeployMessages { // Master to AppClient - case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage + case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends DeployMessage // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { @@ -123,12 +125,14 @@ private[deploy] object DeployMessages { case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage - case class SubmitDriverResponse(success: Boolean, driverId: Option[String], message: String) + case class SubmitDriverResponse( + master: RpcEndpointRef, success: Boolean, driverId: Option[String], message: String) extends DeployMessage case class RequestKillDriver(driverId: String) extends DeployMessage - case class KillDriverResponse(driverId: String, success: Boolean, message: String) + case class KillDriverResponse( + master: RpcEndpointRef, driverId: String, success: Boolean, message: String) extends DeployMessage case class RequestDriverStatus(driverId: String) extends DeployMessage @@ -142,7 +146,7 @@ private[deploy] object DeployMessages { // Master to Worker & AppClient - case class MasterChanged(masterUrl: String, masterWebUiUrl: String) + case class MasterChanged(master: RpcEndpointRef, masterWebUiUrl: String) // MasterWebUI To Master diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 2954f932b4f41..ccffb36652988 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -76,12 +76,13 @@ private[deploy] object JsonProtocol { } def writeMasterState(obj: MasterStateResponse): JObject = { + val aliveWorkers = obj.workers.filter(_.isAlive()) ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ - ("cores" -> obj.workers.map(_.cores).sum) ~ - ("coresused" -> obj.workers.map(_.coresUsed).sum) ~ - ("memory" -> obj.workers.map(_.memory).sum) ~ - ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~ + ("cores" -> aliveWorkers.map(_.cores).sum) ~ + ("coresused" -> aliveWorkers.map(_.coresUsed).sum) ~ + ("memory" -> aliveWorkers.map(_.memory).sum) ~ + ("memoryused" -> aliveWorkers.map(_.memoryUsed).sum) ~ ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~ ("activedrivers" -> obj.activeDrivers.toList.map(writeDriverInfo)) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 0550f00a172ab..53356addf6edb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorSystem - +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master @@ -41,8 +40,8 @@ class LocalSparkCluster( extends Logging { private val localHostname = Utils.localHostName() - private val masterActorSystems = ArrayBuffer[ActorSystem]() - private val workerActorSystems = ArrayBuffer[ActorSystem]() + private val masterRpcEnvs = ArrayBuffer[RpcEnv]() + private val workerRpcEnvs = ArrayBuffer[RpcEnv]() // exposed for testing var masterWebUIPort = -1 @@ -55,18 +54,17 @@ class LocalSparkCluster( .set("spark.shuffle.service.enabled", "false") /* Start the Master */ - val (masterSystem, masterPort, webUiPort, _) = - Master.startSystemAndActor(localHostname, 0, 0, _conf) + val (rpcEnv, webUiPort, _) = Master.startRpcEnvAndEndpoint(localHostname, 0, 0, _conf) masterWebUIPort = webUiPort - masterActorSystems += masterSystem - val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + masterPort + masterRpcEnvs += rpcEnv + val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + rpcEnv.address.port val masters = Array(masterUrl) /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, + val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker, memoryPerWorker, masters, null, Some(workerNum), _conf) - workerActorSystems += workerSystem + workerRpcEnvs += workerEnv } masters @@ -77,11 +75,11 @@ class LocalSparkCluster( // Stop the workers before the master so they don't get upset that it disconnected // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! // This is unfortunate, but for now we just comment it out. - workerActorSystems.foreach(_.shutdown()) + workerRpcEnvs.foreach(_.shutdown()) // workerActorSystems.foreach(_.awaitTermination()) - masterActorSystems.foreach(_.shutdown()) + masterRpcEnvs.foreach(_.shutdown()) // masterActorSystems.foreach(_.awaitTermination()) - masterActorSystems.clear() - workerActorSystems.clear() + masterRpcEnvs.clear() + workerRpcEnvs.clear() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index e99779f299785..c0cab22fa8252 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.fs.Path -import org.apache.spark.api.r.RBackend +import org.apache.spark.api.r.{RBackend, RUtils} import org.apache.spark.util.RedirectThread /** @@ -71,9 +71,10 @@ object RRunner { val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) - val sparkHome = System.getenv("SPARK_HOME") + val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) + env.put("SPARKR_PACKAGE_DIR", rPackageDir) env.put("R_PROFILE_USER", - Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator)) + Seq(rPackageDir, "SparkR", "profile", "general.R").mkString(File.separator)) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() @@ -85,7 +86,9 @@ object RRunner { } System.exit(returnCode) } else { + // scalastyle:off println System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + // scalastyle:on println System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 7fa75ac8c2b54..6d14590a1d192 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -334,6 +334,19 @@ class SparkHadoopUtil extends Logging { * Stop the thread that does the delegation token updates. */ private[spark] def stopExecutorDelegationTokenRenewer() {} + + /** + * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism. + * This is to prevent the DFSClient from using an old cached token to connect to the NameNode. + */ + private[spark] def getConfBypassingFSCache( + hadoopConf: Configuration, + scheme: String): Configuration = { + val newConf = new Configuration(hadoopConf) + val confKey = s"fs.${scheme}.impl.disable.cache" + newConf.setBoolean(confKey, true) + newConf + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index abf222757a95b..7089a7e26707f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -37,6 +37,7 @@ import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} +import org.apache.spark.api.r.RUtils import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -79,9 +80,11 @@ object SparkSubmit { private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" private val SPARKR_SHELL = "sparkr-shell" + private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + // scalastyle:off println // Exposed for testing private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) private[spark] var printStream: PrintStream = System.err @@ -102,11 +105,14 @@ object SparkSubmit { printStream.println("Type --help for more information.") exitFn(0) } + // scalastyle:on println def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { + // scalastyle:off println printStream.println(appArgs) + // scalastyle:on println } appArgs.action match { case SparkSubmitAction.SUBMIT => submit(appArgs) @@ -160,7 +166,9 @@ object SparkSubmit { // makes the message printed to the output by the JVM not very helpful. Instead, // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { + // scalastyle:off println printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") + // scalastyle:on println exitFn(1) } else { throw e @@ -178,7 +186,9 @@ object SparkSubmit { // to use the legacy gateway if the master endpoint turns out to be not a REST server. if (args.isStandaloneCluster && args.useRest) { try { + // scalastyle:off println printStream.println("Running Spark using the REST application submission protocol.") + // scalastyle:on println doRunMain() } catch { // Fail over to use the legacy submission gateway @@ -254,6 +264,12 @@ object SparkSubmit { } } + // Update args.deployMode if it is null. It will be passed down as a Spark property later. + (args.deployMode, deployMode) match { + case (null, CLIENT) => args.deployMode = "client" + case (null, CLUSTER) => args.deployMode = "cluster" + case _ => + } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER @@ -339,6 +355,23 @@ object SparkSubmit { } } + // In YARN mode for an R app, add the SparkR package archive to archives + // that can be distributed with the job + if (args.isR && clusterManager == YARN) { + val rPackagePath = RUtils.localSparkRPackagePath + if (rPackagePath.isEmpty) { + printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") + } + val rPackageFile = new File(rPackagePath.get, SPARKR_PACKAGE_ARCHIVE) + if (!rPackageFile.exists()) { + printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") + } + val localURI = Utils.resolveURI(rPackageFile.getAbsolutePath) + + // Assigns a symbol link name "sparkr" to the shipped package. + args.archives = mergeFileLists(args.archives, localURI.toString + "#sparkr") + } + // If we're running a R app, set the main class to our specific R runner if (args.isR && deployMode == CLIENT) { if (args.primaryResource == SPARKR_SHELL) { @@ -367,6 +400,8 @@ object SparkSubmit { // All cluster managers OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), + OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + sysProp = "spark.submit.deployMode"), OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), @@ -558,6 +593,7 @@ object SparkSubmit { sysProps: Map[String, String], childMainClass: String, verbose: Boolean): Unit = { + // scalastyle:off println if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") @@ -565,6 +601,7 @@ object SparkSubmit { printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } + // scalastyle:on println val loader = if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { @@ -592,8 +629,10 @@ object SparkSubmit { case e: ClassNotFoundException => e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { + // scalastyle:off println printStream.println(s"Failed to load main class $childMainClass.") printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") + // scalastyle:on println } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } @@ -756,6 +795,22 @@ private[spark] object SparkSubmitUtils { val cr = new ChainResolver cr.setName("list") + val repositoryList = remoteRepos.getOrElse("") + // add any other remote repositories other than maven central + if (repositoryList.trim.nonEmpty) { + repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + // scalastyle:off println + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + // scalastyle:on println + } + } + val localM2 = new IBiblioResolver localM2.setM2compatible(true) localM2.setRoot(m2Path.toURI.toString) @@ -786,20 +841,6 @@ private[spark] object SparkSubmitUtils { sp.setRoot("http://dl.bintray.com/spark-packages/maven") sp.setName("spark-packages") cr.add(sp) - - val repositoryList = remoteRepos.getOrElse("") - // add any other remote repositories other than maven central - if (repositoryList.trim.nonEmpty) { - repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => - val brr: IBiblioResolver = new IBiblioResolver - brr.setM2compatible(true) - brr.setUsepoms(true) - brr.setRoot(repo) - brr.setName(s"repo-${i + 1}") - cr.add(brr) - printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") - } - } cr } @@ -829,7 +870,9 @@ private[spark] object SparkSubmitUtils { val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) val dd = new DefaultDependencyDescriptor(ri, false, false) dd.addDependencyConfiguration(ivyConfName, ivyConfName) + // scalastyle:off println printStream.println(s"${dd.getDependencyId} added as a dependency") + // scalastyle:on println md.addDependency(dd) } } @@ -896,9 +939,11 @@ private[spark] object SparkSubmitUtils { ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) new File(alternateIvyCache, "jars") } + // scalastyle:off println printStream.println( s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // scalastyle:on println // create a pattern matcher ivySettings.addMatcher(new GlobPatternMatcher) // create the dependency resolvers @@ -922,6 +967,15 @@ private[spark] object SparkSubmitUtils { // A Module descriptor must be specified. Entries are dummy strings val md = getModuleDescriptor + // clear ivy resolution from previous launches. The resolution file is usually at + // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file + // leads to confusion with Ivy when the files can no longer be found at the repository + // declared in that file/ + val mdId = md.getModuleRevisionId + val previousResolution = new File(ivySettings.getDefaultCache, + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") + if (previousResolution.exists) previousResolution.delete + md.setDefaultConf(ivyConfName) // Add exclusion rules for Spark and Scala Library diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index b7429a901e162..ebb39c354dff1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -79,6 +79,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() + // scalastyle:off println if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => Utils.getPropertiesFromFile(filename).foreach { case (k, v) => @@ -86,6 +87,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } } + // scalastyle:on println defaultProperties } @@ -162,6 +164,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orNull executorCores = Option(executorCores) .orElse(sparkProperties.get("spark.executor.cores")) + .orElse(env.get("SPARK_EXECUTOR_CORES")) .orNull totalExecutorCores = Option(totalExecutorCores) .orElse(sparkProperties.get("spark.cores.max")) @@ -451,6 +454,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { + // scalastyle:off println val outStream = SparkSubmit.printStream if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) @@ -461,8 +465,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin) outStream.println(command) + val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB outStream.println( - """ + s""" |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or @@ -488,7 +493,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --properties-file FILE Path to a file from which to load extra properties. If not | specified, this will look for conf/spark-defaults.conf. | - | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512M). + | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: ${mem_mb}M). | --driver-java-options Extra Java options to pass to the driver. | --driver-library-path Extra library path entries to pass to the driver. | --driver-class-path Extra class path entries to pass to the driver. Note that @@ -539,6 +544,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S outStream.println("CLI options:") outStream.println(getSqlShellOptions()) } + // scalastyle:on println SparkSubmit.exitFn(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 43c8a934c311a..79b251e7e62fe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -17,20 +17,17 @@ package org.apache.spark.deploy.client -import java.util.concurrent.TimeoutException +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ThreadUtils, Utils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -40,98 +37,143 @@ import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} * @param masterUrls Each url should look like spark://host:port. */ private[spark] class AppClient( - actorSystem: ActorSystem, + rpcEnv: RpcEnv, masterUrls: Array[String], appDescription: ApplicationDescription, listener: AppClientListener, conf: SparkConf) extends Logging { - private val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) + private val masterRpcAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) - private val REGISTRATION_TIMEOUT = 20.seconds + private val REGISTRATION_TIMEOUT_SECONDS = 20 private val REGISTRATION_RETRIES = 3 - private var masterAddress: Address = null - private var actor: ActorRef = null + private var endpoint: RpcEndpointRef = null private var appId: String = null - private var registered = false - private var activeMasterUrl: String = null + @volatile private var registered = false + + private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint + with Logging { + + private var master: Option[RpcEndpointRef] = None + // To avoid calling listener.disconnected() multiple times + private var alreadyDisconnected = false + @volatile private var alreadyDead = false // To avoid calling listener.dead() multiple times + @volatile private var registerMasterFutures: Array[JFuture[_]] = null + @volatile private var registrationRetryTimer: JScheduledFuture[_] = null + + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) - private class ClientActor extends Actor with ActorLogReceive with Logging { - var master: ActorSelection = null - var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times - var alreadyDead = false // To avoid calling listener.dead() multiple times - var registrationRetryTimer: Option[Cancellable] = None + // A scheduled executor for scheduling the registration actions + private val registrationRetryThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread") - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + override def onStart(): Unit = { try { - registerWithMaster() + registerWithMaster(1) } catch { case e: Exception => logWarning("Failed to connect to master", e) markDisconnected() - context.stop(self) + stop() } } - def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterApplication(appDescription) + /** + * Register with all masters asynchronously and returns an array `Future`s for cancellation. + */ + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + for (masterAddress <- masterRpcAddresses) yield { + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = try { + if (registered) { + return + } + logInfo("Connecting to master " + masterAddress.toSparkURL + "...") + val masterRef = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterRef.send(RegisterApplication(appDescription, self)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + }) } } - def registerWithMaster() { - tryRegisterAllMasters() - import context.dispatcher - var retries = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { + /** + * Register with all masters asynchronously. It will call `registerWithMaster` every + * REGISTRATION_TIMEOUT_SECONDS seconds until exceeding REGISTRATION_RETRIES times. + * Once we connect to a master successfully, all scheduling work and Futures will be cancelled. + * + * nthRetry means this is the nth attempt to register with master. + */ + private def registerWithMaster(nthRetry: Int) { + registerMasterFutures = tryRegisterAllMasters() + registrationRetryTimer = registrationRetryThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = { Utils.tryOrExit { - retries += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - } else if (retries >= REGISTRATION_RETRIES) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() + } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { - tryRegisterAllMasters() + registerMasterFutures.foreach(_.cancel(true)) + registerWithMaster(nthRetry + 1) } } } - } + }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) } - def changeMaster(url: String) { - // activeMasterUrl is a valid Spark url since we receive it from master. - activeMasterUrl = url - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem)) + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => logWarning(s"Drop $message because has not yet connected to master") + } } - private def isPossibleMaster(remoteUrl: Address) = { - masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort) + private def isPossibleMaster(remoteAddress: RpcAddress): Boolean = { + masterRpcAddresses.contains(remoteAddress) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisteredApplication(appId_, masterUrl) => + override def receive: PartialFunction[Any, Unit] = { + case RegisteredApplication(appId_, masterRef) => + // FIXME How to handle the following cases? + // 1. A master receives multiple registrations and sends back multiple + // RegisteredApplications due to an unstable network. + // 2. Receive multiple RegisteredApplication from different masters because the master is + // changing. appId = appId_ registered = true - changeMaster(masterUrl) + master = Some(masterRef) listener.connected(appId) case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) - context.stop(self) + stop() case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - master ! ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None) + // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not + // guaranteed), `ExecutorStateChanged` may be sent to a dead master. + sendToMaster(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => @@ -142,24 +184,32 @@ private[spark] class AppClient( listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) } - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + master = Some(masterRef) alreadyDisconnected = false - sender ! MasterChangeAcknowledged(appId) + masterRef.send(MasterChangeAcknowledged(appId)) + } - case DisassociatedEvent(_, address, _) if address == masterAddress => + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case StopAppClient => + markDead("Application has been stopped.") + sendToMaster(UnregisterApplication(appId)) + context.reply(true) + stop() + } + + override def onDisconnected(address: RpcAddress): Unit = { + if (master.exists(_.address == address)) { logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() + } + } - case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => + override def onNetworkError(cause: Throwable, address: RpcAddress): Unit = { + if (isPossibleMaster(address)) { logWarning(s"Could not connect to $address: $cause") - - case StopAppClient => - markDead("Application has been stopped.") - master ! UnregisterApplication(appId) - sender ! true - context.stop(self) + } } /** @@ -179,28 +229,31 @@ private[spark] class AppClient( } } - override def postStop() { - registrationRetryTimer.foreach(_.cancel()) + override def onStop(): Unit = { + if (registrationRetryTimer != null) { + registrationRetryTimer.cancel(true) + } + registrationRetryThread.shutdownNow() + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() } } def start() { // Just launch an actor; it will call back into the listener. - actor = actorSystem.actorOf(Props(new ClientActor)) + endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv)) } def stop() { - if (actor != null) { + if (endpoint != null) { try { - val timeout = RpcUtils.askTimeout(conf) - val future = actor.ask(StopAppClient)(timeout) - Await.result(future, timeout) + endpoint.askWithRetry[Boolean](StopAppClient) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") } - actor = null + endpoint = null } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 40835b9550586..1c79089303e3d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.client +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] object TestClient { @@ -46,13 +47,12 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localHostName(), 0, - conf = conf, securityManager = new SecurityManager(conf)) + val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf)) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener - val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) + val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf) client.start() - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala index c5ac45c6730d3..a98b1fa8f83a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala @@ -19,7 +19,9 @@ package org.apache.spark.deploy.client private[spark] object TestExecutor { def main(args: Array[String]) { + // scalastyle:off println println("Hello world!") + // scalastyle:on println while (true) { Thread.sleep(1000) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 5427a88f32ffd..2cc465e55fceb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -83,12 +83,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] - // Constants used to parse Spark 1.0.0 log directories. - private[history] val LOG_PREFIX = "EVENT_LOG_" - private[history] val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" - private[history] val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" - private[history] val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" - /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -146,7 +140,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) override def getAppUI(appId: String, attemptId: Option[String]): Option[SparkUI] = { try { applications.get(appId).flatMap { appInfo => - appInfo.attempts.find(_.attemptId == attemptId).map { attempt => + appInfo.attempts.find(_.attemptId == attemptId).flatMap { attempt => val replayBus = new ReplayListenerBus() val ui = { val conf = this.conf.clone() @@ -155,20 +149,20 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } - val appListener = new ApplicationEventListener() replayBus.addListener(appListener) val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - - ui.setAppName(s"${appInfo.name} ($appId)") - - val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setAcls(uiAclsEnabled) - // make sure to set admin acls before view acls so they are properly picked up - ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) - ui.getSecurityManager.setViewAcls(attempt.sparkUser, - appListener.viewAcls.getOrElse("")) - ui + appInfo.map { info => + ui.setAppName(s"${info.name} ($appId)") + + val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) + ui.getSecurityManager.setAcls(uiAclsEnabled) + // make sure to set admin acls before view acls so they are properly picked up + ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) + ui.getSecurityManager.setViewAcls(attempt.sparkUser, + appListener.viewAcls.getOrElse("")) + ui + } } } } catch { @@ -282,8 +276,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val newAttempts = logs.flatMap { fileStatus => try { val res = replay(fileStatus, bus) - logInfo(s"Application log ${res.logPath} loaded successfully.") - Some(res) + res match { + case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") + case None => logWarning(s"Failed to load application log ${fileStatus.getPath}. " + + "The application may have not started.") + } + res } catch { case e: Exception => logError( @@ -429,9 +427,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replays the events in the specified log file and returns information about the associated - * application. + * application. Return `None` if the application ID cannot be located. */ - private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationAttemptInfo = { + private def replay( + eventLog: FileStatus, + bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") val logInput = @@ -445,16 +445,24 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) bus.replay(logInput, logPath.toString, !appCompleted) - new FsApplicationAttemptInfo( - logPath.getName(), - appListener.appName.getOrElse(NOT_STARTED), - appListener.appId.getOrElse(logPath.getName()), - appListener.appAttemptId, - appListener.startTime.getOrElse(-1L), - appListener.endTime.getOrElse(-1L), - getModificationTime(eventLog).get, - appListener.sparkUser.getOrElse(NOT_STARTED), - appCompleted) + + // Without an app ID, new logs will render incorrectly in the listing page, so do not list or + // try to show their UI. Some old versions of Spark generate logs without an app ID, so let + // logs generated by those versions go through. + if (appListener.appId.isDefined || !sparkVersionHasAppId(eventLog)) { + Some(new FsApplicationAttemptInfo( + logPath.getName(), + appListener.appName.getOrElse(NOT_STARTED), + appListener.appId.getOrElse(logPath.getName()), + appListener.appAttemptId, + appListener.startTime.getOrElse(-1L), + appListener.endTime.getOrElse(-1L), + getModificationTime(eventLog).get, + appListener.sparkUser.getOrElse(NOT_STARTED), + appCompleted)) + } else { + None + } } finally { logInput.close() } @@ -529,10 +537,34 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + /** + * Returns whether the version of Spark that generated logs records app IDs. App IDs were added + * in Spark 1.1. + */ + private def sparkVersionHasAppId(entry: FileStatus): Boolean = { + if (isLegacyLogDirectory(entry)) { + fs.listStatus(entry.getPath()) + .find { status => status.getPath().getName().startsWith(SPARK_VERSION_PREFIX) } + .map { status => + val version = status.getPath().getName().substring(SPARK_VERSION_PREFIX.length()) + version != "1.0" && version != "1.1" + } + .getOrElse(true) + } else { + true + } + } + } -private object FsHistoryProvider { +private[history] object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" + + // Constants used to parse Spark 1.0.0 log directories. + val LOG_PREFIX = "EVENT_LOG_" + val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" + val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" + val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" } private class FsApplicationAttemptInfo( diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 4692d22651c93..18265df9faa2c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -56,6 +56,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin Utils.loadDefaultSparkProperties(conf, propertiesFile) private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( """ |Usage: HistoryServer [options] @@ -84,6 +85,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin | spark.history.fs.updateInterval How often to reload log data from storage | (in seconds, default: 10) |""".stripMargin) + // scalastyle:on println System.exit(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 1620e95bea218..aa54ed9360f36 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -22,10 +22,9 @@ import java.util.Date import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorRef - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.ApplicationDescription +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class ApplicationInfo( @@ -33,7 +32,7 @@ private[spark] class ApplicationInfo( val id: String, val desc: ApplicationDescription, val submitDate: Date, - val driver: ActorRef, + val driver: RpcEndpointRef, defaultCores: Int) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index fccceb3ea528b..48070768f6edb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -21,20 +21,18 @@ import java.io.FileNotFoundException import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date +import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.Serialization import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil} @@ -47,23 +45,27 @@ import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} private[master] class Master( - host: String, - port: Int, + override val rpcEnv: RpcEnv, + address: RpcAddress, webUiPort: Int, val securityMgr: SecurityManager, val conf: SparkConf) - extends Actor with ActorLogReceive with Logging with LeaderElectable { + extends ThreadSafeRpcEndpoint with Logging with LeaderElectable { - import context.dispatcher // to use Akka's scheduler.schedule() + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") + + // TODO Remove it once we don't use akka.serialization.Serialization + private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - private val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000 + private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) @@ -75,10 +77,10 @@ private[master] class Master( val apps = new HashSet[ApplicationInfo] private val idToWorker = new HashMap[String, WorkerInfo] - private val addressToWorker = new HashMap[Address, WorkerInfo] + private val addressToWorker = new HashMap[RpcAddress, WorkerInfo] - private val actorToApp = new HashMap[ActorRef, ApplicationInfo] - private val addressToApp = new HashMap[Address, ApplicationInfo] + private val endpointToApp = new HashMap[RpcEndpointRef, ApplicationInfo] + private val addressToApp = new HashMap[RpcAddress, ApplicationInfo] private val completedApps = new ArrayBuffer[ApplicationInfo] private var nextAppNumber = 0 private val appIdToUI = new HashMap[String, SparkUI] @@ -89,21 +91,22 @@ private[master] class Master( private val waitingDrivers = new ArrayBuffer[DriverInfo] private var nextDriverNumber = 0 - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(address.host, "Expected hostname") private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, securityMgr) private val masterSource = new MasterSource(this) - private val webUi = new MasterWebUI(this, webUiPort) + // After onStart, webUi will be set + private var webUi: MasterWebUI = null private val masterPublicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else host + if (envVar != null) envVar else address.host } - private val masterUrl = "spark://" + host + ":" + port + private val masterUrl = address.toSparkURL private var masterWebUiUrl: String = _ private var state = RecoveryState.STANDBY @@ -112,7 +115,9 @@ private[master] class Master( private var leaderElectionAgent: LeaderElectionAgent = _ - private var recoveryCompletionTask: Cancellable = _ + private var recoveryCompletionTask: ScheduledFuture[_] = _ + + private var checkForWorkerTimeOutTask: ScheduledFuture[_] = _ // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app @@ -130,20 +135,23 @@ private[master] class Master( private val restServer = if (restServerEnabled) { val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(host, port, conf, self, masterUrl)) + Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) } else { None } private val restServerBoundPort = restServer.map(_.start()) - override def preStart() { + override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + webUi = new MasterWebUI(this, webUiPort) webUi.bind() masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort - context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) + checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CheckForWorkerTimeOut) + } + }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() @@ -157,16 +165,16 @@ private[master] class Master( case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system)) + new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem)) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system)) + new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem)) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization]) - .newInstance(conf, SerializationExtension(context.system)) + .newInstance(conf, SerializationExtension(actorSystem)) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -176,18 +184,17 @@ private[master] class Master( leaderElectionAgent = leaderElectionAgent_ } - override def preRestart(reason: Throwable, message: Option[Any]) { - super.preRestart(reason, message) // calls postStop()! - logError("Master actor restarted due to exception", reason) - } - - override def postStop() { + override def onStop() { masterMetricsSystem.report() applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master if (recoveryCompletionTask != null) { - recoveryCompletionTask.cancel() + recoveryCompletionTask.cancel(true) } + if (checkForWorkerTimeOutTask != null) { + checkForWorkerTimeOutTask.cancel(true) + } + forwardMessageThread.shutdownNow() webUi.stop() restServer.foreach(_.stop()) masterMetricsSystem.stop() @@ -197,14 +204,14 @@ private[master] class Master( } override def electedLeader() { - self ! ElectedLeader + self.send(ElectedLeader) } override def revokedLeadership() { - self ! RevokedLeadership + self.send(RevokedLeadership) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { @@ -215,8 +222,11 @@ private[master] class Master( logInfo("I have been elected leader! New state: " + state) if (state == RecoveryState.RECOVERING) { beginRecovery(storedApps, storedDrivers, storedWorkers) - recoveryCompletionTask = context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis, self, - CompleteRecovery) + recoveryCompletionTask = forwardMessageThread.schedule(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CompleteRecovery) + } + }, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) } } @@ -227,111 +237,42 @@ private[master] class Master( System.exit(0) } - case RegisterWorker(id, workerHost, workerPort, cores, memory, workerUiPort, publicAddress) => - { + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { // ignore, don't send response } else if (idToWorker.contains(id)) { - sender ! RegisterWorkerFailed("Duplicate worker ID") + workerRef.send(RegisterWorkerFailed("Duplicate worker ID")) } else { val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - sender, workerUiPort, publicAddress) + workerRef, workerUiPort, publicAddress) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - sender ! RegisteredWorker(masterUrl, masterWebUiUrl) + workerRef.send(RegisteredWorker(self, masterWebUiUrl)) schedule() } else { - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address logWarning("Worker registration failed. Attempted to re-register worker at same " + "address: " + workerAddress) - sender ! RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress) - } - } - } - - case RequestSubmitDriver(description) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - "Can only accept driver submissions in ALIVE state." - sender ! SubmitDriverResponse(false, None, msg) - } else { - logInfo("Driver submitted " + description.command.mainClass) - val driver = createDriver(description) - persistenceEngine.addDriver(driver) - waitingDrivers += driver - drivers.add(driver) - schedule() - - // TODO: It might be good to instead have the submission client poll the master to determine - // the current status of the driver. For now it's simply "fire and forget". - - sender ! SubmitDriverResponse(true, Some(driver.id), - s"Driver successfully submitted as ${driver.id}") - } - } - - case RequestKillDriver(driverId) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - s"Can only kill drivers in ALIVE state." - sender ! KillDriverResponse(driverId, success = false, msg) - } else { - logInfo("Asked to kill driver " + driverId) - val driver = drivers.find(_.id == driverId) - driver match { - case Some(d) => - if (waitingDrivers.contains(d)) { - waitingDrivers -= d - self ! DriverStateChanged(driverId, DriverState.KILLED, None) - } else { - // We just notify the worker to kill the driver here. The final bookkeeping occurs - // on the return path when the worker submits a state change back to the master - // to notify it that the driver was successfully killed. - d.worker.foreach { w => - w.actor ! KillDriver(driverId) - } - } - // TODO: It would be nice for this to be a synchronous response - val msg = s"Kill request for $driverId submitted" - logInfo(msg) - sender ! KillDriverResponse(driverId, success = true, msg) - case None => - val msg = s"Driver $driverId has already finished or does not exist" - logWarning(msg) - sender ! KillDriverResponse(driverId, success = false, msg) - } - } - } - - case RequestDriverStatus(driverId) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - "Can only request driver status in ALIVE state." - sender ! DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg))) - } else { - (drivers ++ completedDrivers).find(_.id == driverId) match { - case Some(driver) => - sender ! DriverStatusResponse(found = true, Some(driver.state), - driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) - case None => - sender ! DriverStatusResponse(found = false, None, None, None, None) + workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) } } } - case RegisterApplication(description) => { + case RegisterApplication(description, driver) => { + // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { // ignore, don't send response } else { logInfo("Registering app " + description.name) - val app = createApplication(description, sender) + val app = createApplication(description, driver) registerApplication(app) logInfo("Registered app " + description.name + " with ID " + app.id) persistenceEngine.addApplication(app) - sender ! RegisteredApplication(app.id, masterUrl) + driver.send(RegisteredApplication(app.id, self)) schedule() } } @@ -343,7 +284,7 @@ private[master] class Master( val appInfo = idToApp(appId) exec.state = state if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } - exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") @@ -384,7 +325,7 @@ private[master] class Master( } } - case Heartbeat(workerId) => { + case Heartbeat(workerId, worker) => { idToWorker.get(workerId) match { case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() @@ -392,7 +333,7 @@ private[master] class Master( if (workers.map(_.id).contains(workerId)) { logWarning(s"Got heartbeat from unregistered worker $workerId." + " Asking it to re-register.") - sender ! ReconnectWorker(masterUrl) + worker.send(ReconnectWorker(masterUrl)) } else { logWarning(s"Got heartbeat from unregistered worker $workerId." + " This worker was never registered, so ignoring the heartbeat.") @@ -444,30 +385,103 @@ private[master] class Master( logInfo(s"Received unregister request from application $applicationId") idToApp.get(applicationId).foreach(finishApplication) - case DisassociatedEvent(_, address, _) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - logInfo(s"$address got disassociated, removing it.") - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + case CheckForWorkerTimeOut => { + timeOutDeadWorkers() } + } - case RequestMasterState => { - sender ! MasterStateResponse( - host, port, restServerBoundPort, - workers.toArray, apps.toArray, completedApps.toArray, - drivers.toArray, completedDrivers.toArray, state) + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestSubmitDriver(description) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only accept driver submissions in ALIVE state." + context.reply(SubmitDriverResponse(self, false, None, msg)) + } else { + logInfo("Driver submitted " + description.command.mainClass) + val driver = createDriver(description) + persistenceEngine.addDriver(driver) + waitingDrivers += driver + drivers.add(driver) + schedule() + + // TODO: It might be good to instead have the submission client poll the master to determine + // the current status of the driver. For now it's simply "fire and forget". + + context.reply(SubmitDriverResponse(self, true, Some(driver.id), + s"Driver successfully submitted as ${driver.id}")) + } } - case CheckForWorkerTimeOut => { - timeOutDeadWorkers() + case RequestKillDriver(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + s"Can only kill drivers in ALIVE state." + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } else { + logInfo("Asked to kill driver " + driverId) + val driver = drivers.find(_.id == driverId) + driver match { + case Some(d) => + if (waitingDrivers.contains(d)) { + waitingDrivers -= d + self.send(DriverStateChanged(driverId, DriverState.KILLED, None)) + } else { + // We just notify the worker to kill the driver here. The final bookkeeping occurs + // on the return path when the worker submits a state change back to the master + // to notify it that the driver was successfully killed. + d.worker.foreach { w => + w.endpoint.send(KillDriver(driverId)) + } + } + // TODO: It would be nice for this to be a synchronous response + val msg = s"Kill request for $driverId submitted" + logInfo(msg) + context.reply(KillDriverResponse(self, driverId, success = true, msg)) + case None => + val msg = s"Driver $driverId has already finished or does not exist" + logWarning(msg) + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } + } + } + + case RequestDriverStatus(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only request driver status in ALIVE state." + context.reply( + DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg)))) + } else { + (drivers ++ completedDrivers).find(_.id == driverId) match { + case Some(driver) => + context.reply(DriverStatusResponse(found = true, Some(driver.state), + driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception)) + case None => + context.reply(DriverStatusResponse(found = false, None, None, None, None)) + } + } + } + + case RequestMasterState => { + context.reply(MasterStateResponse( + address.host, address.port, restServerBoundPort, + workers.toArray, apps.toArray, completedApps.toArray, + drivers.toArray, completedDrivers.toArray, state)) } case BoundPortsRequest => { - sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort) + context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort)) } } + override def onDisconnected(address: RpcAddress): Unit = { + // The disconnected client could've been either a worker or an app; remove whichever it was + logInfo(s"$address got disassociated, removing it.") + addressToWorker.get(address).foreach(removeWorker) + addressToApp.get(address).foreach(finishApplication) + if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + } + private def canCompleteRecovery = workers.count(_.state == WorkerState.UNKNOWN) == 0 && apps.count(_.state == ApplicationState.UNKNOWN) == 0 @@ -479,7 +493,7 @@ private[master] class Master( try { registerApplication(app) app.state = ApplicationState.UNKNOWN - app.driver ! MasterChanged(masterUrl, masterWebUiUrl) + app.driver.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("App " + app.id + " had exception on reconnect") } @@ -496,7 +510,7 @@ private[master] class Master( try { registerWorker(worker) worker.state = WorkerState.UNKNOWN - worker.actor ! MasterChanged(masterUrl, masterWebUiUrl) + worker.endpoint.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect") } @@ -505,10 +519,8 @@ private[master] class Master( private def completeRecovery() { // Ensure "only-once" recovery semantics using a short synchronization period. - synchronized { - if (state != RecoveryState.RECOVERING) { return } - state = RecoveryState.COMPLETING_RECOVERY - } + if (state != RecoveryState.RECOVERING) { return } + state = RecoveryState.COMPLETING_RECOVERY // Kill off any workers and apps that didn't respond to us. workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) @@ -623,10 +635,10 @@ private[master] class Master( private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc): Unit = { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(masterUrl, - exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory) - exec.application.driver ! ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) + worker.endpoint.send(LaunchExecutor(masterUrl, + exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)) + exec.application.driver.send(ExecutorAdded( + exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) } private def registerWorker(worker: WorkerInfo): Boolean = { @@ -638,7 +650,7 @@ private[master] class Master( workers -= w } - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address if (addressToWorker.contains(workerAddress)) { val oldWorker = addressToWorker(workerAddress) if (oldWorker.state == WorkerState.UNKNOWN) { @@ -661,11 +673,11 @@ private[master] class Master( logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) idToWorker -= worker.id - addressToWorker -= worker.actor.path.address + addressToWorker -= worker.endpoint.address for (exec <- worker.executors.values) { logInfo("Telling app of lost executor: " + exec.id) - exec.application.driver ! ExecutorUpdated( - exec.id, ExecutorState.LOST, Some("worker lost"), None) + exec.application.driver.send(ExecutorUpdated( + exec.id, ExecutorState.LOST, Some("worker lost"), None)) exec.application.removeExecutor(exec) } for (driver <- worker.drivers.values) { @@ -687,14 +699,15 @@ private[master] class Master( schedule() } - private def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { + private def createApplication(desc: ApplicationDescription, driver: RpcEndpointRef): + ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores) } private def registerApplication(app: ApplicationInfo): Unit = { - val appAddress = app.driver.path.address + val appAddress = app.driver.address if (addressToApp.contains(appAddress)) { logInfo("Attempted to re-register application at same address: " + appAddress) return @@ -703,7 +716,7 @@ private[master] class Master( applicationMetricsSystem.registerSource(app.appSource) apps += app idToApp(app.id) = app - actorToApp(app.driver) = app + endpointToApp(app.driver) = app addressToApp(appAddress) = app waitingApps += app } @@ -717,8 +730,8 @@ private[master] class Master( logInfo("Removing app " + app.id) apps -= app idToApp -= app.id - actorToApp -= app.driver - addressToApp -= app.driver.path.address + endpointToApp -= app.driver + addressToApp -= app.driver.address if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { @@ -735,19 +748,19 @@ private[master] class Master( for (exec <- app.executors.values) { exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id) + exec.worker.endpoint.send(KillExecutor(masterUrl, exec.application.id, exec.id)) exec.state = ExecutorState.KILLED } app.markFinished(state) if (state != ApplicationState.FINISHED) { - app.driver ! ApplicationRemoved(state.toString) + app.driver.send(ApplicationRemoved(state.toString)) } persistenceEngine.removeApplication(app) schedule() // Tell all workers that the application has finished, so they can clean up any app state. workers.foreach { w => - w.actor ! ApplicationFinished(app.id) + w.endpoint.send(ApplicationFinished(app.id)) } } } @@ -768,7 +781,7 @@ private[master] class Master( } val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, None, app.desc.eventLogCodec) + eventLogDir, app.id, app.desc.eventLogCodec) val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) val inProgressExists = fs.exists(new Path(eventLogFilePrefix + EventLoggingListener.IN_PROGRESS)) @@ -832,14 +845,14 @@ private[master] class Master( private def timeOutDeadWorkers() { // Copy the workers into an array so we don't modify the hashset while iterating through it val currentTime = System.currentTimeMillis() - val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray + val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT_MS).toArray for (worker <- toRemove) { if (worker.state != WorkerState.DEAD) { logWarning("Removing %s because we got no heartbeat in %d seconds".format( - worker.id, WORKER_TIMEOUT/1000)) + worker.id, WORKER_TIMEOUT_MS / 1000)) removeWorker(worker) } else { - if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) { + if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT_MS)) { workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it } } @@ -862,7 +875,7 @@ private[master] class Master( logInfo("Launching driver " + driver.id + " on worker " + worker.id) worker.addDriver(driver) driver.worker = Some(worker) - worker.actor ! LaunchDriver(driver.id, driver.desc) + worker.endpoint.send(LaunchDriver(driver.id, driver.desc)) driver.state = DriverState.RUNNING } @@ -891,57 +904,33 @@ private[master] class Master( } private[deploy] object Master extends Logging { - val systemName = "sparkMaster" - private val actorName = "Master" + val SYSTEM_NAME = "sparkMaster" + val ENDPOINT_NAME = "Master" def main(argStrings: Array[String]) { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) - val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) - actorSystem.awaitTermination() - } - - /** - * Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid - */ - def toAkkaUrl(sparkUrl: String, protocol: String): String = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - AkkaUtils.address(protocol, systemName, host, port, actorName) - } - - /** - * Returns an akka `Address` for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid - */ - def toAkkaAddress(sparkUrl: String, protocol: String): Address = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - Address(protocol, systemName, host, port) + val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf) + rpcEnv.awaitTermination() } /** - * Start the Master and return a four tuple of: - * (1) The Master actor system - * (2) The bound port - * (3) The web UI bound port - * (4) The REST server bound port, if any + * Start the Master and return a three tuple of: + * (1) The Master RpcEnv + * (2) The web UI bound port + * (3) The REST server bound port, if any */ - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, - conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = { + conf: SparkConf): (RpcEnv, Int, Option[Int]) = { val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, - securityManager = securityMgr) - val actor = actorSystem.actorOf( - Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName) - val timeout = RpcUtils.askTimeout(conf) - val portsRequest = actor.ask(BoundPortsRequest)(timeout) - val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] - (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort) + val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr) + val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, + new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf)) + val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest) + (rpcEnv, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 435b9b12f83b8..44cefbc77f08e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -85,6 +85,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { * Print usage and exit JVM with the given exit code. */ private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Master [options]\n" + "\n" + @@ -95,6 +96,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8080)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index 15c6296888f70..68c937188b333 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -28,7 +28,7 @@ private[master] object MasterMessages { case object RevokedLeadership - // Actor System to Master + // Master to itself case object CheckForWorkerTimeOut diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 9b3d48c6edc84..f751966605206 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -19,9 +19,7 @@ package org.apache.spark.deploy.master import scala.collection.mutable -import akka.actor.ActorRef - -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class WorkerInfo( @@ -30,7 +28,7 @@ private[spark] class WorkerInfo( val port: Int, val cores: Int, val memory: Int, - val actor: ActorRef, + val endpoint: RpcEndpointRef, val webUiPort: Int, val publicAddress: String) extends Serializable { @@ -107,4 +105,6 @@ private[spark] class WorkerInfo( def setState(state: WorkerState.Value): Unit = { this.state = state } + + def isAlive(): Boolean = this.state == WorkerState.ALIVE } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 52758d6a7c4be..6fdff86f66e01 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -17,10 +17,7 @@ package org.apache.spark.deploy.master -import akka.actor.ActorRef - import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} import org.apache.spark.deploy.SparkCuratorUtil diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 06e265f99e231..e28e7e379ac91 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -19,11 +19,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask - import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.ExecutorDesc @@ -32,14 +29,12 @@ import org.apache.spark.util.Utils private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private val master = parent.masterEndpointRef /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithRetry[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 6a7c74020bace..c3e20ebf8d6eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -19,25 +19,21 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import org.json4s.JValue import org.apache.spark.deploy.JsonProtocol -import org.apache.spark.deploy.DeployMessages.{RequestKillDriver, MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver, MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master._ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private val master = parent.masterEndpointRef def getMasterState: MasterStateResponse = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - Await.result(stateFuture, timeout) + master.askWithRetry[MasterStateResponse](RequestMasterState) } override def renderJson(request: HttpServletRequest): JValue = { @@ -53,7 +49,9 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } def handleDriverKillRequest(request: HttpServletRequest): Unit = { - handleKillRequest(request, id => { master ! RequestKillDriver(id) }) + handleKillRequest(request, id => { + master.ask[KillDriverResponse](RequestKillDriver(id)) + }) } private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 2111a8581f2e4..6174fc11f83d8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -23,7 +23,6 @@ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource UIRoot} import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.RpcUtils /** * Web UI server for the standalone master. @@ -33,8 +32,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging with UIRoot { - val masterActorRef = master.self - val timeout = RpcUtils.askTimeout(master.conf) + val masterEndpointRef = master.self val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) val masterPage = new MasterPage(this) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 894cb78d8591a..5accaf78d0a51 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -54,7 +54,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case ("--master" | "-m") :: value :: tail => if (!value.startsWith("mesos://")) { + // scalastyle:off println System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + // scalastyle:on println System.exit(1) } masterUrl = value.stripPrefix("mesos://") @@ -73,7 +75,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case Nil => { if (masterUrl == null) { + // scalastyle:off println System.err.println("--master is required") + // scalastyle:on println printUsageAndExit(1) } } @@ -83,6 +87,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: } private def printUsageAndExit(exitCode: Int): Unit = { + // scalastyle:off println System.err.println( "Usage: MesosClusterDispatcher [options]\n" + "\n" + @@ -96,6 +101,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: " Zookeeper for persistence\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 502b9bb701ccf..d5b9bcab1423f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -20,10 +20,10 @@ package org.apache.spark.deploy.rest import java.io.File import javax.servlet.http.HttpServletResponse -import akka.actor.ActorRef import org.apache.spark.deploy.ClientArguments._ import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.util.Utils import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} /** @@ -45,35 +45,34 @@ import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} * @param host the address this server should bind to * @param requestedPort the port this server will attempt to bind to * @param masterConf the conf used by the Master - * @param masterActor reference to the Master actor to which requests can be sent + * @param masterEndpoint reference to the Master endpoint to which requests can be sent * @param masterUrl the URL of the Master new drivers will attempt to connect to */ private[deploy] class StandaloneRestServer( host: String, requestedPort: Int, masterConf: SparkConf, - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String) extends RestSubmissionServer(host, requestedPort, masterConf) { protected override val submitRequestServlet = - new StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) + new StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) protected override val killRequestServlet = - new StandaloneKillRequestServlet(masterActor, masterConf) + new StandaloneKillRequestServlet(masterEndpoint, masterConf) protected override val statusRequestServlet = - new StandaloneStatusRequestServlet(masterActor, masterConf) + new StandaloneStatusRequestServlet(masterEndpoint, masterConf) } /** * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) extends KillRequestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse = { - val askTimeout = RpcUtils.askTimeout(conf) - val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( - DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.KillDriverResponse]( + DeployMessages.RequestKillDriver(submissionId)) val k = new KillSubmissionResponse k.serverSparkVersion = sparkVersion k.message = response.message @@ -86,13 +85,12 @@ private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: Sp /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) extends StatusRequestServlet { protected def handleStatus(submissionId: String): SubmissionStatusResponse = { - val askTimeout = RpcUtils.askTimeout(conf) - val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( - DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.DriverStatusResponse]( + DeployMessages.RequestDriverStatus(submissionId)) val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } val d = new SubmissionStatusResponse d.serverSparkVersion = sparkVersion @@ -110,7 +108,7 @@ private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. */ private[rest] class StandaloneSubmitRequestServlet( - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String, conf: SparkConf) extends SubmitRequestServlet { @@ -175,10 +173,9 @@ private[rest] class StandaloneSubmitRequestServlet( responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { requestMessage match { case submitRequest: CreateSubmissionRequest => - val askTimeout = RpcUtils.askTimeout(conf) val driverDescription = buildDriverDescription(submitRequest) - val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( - DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription)) val submitResponse = new CreateSubmissionResponse submitResponse.serverSparkVersion = sparkVersion submitResponse.message = response.message diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 8198296eeb341..868cc35d06ef3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -59,7 +59,7 @@ private[mesos] class MesosSubmitRequestServlet( extends SubmitRequestServlet { private val DEFAULT_SUPERVISE = false - private val DEFAULT_MEMORY = 512 // mb + private val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // mb private val DEFAULT_CORES = 1.0 private val nextDriverNumber = new AtomicLong(0) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 1386055eb8c48..ec51c3d935d8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -21,7 +21,6 @@ import java.io._ import scala.collection.JavaConversions._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.fs.Path @@ -31,6 +30,7 @@ import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{Utils, Clock, SystemClock} /** @@ -43,7 +43,7 @@ private[deploy] class DriverRunner( val workDir: File, val sparkHome: File, val driverDesc: DriverDescription, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerUrl: String, val securityManager: SecurityManager) extends Logging { @@ -107,7 +107,7 @@ private[deploy] class DriverRunner( finalState = Some(state) - worker ! DriverStateChanged(driverId, state, finalException) + worker.send(DriverStateChanged(driverId, state, finalException)) } }.start() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index d1a12b01e78f7..2d6be3042c905 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -60,7 +60,9 @@ object DriverWrapper { rpcEnv.shutdown() case _ => + // scalastyle:off println System.err.println("Usage: DriverWrapper [options]") + // scalastyle:on println System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index fff17e1095042..29a5042285578 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -21,10 +21,10 @@ import java.io._ import scala.collection.JavaConversions._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged @@ -41,7 +41,7 @@ private[deploy] class ExecutorRunner( val appDesc: ApplicationDescription, val cores: Int, val memory: Int, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerId: String, val host: String, val webUiPort: Int, @@ -91,7 +91,7 @@ private[deploy] class ExecutorRunner( process.destroy() exitCode = Some(process.waitFor()) } - worker ! ExecutorStateChanged(appId, execId, state, message, exitCode) + worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) } /** Stop this executor runner, including killing the process it launched */ @@ -159,7 +159,7 @@ private[deploy] class ExecutorRunner( val exitCode = process.waitFor() state = ExecutorState.EXITED val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)) + worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))) } catch { case interrupted: InterruptedException => { logInfo("Runner thread for executor " + fullId + " interrupted") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ebc6cd76c6afd..82e9578bbcba5 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -21,15 +21,14 @@ import java.io.File import java.io.IOException import java.text.SimpleDateFormat import java.util.{UUID, Date} +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import scala.concurrent.duration._ -import scala.language.postfixOps +import scala.concurrent.ExecutionContext import scala.util.Random - -import akka.actor._ -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} @@ -38,32 +37,39 @@ import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} -/** - * @param masterAkkaUrls Each url should be a valid akka url. - */ private[worker] class Worker( - host: String, - port: Int, + override val rpcEnv: RpcEnv, webUiPort: Int, cores: Int, memory: Int, - masterAkkaUrls: Array[String], - actorSystemName: String, - actorName: String, + masterRpcAddresses: Array[RpcAddress], + systemName: String, + endpointName: String, workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging { - import context.dispatcher + extends ThreadSafeRpcEndpoint with Logging { + + private val host = rpcEnv.address.host + private val port = rpcEnv.address.port Utils.checkHost(host, "Expected hostname") assert (port > 0) + // A scheduled executor used to send messages at the specified time. + private val forwordMessageScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") + + // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future` + // methods. + private val cleanupThreadExecutor = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) + // For worker and executor IDs private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -79,32 +85,26 @@ private[worker] class Worker( val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits) randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND } - private val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 * - REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds - private val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60 - * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds + private val INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(10 * + REGISTRATION_RETRY_FUZZ_MULTIPLIER)) + private val PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(60 + * REGISTRATION_RETRY_FUZZ_MULTIPLIER)) private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders private val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - private val APP_DATA_RETENTION_SECS = + private val APP_DATA_RETENTION_SECONDS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) private val testing: Boolean = sys.props.contains("spark.testing") - private var master: ActorSelection = null - private var masterAddress: Address = null + private var master: Option[RpcEndpointRef] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" - private val akkaUrl = AkkaUtils.address( - AkkaUtils.protocol(context.system), - actorSystemName, - host, - port, - actorName) - @volatile private var registered = false - @volatile private var connected = false + private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName) + private var registered = false + private var connected = false private val workerId = generateWorkerId() private val sparkHome = if (testing) { @@ -136,7 +136,18 @@ private[worker] class Worker( private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) private val workerSource = new WorkerSource(this) - private var registrationRetryTimer: Option[Cancellable] = None + private var registerMasterFutures: Array[JFuture[_]] = null + private var registrationRetryTimer: Option[JScheduledFuture[_]] = None + + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("worker-register-master-threadpool")) var coresUsed = 0 var memoryUsed = 0 @@ -162,14 +173,13 @@ private[worker] class Worker( } } - override def preStart() { + override def onStart() { assert(!registered) logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( host, port, cores, Utils.megabytesToString(memory))) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") logInfo("Spark home: " + sparkHome) createWorkDir() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() @@ -181,24 +191,32 @@ private[worker] class Worker( metricsSystem.getServletHandlers.foreach(webUi.attachHandler) } - private def changeMaster(url: String, uiUrl: String) { + private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) { // activeMasterUrl it's a valid Spark url since we receive it from master. - activeMasterUrl = url + activeMasterUrl = masterRef.address.toSparkURL activeMasterWebUiUrl = uiUrl - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system)) + master = Some(masterRef) connected = true // Cancel any outstanding re-registration attempts because we found a new master - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } - private def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + masterRpcAddresses.map { masterAddress => + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + }) } } @@ -211,8 +229,7 @@ private[worker] class Worker( Utils.tryOrExit { connectionAttemptCount += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") /** @@ -235,21 +252,48 @@ private[worker] class Worker( * still not safe if the old master recovers within this interval, but this is a much * less likely scenario. */ - if (master != null) { - master ! RegisterWorker( - workerId, host, port, cores, memory, webUi.boundPort, publicAddress) - } else { - // We are retrying the initial registration - tryRegisterAllMasters() + master match { + case Some(masterRef) => + // registered == false && master != None means we lost the connection to master, so + // masterRef cannot be used and we need to recreate it again. Note: we must not set + // master to None due to the above comments. + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + val masterAddress = masterRef.address + registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + })) + case None => + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + // We are retrying the initial registration + registerMasterFutures = tryRegisterAllMasters() } // We have exceeded the initial registration retry threshold // All retries from now on should use a higher interval if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = Some { - context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, - PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = Some( + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) } } else { logError("All masters are unresponsive! Giving up.") @@ -258,41 +302,67 @@ private[worker] class Worker( } } + /** + * Cancel last registeration retry, or do nothing if no retry + */ + private def cancelLastRegistrationRetry(): Unit = { + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterFutures = null + } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = None + } + private def registerWithMaster() { - // DisassociatedEvent may be triggered multiple times, so don't attempt registration + // onDisconnected may be triggered multiple times, so don't attempt registration // if there are outstanding registration attempts scheduled. registrationRetryTimer match { case None => registered = false - tryRegisterAllMasters() + registerMasterFutures = tryRegisterAllMasters() connectionAttemptCount = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, - INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate( + new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) case Some(_) => logInfo("Not spawning another attempt to register with the master, since there is an" + " attempt scheduled already.") } } - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisteredWorker(masterUrl, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterUrl) + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(masterRef, masterWebUiUrl) => + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) registered = true - changeMaster(masterUrl, masterWebUiUrl) - context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat) + changeMaster(masterRef, masterWebUiUrl) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(SendHeartbeat) + } + }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) if (CLEANUP_ENABLED) { logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") - context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis, - CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(WorkDirCleanup) + } + }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) } case SendHeartbeat => - if (connected) { master ! Heartbeat(workerId) } + if (connected) { sendToMaster(Heartbeat(workerId, self)) } case WorkDirCleanup => // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor + // Copy ids so that it can be used in the cleanup thread. + val appIds = executors.values.map(_.appId).toSet val cleanupFuture = concurrent.future { val appDirs = workDir.listFiles() if (appDirs == null) { @@ -302,27 +372,27 @@ private[worker] class Worker( // the directory is used by an application - check that the application is not running // when cleaning up val appIdFromDir = dir.getName - val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir) + val isAppStillRunning = appIds.contains(appIdFromDir) dir.isDirectory && !isAppStillRunning && - !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS) + !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECONDS) }.foreach { dir => logInfo(s"Removing directory: ${dir.getPath}") Utils.deleteRecursively(dir) } - } + }(cleanupThreadExecutor) - cleanupFuture onFailure { + cleanupFuture.onFailure { case e: Throwable => logError("App dir cleanup failed: " + e.getMessage, e) - } + }(cleanupThreadExecutor) - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl, masterWebUiUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + changeMaster(masterRef, masterWebUiUrl) val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) - sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq) + masterRef.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)) case RegisterWorkerFailed(message) => if (!registered) { @@ -369,14 +439,14 @@ private[worker] class Worker( publicAddress, sparkHome, executorDir, - akkaUrl, + workerUri, conf, appLocalDirs, ExecutorState.LOADING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ memoryUsed += memory_ - master ! ExecutorStateChanged(appId, execId, manager.state, None, None) + sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None)) } catch { case e: Exception => { logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) @@ -384,14 +454,14 @@ private[worker] class Worker( executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, - Some(e.toString), None) + sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, + Some(e.toString), None)) } } } - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) + case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) => + sendToMaster(executorStateChanged) val fullId = appId + "/" + execId if (ExecutorState.isFinished(state)) { executors.get(fullId) match { @@ -434,7 +504,7 @@ private[worker] class Worker( sparkHome, driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)), self, - akkaUrl, + workerUri, securityMgr) drivers(driverId) = driver driver.start() @@ -453,7 +523,7 @@ private[worker] class Worker( } } - case DriverStateChanged(driverId, state, exception) => { + case driverStageChanged @ DriverStateChanged(driverId, state, exception) => { state match { case DriverState.ERROR => logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") @@ -466,23 +536,13 @@ private[worker] class Worker( case _ => logDebug(s"Driver $driverId changed state to $state") } - master ! DriverStateChanged(driverId, state, exception) + sendToMaster(driverStageChanged) val driver = drivers.remove(driverId).get finishedDrivers(driverId) = driver memoryUsed -= driver.driverDesc.mem coresUsed -= driver.driverDesc.cores } - case x: DisassociatedEvent if x.remoteAddress == masterAddress => - logInfo(s"$x Disassociated !") - masterDisconnected() - - case RequestWorkerState => - sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, - finishedExecutors.values.toList, drivers.values.toList, - finishedDrivers.values.toList, activeMasterUrl, cores, memory, - coresUsed, memoryUsed, activeMasterWebUiUrl) - case ReregisterWithMaster => reregisterWithMaster() @@ -491,6 +551,21 @@ private[worker] class Worker( maybeCleanupApplication(id) } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestWorkerState => + context.reply(WorkerStateResponse(host, port, workerId, executors.values.toList, + finishedExecutors.values.toList, drivers.values.toList, + finishedDrivers.values.toList, activeMasterUrl, cores, memory, + coresUsed, memoryUsed, activeMasterWebUiUrl)) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (master.exists(_.address == remoteAddress)) { + logInfo(s"$remoteAddress Disassociated !") + masterDisconnected() + } + } + private def masterDisconnected() { logError("Connection to master failed! Waiting for master to reconnect...") connected = false @@ -510,13 +585,29 @@ private[worker] class Worker( } } + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => + logWarning( + s"Dropping $message because the connection to master has not yet been established") + } + } + private def generateWorkerId(): String = { "worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port) } - override def postStop() { + override def onStop() { + cleanupThreadExecutor.shutdownNow() metricsSystem.report() - registrationRetryTimer.foreach(_.cancel()) + cancelLastRegistrationRetry() + forwordMessageScheduler.shutdownNow() + registerMasterThreadPool.shutdownNow() executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) shuffleService.stop() @@ -530,12 +621,12 @@ private[deploy] object Worker extends Logging { SignalLogger.register(log) val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) - val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, + val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, @@ -544,18 +635,17 @@ private[deploy] object Worker extends Logging { masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None, - conf: SparkConf = new SparkConf): (ActorSystem, Int) = { + conf: SparkConf = new SparkConf): RpcEnv = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf, securityManager = securityMgr) - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) - actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) - (actorSystem, boundPort) + val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr) + val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) + rpcEnv.setupEndpoint(actorName, new Worker(rpcEnv, webUiPort, cores, memory, masterAddresses, + systemName, actorName, workDir, conf, securityMgr)) + rpcEnv } def isUseLocalNodeSSLConfig(cmd: Command): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 9678631da9f6f..e89d076802215 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -121,6 +121,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { * Print usage and exit JVM with the given exit code. */ def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Worker [options] \n" + "\n" + @@ -136,6 +137,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8081)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } @@ -160,11 +162,13 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { } catch { case e: Exception => { totalMb = 2*1024 + // scalastyle:off println System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") + // scalastyle:on println } } // Leave out 1 GB for the operating system, but don't return a negative memory size - math.max(totalMb - 1024, 512) + math.max(totalMb - 1024, Utils.DEFAULT_DRIVER_MEM_MB) } def checkWorkerMemory(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 83fb991891a41..fae5640b9a213 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.worker import org.apache.spark.Logging -import org.apache.spark.deploy.DeployMessages.SendHeartbeat import org.apache.spark.rpc._ /** diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 9f9f27d71e1ae..fd905feb97e92 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -17,10 +17,8 @@ package org.apache.spark.deploy.worker.ui -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import javax.servlet.http.HttpServletRequest import org.json4s.JValue @@ -32,18 +30,15 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { - private val workerActor = parent.worker.self - private val timeout = parent.timeout + private val workerEndpoint = parent.worker.self override def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) JsonProtocol.writeWorkerState(workerState) } def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") val runningExecutors = workerState.executors diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index b3bb5f911dbd7..334a5b10142aa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -38,7 +38,7 @@ class WorkerWebUI( extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { - private[ui] val timeout = RpcUtils.askTimeout(worker.conf) + private[ui] val timeout = RpcUtils.askRpcTimeout(worker.conf) initialize() diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index f3a26f54a81fb..fcd76ec52742a 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -66,7 +66,10 @@ private[spark] class CoarseGrainedExecutorBackend( case Success(msg) => Utils.tryLogNonFatalError { Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor } - case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) + case Failure(e) => { + logError(s"Cannot register with driver: $driverUrl", e) + System.exit(1) + } }(ThreadUtils.sameThread) } @@ -232,7 +235,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { argv = tail case Nil => case tail => + // scalastyle:off println System.err.println(s"Unrecognized options: ${tail.mkString(" ")}") + // scalastyle:on println printUsageAndExit() } } @@ -246,6 +251,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } private def printUsageAndExit() = { + // scalastyle:off println System.err.println( """ |"Usage: CoarseGrainedExecutorBackend [options] @@ -259,6 +265,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { | --worker-url | --user-class-path |""".stripMargin) + // scalastyle:on println System.exit(1) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 8f916e0502ecb..f7ef92bc80f91 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -443,7 +443,7 @@ private[spark] class Executor( try { val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message) if (response.reregisterBlockManager) { - logWarning("Told to re-register on heartbeat") + logInfo("Told to re-register on heartbeat") env.blockManager.reregister() } } catch { diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala index c219d21fbefa9..532850dd57716 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -21,6 +21,8 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} + +import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil /** @@ -39,7 +41,8 @@ private[spark] object FixedLengthBinaryInputFormat { } private[spark] class FixedLengthBinaryInputFormat - extends FileInputFormat[LongWritable, BytesWritable] { + extends FileInputFormat[LongWritable, BytesWritable] + with Logging { private var recordLength = -1 @@ -51,7 +54,7 @@ private[spark] class FixedLengthBinaryInputFormat recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) } if (recordLength <= 0) { - println("record length is less than 0, file cannot be split") + logDebug("record length is less than 0, file cannot be split") false } else { true diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index 67a376102994c..79cb0640c8672 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -57,16 +57,6 @@ private[nio] class BlockMessage() { } def set(buffer: ByteBuffer) { - /* - println() - println("BlockMessage: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ typ = buffer.getInt() val idLength = buffer.getInt() val idBuilder = new StringBuilder(idLength) @@ -138,18 +128,6 @@ private[nio] class BlockMessage() { buffers += data } - /* - println() - println("BlockMessage: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ Message.createBufferMessage(buffers) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index 7d0806f0c2580..f1c9ea8b64ca3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -43,16 +43,6 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) val newBlockMessages = new ArrayBuffer[BlockMessage]() val buffer = bufferMessage.buffers(0) buffer.clear() - /* - println() - println("BlockMessageArray: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ while (buffer.remaining() > 0) { val size = buffer.getInt() logDebug("Creating block message of size " + size + " bytes") @@ -86,23 +76,11 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) logDebug("Buffer list:") buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - /* - println() - println("BlockMessageArray: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ Message.createBufferMessage(buffers) } } -private[nio] object BlockMessageArray { +private[nio] object BlockMessageArray extends Logging { def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { val newBlockMessageArray = new BlockMessageArray() @@ -123,10 +101,10 @@ private[nio] object BlockMessageArray { } } val blockMessageArray = new BlockMessageArray(blockMessages) - println("Block message array created") + logDebug("Block message array created") val bufferMessage = blockMessageArray.toBufferMessage - println("Converted to buffer message") + logDebug("Converted to buffer message") val totalSize = bufferMessage.size val newBuffer = ByteBuffer.allocate(totalSize) @@ -138,10 +116,11 @@ private[nio] object BlockMessageArray { }) newBuffer.flip val newBufferMessage = Message.createBufferMessage(newBuffer) - println("Copied to new buffer message, size = " + newBufferMessage.size) + logDebug("Copied to new buffer message, size = " + newBufferMessage.size) val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - println("Converted back to block message array") + logDebug("Converted back to block message array") + // scalastyle:off println newBlockMessageArray.foreach(blockMessage => { blockMessage.getType match { case BlockMessage.TYPE_PUT_BLOCK => { @@ -154,6 +133,7 @@ private[nio] object BlockMessageArray { } } }) + // scalastyle:on println } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index c0bca2c4bc994..9143918790381 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -1016,7 +1016,9 @@ private[spark] object ConnectionManager { val conf = new SparkConf val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + // scalastyle:off println println("Received [" + msg + "] from [" + id + "]") + // scalastyle:on println None }) @@ -1033,6 +1035,7 @@ private[spark] object ConnectionManager { System.gc() } + // scalastyle:off println def testSequentialSending(manager: ConnectionManager) { println("--------------------------") println("Sequential Sending") @@ -1150,4 +1153,5 @@ private[spark] object ConnectionManager { println() } } + // scalastyle:on println } diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 33e6998b2cb10..e17bd47905d7a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -28,7 +28,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.{SerializableConfiguration, Utils} -private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} +private[spark] class CheckpointRDDPartition(val index: Int) extends Partition /** * This RDD represents a RDD checkpoint file (similar to HadoopRDD). @@ -37,9 +37,11 @@ private[spark] class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { - val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration)) + private val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration)) - @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) + @transient private val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) + + override def getCheckpointFile: Option[String] = Some(checkpointPath) override def getPartitions: Array[Partition] = { val cpath = new Path(checkpointPath) @@ -59,9 +61,6 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) } - checkpointData = Some(new RDDCheckpointData[T](this)) - checkpointData.get.cpFile = Some(checkpointPath) - override def getPreferredLocations(split: Partition): Seq[String] = { val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))) @@ -74,9 +73,9 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) CheckpointRDD.readFromFile(file, broadcastedConf, context) } - override def checkpoint() { - // Do nothing. CheckpointRDD should not be checkpointed. - } + // CheckpointRDD should not be checkpointed again + override def checkpoint(): Unit = { } + override def doCheckpoint(): Unit = { } } private[spark] object CheckpointRDD extends Logging { diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index dc60d48927624..defdabf95ac4b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -123,7 +123,9 @@ private[spark] class PipedRDD[T: ClassTag]( new Thread("stderr reader for " + command) { override def run() { for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { + // scalastyle:off println System.err.println(line) + // scalastyle:on println } } }.start() @@ -133,6 +135,7 @@ private[spark] class PipedRDD[T: ClassTag]( override def run() { val out = new PrintWriter(proc.getOutputStream) + // scalastyle:off println // input the pipe context firstly if (printPipeContext != null) { printPipeContext(out.println(_)) @@ -144,6 +147,7 @@ private[spark] class PipedRDD[T: ClassTag]( out.println(elem) } } + // scalastyle:on println out.close() } }.start() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 10610f4b6f1ff..9f7ebae3e9af3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -194,7 +194,7 @@ abstract class RDD[T: ClassTag]( @transient private var partitions_ : Array[Partition] = null /** An Option holding our checkpoint RDD, if we are checkpointed */ - private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) + private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD) /** * Get the list of dependencies of this RDD, taking into account whether the @@ -890,6 +890,10 @@ abstract class RDD[T: ClassTag]( * Return an iterator that contains all of the elements in this RDD. * * The iterator will consume as much memory as the largest partition in this RDD. + * + * Note: this results in multiple Spark jobs, and if the input RDD is the result + * of a wide transformation (e.g. join with different partitioners), to avoid + * recomputing the input RDD should be cached first. */ def toLocalIterator: Iterator[T] = withScope { def collectPartition(p: Int): Array[T] = { @@ -1447,12 +1451,16 @@ abstract class RDD[T: ClassTag]( * executed on this RDD. It is strongly recommended that this RDD is persisted in * memory, otherwise saving it on a file will require recomputation. */ - def checkpoint() { + def checkpoint(): Unit = { if (context.checkpointDir.isEmpty) { throw new SparkException("Checkpoint directory has not been set in the SparkContext") } else if (checkpointData.isEmpty) { - checkpointData = Some(new RDDCheckpointData(this)) - checkpointData.get.markForCheckpoint() + // NOTE: we use a global lock here due to complexities downstream with ensuring + // children RDD partitions point to the correct parent partitions. In the future + // we should revisit this consideration. + RDDCheckpointData.synchronized { + checkpointData = Some(new RDDCheckpointData(this)) + } } } @@ -1493,7 +1501,7 @@ abstract class RDD[T: ClassTag]( private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassTag] = { + protected[spark] def firstParent[U: ClassTag]: RDD[U] = { dependencies.head.rdd.asInstanceOf[RDD[U]] } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index acbd31aacdf59..4f954363bed8e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -22,16 +22,15 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.Path import org.apache.spark._ -import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} import org.apache.spark.util.SerializableConfiguration /** * Enumeration to manage state transitions of an RDD through checkpointing - * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] + * [ Initialized --> checkpointing in progress --> checkpointed ]. */ private[spark] object CheckpointState extends Enumeration { type CheckpointState = Value - val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value + val Initialized, CheckpointingInProgress, Checkpointed = Value } /** @@ -46,37 +45,37 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) import CheckpointState._ // The checkpoint state of the associated RDD. - var cpState = Initialized + private var cpState = Initialized // The file to which the associated RDD has been checkpointed to - @transient var cpFile: Option[String] = None + private var cpFile: Option[String] = None // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. - var cpRDD: Option[RDD[T]] = None + // This is defined if and only if `cpState` is `Checkpointed`. + private var cpRDD: Option[CheckpointRDD[T]] = None - // Mark the RDD for checkpointing - def markForCheckpoint() { - RDDCheckpointData.synchronized { - if (cpState == Initialized) cpState = MarkedForCheckpoint - } - } + // TODO: are we sure we need to use a global lock in the following methods? // Is the RDD already checkpointed - def isCheckpointed: Boolean = { - RDDCheckpointData.synchronized { cpState == Checkpointed } + def isCheckpointed: Boolean = RDDCheckpointData.synchronized { + cpState == Checkpointed } // Get the file to which this RDD was checkpointed to as an Option - def getCheckpointFile: Option[String] = { - RDDCheckpointData.synchronized { cpFile } + def getCheckpointFile: Option[String] = RDDCheckpointData.synchronized { + cpFile } - // Do the checkpointing of the RDD. Called after the first job using that RDD is over. - def doCheckpoint() { - // If it is marked for checkpointing AND checkpointing is not already in progress, - // then set it to be in progress, else return + /** + * Materialize this RDD and write its content to a reliable DFS. + * This is called immediately after the first action invoked on this RDD has completed. + */ + def doCheckpoint(): Unit = { + + // Guard against multiple threads checkpointing the same RDD by + // atomically flipping the state of this RDDCheckpointData RDDCheckpointData.synchronized { - if (cpState == MarkedForCheckpoint) { + if (cpState == Initialized) { cpState = CheckpointingInProgress } else { return @@ -87,7 +86,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get val fs = path.getFileSystem(rdd.context.hadoopConfiguration) if (!fs.mkdirs(path)) { - throw new SparkException("Failed to create checkpoint path " + path) + throw new SparkException(s"Failed to create checkpoint path $path") } // Save to file, and reload it as an RDD @@ -99,6 +98,8 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) } } + + // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( @@ -113,34 +114,26 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed } - logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) - } - - // Get preferred location of a split after checkpointing - def getPreferredLocations(split: Partition): Seq[String] = { - RDDCheckpointData.synchronized { - cpRDD.get.preferredLocations(split) - } + logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}") } - def getPartitions: Array[Partition] = { - RDDCheckpointData.synchronized { - cpRDD.get.partitions - } + def getPartitions: Array[Partition] = RDDCheckpointData.synchronized { + cpRDD.get.partitions } - def checkpointRDD: Option[RDD[T]] = { - RDDCheckpointData.synchronized { - cpRDD - } + def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized { + cpRDD } } private[spark] object RDDCheckpointData { + + /** Return the path of the directory to which this RDD's checkpoint data is written. */ def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = { - sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) } + sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") } } + /** Clean up the files associated with the checkpoint data for this RDD. */ def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = { rddCheckpointDataPath(sc, rddId).foreach { path => val fs = path.getFileSystem(sc.hadoopConfiguration) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 69181edb9ad44..6ae47894598be 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -17,8 +17,7 @@ package org.apache.spark.rpc -import scala.concurrent.{Await, Future} -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.Future import scala.reflect.ClassTag import org.apache.spark.util.RpcUtils @@ -32,7 +31,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) private[this] val maxRetries = RpcUtils.numRetries(conf) private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) - private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) + private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) /** * return the address for the [[RpcEndpointRef]] @@ -52,7 +51,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * * This method only sends the message once and never retries. */ - def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to @@ -91,7 +90,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = { // TODO: Consider removing multiple attempts var attempts = 0 var lastException: Exception = null @@ -99,7 +98,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) attempts += 1 try { val future = ask[T](message, timeout) - val result = Await.result(future, timeout) + val result = timeout.awaitResult(future) if (result == null) { throw new SparkException("Actor returned null") } @@ -110,10 +109,14 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) lastException = e logWarning(s"Error sending message [message = $message] in $attempts attempts", e) } - Thread.sleep(retryWaitMs) + + if (attempts < maxRetries) { + Thread.sleep(retryWaitMs) + } } throw new SparkException( s"Error sending message [message = $message]", lastException) } + } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 12b6b28d4d7ec..1709bdf560b6f 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -18,8 +18,10 @@ package org.apache.spark.rpc import java.net.URI +import java.util.concurrent.TimeoutException -import scala.concurrent.{Await, Future} +import scala.concurrent.{Awaitable, Await, Future} +import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.{SecurityManager, SparkConf} @@ -66,7 +68,7 @@ private[spark] object RpcEnv { */ private[spark] abstract class RpcEnv(conf: SparkConf) { - private[spark] val defaultLookupTimeout = RpcUtils.lookupTimeout(conf) + private[spark] val defaultLookupTimeout = RpcUtils.lookupRpcTimeout(conf) /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement @@ -94,7 +96,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action. */ def setupEndpointRefByURI(uri: String): RpcEndpointRef = { - Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout) + defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri)) } /** @@ -158,6 +160,8 @@ private[spark] case class RpcAddress(host: String, port: Int) { val hostPort: String = host + ":" + port override val toString: String = hostPort + + def toSparkURL: String = "spark://" + hostPort } @@ -182,3 +186,107 @@ private[spark] object RpcAddress { RpcAddress(host, port) } } + + +/** + * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + */ +private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) + extends TimeoutException(message) { initCause(cause) } + + +/** + * Associates a timeout with a description so that a when a TimeoutException occurs, additional + * context about the timeout can be amended to the exception message. + * @param duration timeout duration in seconds + * @param timeoutProp the configuration property that controls this timeout + */ +private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) + extends Serializable { + + /** Amends the standard message of TimeoutException to include the description */ + private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { + new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te) + } + + /** + * PartialFunction to match a TimeoutException and add the timeout description to the message + * + * @note This can be used in the recover callback of a Future to add to a TimeoutException + * Example: + * val timeout = new RpcTimeout(5 millis, "short timeout") + * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) + */ + def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { + // The exception has already been converted to a RpcTimeoutException so just raise it + case rte: RpcTimeoutException => throw rte + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) + } + + /** + * Wait for the completed result and return it. If the result is not available within this + * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. + * @param awaitable the `Awaitable` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * is still not ready + */ + def awaitResult[T](awaitable: Awaitable[T]): T = { + try { + Await.result(awaitable, duration) + } catch addMessageIfTimeout + } +} + + +private[spark] object RpcTimeout { + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @throws NoSuchElementException if property is not set + */ + def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @param defaultValue default timeout value in seconds if property not found + */ + def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup prioritized list of timeout properties in the configuration + * and create a RpcTimeout with the first set property key in the + * description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutPropList prioritized list of property keys for the timeout in seconds + * @param defaultValue default timeout value in seconds if no properties found + */ + def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { + require(timeoutPropList.nonEmpty) + + // Find the first set property or use the default value with the first property + val itr = timeoutPropList.iterator + var foundProp: Option[(String, String)] = None + while (itr.hasNext && foundProp.isEmpty){ + val propKey = itr.next() + conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } + } + val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) + val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds } + new RpcTimeout(timeout, finalProp._1) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 0161962cde073..f2d87f68341af 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -20,7 +20,6 @@ package org.apache.spark.rpc.akka import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future -import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -180,10 +179,10 @@ private[spark] class AkkaRpcEnv private[akka] ( }) } catch { case NonFatal(e) => - if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - _sender ! AkkaFailure(e) - } else { + _sender ! AkkaFailure(e) + if (!needReply) { + // If the sender does not require a reply, it may not handle the exception. So we rethrow + // "e" to make sure it will be processed. throw e } } @@ -214,8 +213,11 @@ private[spark] class AkkaRpcEnv private[akka] ( override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { import actorSystem.dispatcher - actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout). - map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)). + // this is just in case there is a timeout from creating the future in resolveOne, we want the + // exception to indicate the conf that determines the timeout + recover(defaultLookupTimeout.addMessageIfTimeout) } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { @@ -295,8 +297,8 @@ private[akka] class AkkaRpcEndpointRef( actorRef ! AkkaMessage(message, false) } - override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { - actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { + actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { // The function will run in the calling thread, so it should be short and never block. case msg @ AkkaMessage(message, reply) => if (reply) { @@ -307,7 +309,8 @@ private[akka] class AkkaRpcEndpointRef( } case AkkaFailure(e) => Future.failed(e) - }(ThreadUtils.sameThread).mapTo[T] + }(ThreadUtils.sameThread).mapTo[T]. + recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } override def toString: String = s"${getClass.getSimpleName}($actorRef)" diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c6029675eab0e..11b12edf7eaf1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -35,6 +35,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ @@ -188,7 +189,7 @@ class DAGScheduler( blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( - BlockManagerHeartbeat(blockManagerId), 600 seconds) + BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } // Called by TaskScheduler when an executor fails. @@ -870,7 +871,7 @@ class DAGScheduler( // serializable. If tasks are not serializable, a SparkListenerStageCompleted event // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. - stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size)) + stage.makeNewStageAttempt(partitionsToCompute.size) outputCommitCoordinator.stageStart(stage.id) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) @@ -935,8 +936,8 @@ class DAGScheduler( logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") stage.pendingTasks ++= tasks logDebug("New pending tasks: " + stage.pendingTasks) - taskScheduler.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.firstJobId, properties)) + taskScheduler.submitTasks(new TaskSet( + tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 529a5b2bf1a0d..62b05033a9281 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -140,7 +140,9 @@ private[spark] class EventLoggingListener( /** Log the event as JSON. */ private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) { val eventJson = JsonProtocol.sparkEventToJson(event) + // scalastyle:off println writer.foreach(_.println(compact(render(eventJson)))) + // scalastyle:on println if (flushLogger) { writer.foreach(_.flush()) hadoopDataStream.foreach(hadoopFlushMethod.invoke(_)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index e55b76c36cc5f..f96eb8ca0ae00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -125,7 +125,9 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener val date = new Date(System.currentTimeMillis()) writeInfo = dateFormat.get.format(date) + ": " + info } + // scalastyle:off println jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo)) + // scalastyle:on println } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 14ab2b86e1b77..b86724de2cb73 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -62,28 +62,28 @@ private[spark] abstract class Stage( var pendingTasks = new HashSet[Task[_]] + /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 val name = callSite.shortForm val details = callSite.longForm - /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */ - var latestInfo: StageInfo = StageInfo.fromStage(this) + /** + * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized + * here, before any attempts have actually been created, because the DAGScheduler uses this + * StageInfo to tell SparkListeners when a job starts (which happens before any stage attempts + * have been created). + */ + private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) - /** Return a new attempt id, starting with 0. */ - def newAttemptId(): Int = { - val id = nextAttemptId + /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ + def makeNewStageAttempt(numPartitionsToCompute: Int): Unit = { + _latestInfo = StageInfo.fromStage(this, nextAttemptId, Some(numPartitionsToCompute)) nextAttemptId += 1 - id } - /** - * The id for the **next** stage attempt. - * - * The unusual meaning of this method means its unlikely to hold the value you are interested in - * -- you probably want to use [[latestInfo.attemptId]] - */ - private[spark] def attemptId: Int = nextAttemptId + /** Returns the StageInfo for the most recent attempt for this stage. */ + def latestInfo: StageInfo = _latestInfo override final def hashCode(): Int = id override final def equals(other: Any): Boolean = other match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index e439d2a7e1229..5d2abbc67e9d9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -70,12 +70,12 @@ private[spark] object StageInfo { * shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a * sequence of narrow dependencies should also be associated with this Stage. */ - def fromStage(stage: Stage, numTasks: Option[Int] = None): StageInfo = { + def fromStage(stage: Stage, attemptId: Int, numTasks: Option[Int] = None): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos new StageInfo( stage.id, - stage.attemptId, + attemptId, stage.name, numTasks.getOrElse(stage.numTasks), rddInfos, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ccf1dc5af6120..687ae9620460f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -85,7 +85,7 @@ private[spark] class SparkDeploySchedulerBackend( val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) - client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) + client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() waitForRegistration() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 190ff61d689d1..bc67abb5df446 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -46,7 +46,7 @@ private[spark] abstract class YarnSchedulerBackend( private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) - private implicit val askTimeout = RpcUtils.askTimeout(sc.conf) + private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) /** * Request executors from the ApplicationMaster by specifying the total number desired. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 6b8edca5aa485..cbade131494bc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,18 +18,21 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{Collections, List => JList} +import java.util.{List => JList, Collections} +import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import com.google.common.collect.HashBiMap import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -60,12 +63,34 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveIdsWithExecutors = new HashSet[String] - val taskIdToSlaveId = new HashMap[Int, String] - val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed + val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] + // How many times tasks on each slave failed + val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] + + /** + * The total number of executors we aim to have. Undefined when not using dynamic allocation + * and before the ExecutorAllocatorManager calls [[doRequesTotalExecutors]]. + */ + private var executorLimitOption: Option[Int] = None + + /** + * Return the current executor limit, which may be [[Int.MaxValue]] + * before properly initialized. + */ + private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue) + + private val pendingRemovedSlaveIds = new HashSet[String] + // private lock object protecting mutable state above. Using the intrinsic lock + // may lead to deadlocks since the superclass might also try to lock + private val stateLock = new ReentrantLock val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) + // Offer constraints + private val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + var nextMesosTaskId = 0 @volatile var appId: String = _ @@ -82,7 +107,7 @@ private[spark] class CoarseMesosSchedulerBackend( startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo) } - def createCommand(offer: Offer, numCores: Int): CommandInfo = { + def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = { val executorSparkHome = conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) .getOrElse { @@ -116,10 +141,6 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = sc.env.rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.getOption("spark.executor.uri") .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) @@ -129,7 +150,7 @@ private[spark] class CoarseMesosSchedulerBackend( command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + - s" --driver-url $driverUrl" + + s" --driver-url $driverURL" + s" --executor-id ${offer.getSlaveId.getValue}" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + @@ -138,11 +159,12 @@ private[spark] class CoarseMesosSchedulerBackend( // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.get.split('/').last.split('.').head + val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString) command.setValue( s"cd $basename*; $prefixEnv " + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + - s" --driver-url $driverUrl" + - s" --executor-id ${offer.getSlaveId.getValue}" + + s" --driver-url $driverURL" + + s" --executor-id $executorId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -151,6 +173,17 @@ private[spark] class CoarseMesosSchedulerBackend( command.build() } + protected def driverURL: String = { + if (conf.contains("spark.testing")) { + "driverURL" + } else { + sc.env.rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + } + } + override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { @@ -168,15 +201,19 @@ private[spark] class CoarseMesosSchedulerBackend( * unless we've already launched more than we wanted to. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - synchronized { + stateLock.synchronized { val filters = Filters.newBuilder().setRefuseSeconds(5).build() - for (offer <- offers) { - val slaveId = offer.getSlaveId.toString + val offerAttributes = toAttributeMap(offer.getAttributesList) + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val slaveId = offer.getSlaveId.getValue val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && - mem >= MemoryUtils.calculateTotalMemory(sc) && + val id = offer.getId.getValue + if (taskIdToSlaveId.size < executorLimit && + totalCoresAcquired < maxCores && + meetsConstraints && + mem >= calculateTotalMemory(sc) && cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && !slaveIdsWithExecutors.contains(slaveId)) { @@ -190,42 +227,36 @@ private[spark] class CoarseMesosSchedulerBackend( val task = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", - MemoryUtils.calculateTotalMemory(sc))) + .addResources(createResource("mem", calculateTotalMemory(sc))) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder()) + .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder) } + // accept the offer and launch the task + logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") d.launchTasks( - Collections.singleton(offer.getId), Collections.singletonList(task.build()), filters) + Collections.singleton(offer.getId), + Collections.singleton(task.build()), filters) } else { - // Filter it out - d.launchTasks( - Collections.singleton(offer.getId), Collections.emptyList[MesosTaskInfo](), filters) + // Decline the offer + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.declineOffer(offer.getId) } } } } - /** Build a Mesos resource protobuf object */ - private def createResource(resourceName: String, quantity: Double): Protos.Resource = { - Resource.newBuilder() - .setName(resourceName) - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) - .build() - } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue.toInt val state = status.getState logInfo("Mesos task " + taskId + " is now " + state) - synchronized { + stateLock.synchronized { if (TaskState.isFinished(TaskState.fromMesos(state))) { val slaveId = taskIdToSlaveId(taskId) slaveIdsWithExecutors -= slaveId @@ -243,8 +274,9 @@ private[spark] class CoarseMesosSchedulerBackend( "is Spark installed on it?") } } + executorTerminated(d, slaveId, s"Executor finished with state $state") // In case we'd rejected everything before but have now lost a node - mesosDriver.reviveOffers() + d.reviveOffers() } } } @@ -263,18 +295,39 @@ private[spark] class CoarseMesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { - logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - if (slaveIdsWithExecutors.contains(slaveId.getValue)) { - // Note that the slave ID corresponds to the executor ID on that slave - slaveIdsWithExecutors -= slaveId.getValue - removeExecutor(slaveId.getValue, "Mesos slave lost") + /** + * Called when a slave is lost or a Mesos task finished. Update local view on + * what tasks are running and remove the terminated slave from the list of pending + * slave IDs that we might have asked to be killed. It also notifies the driver + * that an executor was removed. + */ + private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = { + stateLock.synchronized { + if (slaveIdsWithExecutors.contains(slaveId)) { + val slaveIdToTaskId = taskIdToSlaveId.inverse() + if (slaveIdToTaskId.contains(slaveId)) { + val taskId: Int = slaveIdToTaskId.get(slaveId) + taskIdToSlaveId.remove(taskId) + removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason) + } + // TODO: This assumes one Spark executor per Mesos slave, + // which may no longer be true after SPARK-5095 + pendingRemovedSlaveIds -= slaveId + slaveIdsWithExecutors -= slaveId } } } - override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { + private def sparkExecutorId(slaveId: String, taskId: String): String = { + s"$slaveId/$taskId" + } + + override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { + logInfo("Mesos slave lost: " + slaveId.getValue) + executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue) + } + + override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) slaveLost(d, s) } @@ -285,4 +338,34 @@ private[spark] class CoarseMesosSchedulerBackend( super.applicationId } + override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + // We don't truly know if we can fulfill the full amount of executors + // since at coarse grain it depends on the amount of slaves available. + logInfo("Capping the total amount of executors to " + requestedTotal) + executorLimitOption = Some(requestedTotal) + true + } + + override def doKillExecutors(executorIds: Seq[String]): Boolean = { + if (mesosDriver == null) { + logWarning("Asked to kill executors before the Mesos driver was started.") + return false + } + + val slaveIdToTaskId = taskIdToSlaveId.inverse() + for (executorId <- executorIds) { + val slaveId = executorId.split("/")(0) + if (slaveIdToTaskId.contains(slaveId)) { + mesosDriver.killTask( + TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) + pendingRemovedSlaveIds += slaveId + } else { + logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler") + } + } + // no need to adjust `executorLimitOption` since the AllocationManager already communicated + // the desired limit through a call to `doRequestTotalExecutors`. + // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] + true + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 1067a7f1caf4c..d3a20f822176e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -29,6 +29,7 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.{Scheduler, SchedulerDriver} + import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 49de85ef48ada..d72e2af456e15 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -23,14 +23,14 @@ import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkException, TaskState} /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -59,6 +59,10 @@ private[spark] class MesosSchedulerBackend( private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1) + // Offer constraints + private[this] val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + @volatile var appId: String = _ override def start() { @@ -71,8 +75,8 @@ private[spark] class MesosSchedulerBackend( val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility .getOrElse { - throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") - } + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } val environment = Environment.newBuilder() sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => environment.addVariables( @@ -115,14 +119,14 @@ private[spark] class MesosSchedulerBackend( .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Value.Scalar.newBuilder() - .setValue(mesosExecutorCores).build()) + .setValue(mesosExecutorCores).build()) .build() val memory = Resource.newBuilder() .setName("mem") .setType(Value.Type.SCALAR) .setScalar( Value.Scalar.newBuilder() - .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) + .setValue(calculateTotalMemory(sc)).build()) .build() val executorInfo = MesosExecutorInfo.newBuilder() .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) @@ -191,13 +195,31 @@ private[spark] class MesosSchedulerBackend( val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue - (mem >= MemoryUtils.calculateTotalMemory(sc) && - // need at least 1 for executor, 1 for task - cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)) || - (slaveIdsWithExecutors.contains(slaveId) && - cpus >= scheduler.CPUS_PER_TASK) + val offerAttributes = toAttributeMap(o.getAttributesList) + + // check if all constraints are satisfield + // 1. Attribute constraints + // 2. Memory requirements + // 3. CPU requirements - need at least 1 for executor, 1 for task + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) + val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) + + val meetsRequirements = + (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || + (slaveIdsWithExecutors.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) + + // add some debug messaging + val debugstr = if (meetsRequirements) "Accepting" else "Declining" + val id = o.getId.getValue + logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + + meetsRequirements } + // Decline offers we ruled out immediately + unUsableOffers.foreach(o => d.declineOffer(o.getId)) + val workerOffers = usableOffers.map { o => val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt @@ -223,15 +245,15 @@ private[spark] class MesosSchedulerBackend( val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty) acceptedOffers .foreach { offer => - offer.foreach { taskDesc => - val slaveId = taskDesc.executorId - slaveIdsWithExecutors += slaveId - slavesIdsOfAcceptedOffers += slaveId - taskIdToSlaveId(taskDesc.taskId) = slaveId - mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) - .add(createMesosTask(taskDesc, slaveId)) - } + offer.foreach { taskDesc => + val slaveId = taskDesc.executorId + slaveIdsWithExecutors += slaveId + slavesIdsOfAcceptedOffers += slaveId + taskIdToSlaveId(taskDesc.taskId) = slaveId + mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) + .add(createMesosTask(taskDesc, slaveId)) } + } // Reply to the offers val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? @@ -251,8 +273,6 @@ private[spark] class MesosSchedulerBackend( d.declineOffer(o.getId) } - // Decline offers we ruled out immediately - unUsableOffers.foreach(o => d.declineOffer(o.getId)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index d11228f3d016a..925702e63afd3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -17,14 +17,17 @@ package org.apache.spark.scheduler.cluster.mesos -import java.util.List +import java.util.{List => JList} import java.util.concurrent.CountDownLatch import scala.collection.JavaConversions._ +import scala.util.control.NonFatal -import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status} -import org.apache.mesos.{MesosSchedulerDriver, Scheduler} -import org.apache.spark.Logging +import com.google.common.base.Splitter +import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos} +import org.apache.mesos.Protos._ +import org.apache.mesos.protobuf.GeneratedMessage +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.util.Utils /** @@ -36,7 +39,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { private final val registerLatch = new CountDownLatch(1) // Driver for talking to Mesos - protected var mesosDriver: MesosSchedulerDriver = null + protected var mesosDriver: SchedulerDriver = null /** * Starts the MesosSchedulerDriver with the provided information. This method returns @@ -86,10 +89,150 @@ private[mesos] trait MesosSchedulerUtils extends Logging { /** * Get the amount of resources for the specified type from the resource list */ - protected def getResource(res: List[Resource], name: String): Double = { + protected def getResource(res: JList[Resource], name: String): Double = { for (r <- res if r.getName == name) { return r.getScalar.getValue } 0.0 } + + /** Helper method to get the key,value-set pair for a Mesos Attribute protobuf */ + protected def getAttribute(attr: Attribute): (String, Set[String]) = { + (attr.getName, attr.getText.getValue.split(',').toSet) + } + + + /** Build a Mesos resource protobuf object */ + protected def createResource(resourceName: String, quantity: Double): Protos.Resource = { + Resource.newBuilder() + .setName(resourceName) + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) + .build() + } + + /** + * Converts the attributes from the resource offer into a Map of name -> Attribute Value + * The attribute values are the mesos attribute types and they are + * @param offerAttributes + * @return + */ + protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { + offerAttributes.map(attr => { + val attrValue = attr.getType match { + case Value.Type.SCALAR => attr.getScalar + case Value.Type.RANGES => attr.getRanges + case Value.Type.SET => attr.getSet + case Value.Type.TEXT => attr.getText + } + (attr.getName, attrValue) + }).toMap + } + + + /** + * Match the requirements (if any) to the offer attributes. + * if attribute requirements are not specified - return true + * else if attribute is defined and no values are given, simple attribute presence is performed + * else if attribute name and value is specified, subset match is performed on slave attributes + */ + def matchesAttributeRequirements( + slaveOfferConstraints: Map[String, Set[String]], + offerAttributes: Map[String, GeneratedMessage]): Boolean = { + slaveOfferConstraints.forall { + // offer has the required attribute and subsumes the required values for that attribute + case (name, requiredValues) => + offerAttributes.get(name) match { + case None => false + case Some(_) if requiredValues.isEmpty => true // empty value matches presence + case Some(scalarValue: Value.Scalar) => + // check if provided values is less than equal to the offered values + requiredValues.map(_.toDouble).exists(_ <= scalarValue.getValue) + case Some(rangeValue: Value.Range) => + val offerRange = rangeValue.getBegin to rangeValue.getEnd + // Check if there is some required value that is between the ranges specified + // Note: We only support the ability to specify discrete values, in the future + // we may expand it to subsume ranges specified with a XX..YY value or something + // similar to that. + requiredValues.map(_.toLong).exists(offerRange.contains(_)) + case Some(offeredValue: Value.Set) => + // check if the specified required values is a subset of offered set + requiredValues.subsetOf(offeredValue.getItemList.toSet) + case Some(textValue: Value.Text) => + // check if the specified value is equal, if multiple values are specified + // we succeed if any of them match. + requiredValues.contains(textValue.getValue) + } + } + } + + /** + * Parses the attributes constraints provided to spark and build a matching data struct: + * Map[, Set[values-to-match]] + * The constraints are specified as ';' separated key-value pairs where keys and values + * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for + * multiple values (comma separated). For example: + * {{{ + * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") + * // would result in + * + * Map( + * "tachyon" -> Set("true"), + * "zone": -> Set("us-east-1a", "us-east-1b") + * ) + * }}} + * + * Mesos documentation: http://mesos.apache.org/documentation/attributes-resources/ + * https://github.com/apache/mesos/blob/master/src/common/values.cpp + * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp + * + * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated + * by ':') + * @return Map of constraints to match resources offers. + */ + def parseConstraintString(constraintsVal: String): Map[String, Set[String]] = { + /* + Based on mesos docs: + attributes : attribute ( ";" attribute )* + attribute : labelString ":" ( labelString | "," )+ + labelString : [a-zA-Z0-9_/.-] + */ + val splitter = Splitter.on(';').trimResults().withKeyValueSeparator(':') + // kv splitter + if (constraintsVal.isEmpty) { + Map() + } else { + try { + Map() ++ mapAsScalaMap(splitter.split(constraintsVal)).map { + case (k, v) => + if (v == null || v.isEmpty) { + (k, Set[String]()) + } else { + (k, v.split(',').toSet) + } + } + } catch { + case NonFatal(e) => + throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e) + } + } + } + + // These defaults copied from YARN + private val MEMORY_OVERHEAD_FRACTION = 0.10 + private val MEMORY_OVERHEAD_MINIMUM = 384 + + /** + * Return the amount of memory to allocate to each executor, taking into account + * container overheads. + * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value + * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM + * (whichever is larger) + */ + def calculateTotalMemory(sc: SparkContext): Int = { + sc.conf.getInt("spark.mesos.executor.memoryOverhead", + math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + + sc.executorMemory + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 3078a1b10be8b..776e5d330e3c7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler.local +import java.io.File +import java.net.URL import java.nio.ByteBuffer import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} @@ -40,6 +42,7 @@ private case class StopExecutor() */ private[spark] class LocalEndpoint( override val rpcEnv: RpcEnv, + userClassPath: Seq[URL], scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) @@ -51,7 +54,7 @@ private[spark] class LocalEndpoint( private val localExecutorHostname = "localhost" private val executor = new Executor( - localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) + localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true) override def receive: PartialFunction[Any, Unit] = { case ReviveOffers => @@ -97,10 +100,22 @@ private[spark] class LocalBackend( private val appId = "local-" + System.currentTimeMillis var localEndpoint: RpcEndpointRef = null + private val userClassPath = getUserClasspath(conf) + + /** + * Returns a list of URLs representing the user classpath. + * + * @param conf Spark configuration. + */ + def getUserClasspath(conf: SparkConf): Seq[URL] = { + val userClassPathStr = conf.getOption("spark.executor.extraClassPath") + userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL) + } override def start() { localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint( - "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores)) + "LocalBackendEndpoint", + new LocalEndpoint(SparkEnv.get.rpcEnv, userClassPath, scheduler, this, totalCores)) } override def stop() { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 7cdae22b0e253..f70f701494dbf 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -33,7 +33,7 @@ class BlockManagerMaster( isDriver: Boolean) extends Logging { - val timeout = RpcUtils.askTimeout(conf) + val timeout = RpcUtils.askRpcTimeout(conf) /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { @@ -106,7 +106,7 @@ class BlockManagerMaster( logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e) }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -118,7 +118,7 @@ class BlockManagerMaster( logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e) }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -132,7 +132,7 @@ class BlockManagerMaster( s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e) }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -176,8 +176,8 @@ class BlockManagerMaster( CanBuildFrom[Iterable[Future[Option[BlockStatus]]], Option[BlockStatus], Iterable[Option[BlockStatus]]]] - val blockStatus = Await.result( - Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread), timeout) + val blockStatus = timeout.awaitResult( + Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread)) if (blockStatus == null) { throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) } @@ -199,7 +199,7 @@ class BlockManagerMaster( askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg) - Await.result(future, timeout) + timeout.awaitResult(future) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 91ef86389a0c3..5f537692a16c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -124,10 +124,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon (blockId, getFile(blockId)) } + /** + * Create local directories for storing block data. These directories are + * located inside configured local directories and won't + * be deleted on JVM exit when using the external shuffle service. + */ private def createLocalDirs(conf: SparkConf): Array[File] = { - Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir => + Utils.getConfiguredLocalDirs(conf).flatMap { rootDir => try { val localDir = Utils.createDirectory(rootDir, "blockmgr") + Utils.chmod700(localDir) logInfo(s"Created local directory at $localDir") Some(localDir) } catch { diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 06e616220c706..c8356467fab87 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -68,7 +68,9 @@ private[spark] object JettyUtils extends Logging { response.setStatus(HttpServletResponse.SC_OK) val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + // scalastyle:off println response.getWriter.println(servletParams.extractFn(result)) + // scalastyle:on println } else { response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") @@ -210,10 +212,16 @@ private[spark] object JettyUtils extends Logging { conf: SparkConf, serverName: String = ""): ServerInfo = { - val collection = new ContextHandlerCollection - collection.setHandlers(handlers.toArray) addFilters(handlers, conf) + val collection = new ContextHandlerCollection + val gzipHandlers = handlers.map { h => + val gzipHandler = new GzipHandler + gzipHandler.setHandler(h) + gzipHandler + } + collection.setHandlers(gzipHandlers.toArray) + // Bind to the given port, or throw a java.net.BindException if the port is occupied def connect(currentPort: Int): (Server, Int) = { val server = new Server(new InetSocketAddress(hostName, currentPort)) diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index ba03acdb38cc5..5a8c2914314c2 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -38,9 +38,11 @@ private[spark] object UIWorkloadGenerator { def main(args: Array[String]) { if (args.length < 3) { + // scalastyle:off println println( - "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + + "Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + "[master] [FIFO|FAIR] [#job set (4 jobs per set)]") + // scalastyle:on println System.exit(1) } @@ -96,6 +98,7 @@ private[spark] object UIWorkloadGenerator { for ((desc, job) <- jobs) { new Thread { override def run() { + // scalastyle:off println try { setProperties(desc) job() @@ -106,6 +109,7 @@ private[spark] object UIWorkloadGenerator { } finally { barrier.release() } + // scalastyle:on println } }.start Thread.sleep(INTER_JOB_WAIT_MS) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 39583af14390d..a88fc4c37d3c9 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui.exec import scala.collection.mutable.HashMap -import org.apache.spark.{ExceptionFailure, SparkContext} +import org.apache.spark.{Resubmitted, ExceptionFailure, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.{StorageStatus, StorageStatusListener} @@ -92,15 +92,22 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp val info = taskEnd.taskInfo if (info != null) { val eid = info.executorId - executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1 - executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration taskEnd.reason match { + case Resubmitted => + // Note: For resubmitted tasks, we continue to use the metrics that belong to the + // first attempt of this task. This may not be 100% accurate because the first attempt + // could have failed half-way through. The correct fix would be to keep track of the + // metrics added by each attempt, but this is much more complicated. + return case e: ExceptionFailure => executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1 case _ => executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 } + executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1 + executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration + // Update shuffle read/write val metrics = taskEnd.taskMetrics if (metrics != null) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index e96bf49d0dd14..ff0a339a39c65 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -332,7 +332,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(serializationTimes) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => - getGettingResultTime(info).toDouble + getGettingResultTime(info, currentTime).toDouble } val gettingResultQuantiles = @@ -346,7 +346,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) => - getSchedulerDelay(info, metrics.get).toDouble + getSchedulerDelay(info, metrics.get, currentTime).toDouble } val schedulerDelayTitle = Scheduler Delay @@ -544,7 +544,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val serializationTimeProportion = toProportion(serializationTime) val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) val deserializationTimeProportion = toProportion(deserializationTime) - val gettingResultTime = getGettingResultTime(taskUIData.taskInfo) + val gettingResultTime = getGettingResultTime(taskUIData.taskInfo, currentTime) val gettingResultTimeProportion = toProportion(gettingResultTime) val schedulerDelay = totalExecutionTime - (executorComputingTime + shuffleReadTime + shuffleWriteTime + @@ -570,6 +570,35 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val index = taskInfo.index val attempt = taskInfo.attempt + + val svgTag = + if (totalExecutionTime == 0) { + // SPARK-8705: Avoid invalid attribute error in JavaScript if execution time is 0 + """""" + } else { + s""" + | + | + | + | + | + | + |""".stripMargin + } val timelineObject = s""" |{ @@ -595,32 +624,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { |
Shuffle Write Time: ${UIUtils.formatDuration(shuffleWriteTime)} |
Result Serialization Time: ${UIUtils.formatDuration(serializationTime)} |
Getting Result Time: ${UIUtils.formatDuration(gettingResultTime)}"> - | - | - | - | - | - | - | - |', + |$svgTag', |'start': new Date($launchTime), |'end': new Date($finishTime) |} - |""".stripMargin.replaceAll("\n", " ") + |""".stripMargin.replaceAll("""[\r\n]+""", " ") timelineObject }.mkString("[", ",", "]") @@ -677,11 +685,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { else metrics.map(_.executorRunTime).getOrElse(1L) val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") - val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L) + val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = getGettingResultTime(info) + val gettingResultTime = getGettingResultTime(info, currentTime) val maybeAccumulators = info.accumulables val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} @@ -844,32 +852,31 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {errorSummary}{details} } - private def getGettingResultTime(info: TaskInfo): Long = { - if (info.gettingResultTime > 0) { - if (info.finishTime > 0) { + private def getGettingResultTime(info: TaskInfo, currentTime: Long): Long = { + if (info.gettingResult) { + if (info.finished) { info.finishTime - info.gettingResultTime } else { // The task is still fetching the result. - System.currentTimeMillis - info.gettingResultTime + currentTime - info.gettingResultTime } } else { 0L } } - private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { - val totalExecutionTime = - if (info.gettingResult) { - info.gettingResultTime - info.launchTime - } else if (info.finished) { - info.finishTime - info.launchTime - } else { - 0 - } - val executorOverhead = (metrics.executorDeserializeTime + - metrics.resultSerializationTime) - math.max( - 0, - totalExecutionTime - metrics.executorRunTime - executorOverhead - getGettingResultTime(info)) + private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics, currentTime: Long): Long = { + if (info.finished) { + val totalExecutionTime = info.finishTime - info.launchTime + val executorOverhead = (metrics.executorDeserializeTime + + metrics.resultSerializationTime) + math.max( + 0, + totalExecutionTime - metrics.executorRunTime - executorOverhead - + getGettingResultTime(info, currentTime)) + } else { + // The task is still running and the metrics like executorRunTime are not available. + 0L + } } } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 96aa2fe164703..c179833e5b06a 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -18,8 +18,6 @@ package org.apache.spark.util import scala.collection.JavaConversions.mapAsJavaMap -import scala.concurrent.Await -import scala.concurrent.duration.FiniteDuration import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -28,6 +26,7 @@ import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException} +import org.apache.spark.rpc.RpcTimeout /** * Various utility classes for working with Akka. @@ -147,7 +146,7 @@ private[spark] object AkkaUtils extends Logging { def askWithReply[T]( message: Any, actor: ActorRef, - timeout: FiniteDuration): T = { + timeout: RpcTimeout): T = { askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout) } @@ -160,7 +159,7 @@ private[spark] object AkkaUtils extends Logging { actor: ActorRef, maxAttempts: Int, retryInterval: Long, - timeout: FiniteDuration): T = { + timeout: RpcTimeout): T = { // TODO: Consider removing multiple attempts if (actor == null) { throw new SparkException(s"Error sending message [message = $message]" + @@ -171,8 +170,8 @@ private[spark] object AkkaUtils extends Logging { while (attempts < maxAttempts) { attempts += 1 try { - val future = actor.ask(message)(timeout) - val result = Await.result(future, timeout) + val future = actor.ask(message)(timeout.duration) + val result = timeout.awaitResult(future) if (result == null) { throw new SparkException("Actor returned null") } @@ -198,9 +197,9 @@ private[spark] object AkkaUtils extends Logging { val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name) - val timeout = RpcUtils.lookupTimeout(conf) + val timeout = RpcUtils.lookupRpcTimeout(conf) logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) } def makeExecutorRef( @@ -212,9 +211,9 @@ private[spark] object AkkaUtils extends Logging { val executorActorSystemName = SparkEnv.executorActorSystemName Utils.checkHost(host, "Expected hostname") val url = address(protocol(actorSystem), executorActorSystemName, host, port, name) - val timeout = RpcUtils.lookupTimeout(conf) + val timeout = RpcUtils.lookupRpcTimeout(conf) logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) } def protocol(actorSystem: ActorSystem): String = { diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala index 1bab707235b89..950b69f7db641 100644 --- a/core/src/main/scala/org/apache/spark/util/Distribution.scala +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -52,9 +52,11 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va } def showQuantiles(out: PrintStream = System.out): Unit = { + // scalastyle:off println out.println("min\t25%\t50%\t75%\tmax") getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")} out.println + // scalastyle:on println } def statCounter: StatCounter = StatCounter(data.slice(startIdx, endIdx)) @@ -64,8 +66,10 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va * @param out */ def summary(out: PrintStream = System.out) { + // scalastyle:off println out.println(statCounter) showQuantiles(out) + // scalastyle:on println } } @@ -80,8 +84,10 @@ private[spark] object Distribution { } def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) { + // scalastyle:off println out.println("min\t25%\t50%\t75%\tmax") quantiles.foreach{q => out.print(q + "\t")} out.println + // scalastyle:on println } } diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index f16cc8e7e42c6..7578a3b1d85f2 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -17,11 +17,11 @@ package org.apache.spark.util -import scala.concurrent.duration._ +import scala.concurrent.duration.FiniteDuration import scala.language.postfixOps import org.apache.spark.{SparkEnv, SparkConf} -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout} object RpcUtils { @@ -47,14 +47,22 @@ object RpcUtils { } /** Returns the default Spark timeout to use for RPC ask operations. */ + private[spark] def askRpcTimeout(conf: SparkConf): RpcTimeout = { + RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s") + } + + @deprecated("use askRpcTimeout instead, this method was not intended to be public", "1.5.0") def askTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.askTimeout", - conf.get("spark.network.timeout", "120s")) seconds + askRpcTimeout(conf).duration } /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ + private[spark] def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { + RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s") + } + + @deprecated("use lookupRpcTimeout instead, this method was not intended to be public", "1.5.0") def lookupTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.lookupTimeout", - conf.get("spark.network.timeout", "120s")) seconds + lookupRpcTimeout(conf).duration } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 19157af5b6f4d..b6b932104a94d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -80,6 +80,12 @@ private[spark] object Utils extends Logging { */ val TEMP_DIR_SHUTDOWN_PRIORITY = 25 + /** + * Define a default value for driver memory here since this value is referenced across the code + * base and nearly all files already use Utils.scala + */ + val DEFAULT_DRIVER_MEM_MB = JavaUtils.DEFAULT_DRIVER_MEM_MB.toInt + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null @@ -727,7 +733,12 @@ private[spark] object Utils extends Logging { localRootDirs } - private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + /** + * Return the configured local directories where Spark can write files. This + * method does not create any directories on its own, it only encapsulates the + * logic of locating the local directories according to deployment mode. + */ + def getConfiguredLocalDirs(conf: SparkConf): Array[String] = { if (isRunningInYarnContainer(conf)) { // If we are in yarn mode, systems can have different disk layouts so we must set it // to what Yarn on this system said was available. Note this assumes that Yarn has @@ -743,27 +754,29 @@ private[spark] object Utils extends Logging { Option(conf.getenv("SPARK_LOCAL_DIRS")) .getOrElse(conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) .split(",") - .flatMap { root => - try { - val rootDir = new File(root) - if (rootDir.exists || rootDir.mkdirs()) { - val dir = createTempDir(root) - chmod700(dir) - Some(dir.getAbsolutePath) - } else { - logError(s"Failed to create dir in $root. Ignoring this directory.") - None - } - } catch { - case e: IOException => - logError(s"Failed to create local root dir in $root. Ignoring this directory.") - None - } - } - .toArray } } + private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + getConfiguredLocalDirs(conf).flatMap { root => + try { + val rootDir = new File(root) + if (rootDir.exists || rootDir.mkdirs()) { + val dir = createTempDir(root) + chmod700(dir) + Some(dir.getAbsolutePath) + } else { + logError(s"Failed to create dir in $root. Ignoring this directory.") + None + } + } catch { + case e: IOException => + logError(s"Failed to create local root dir in $root. Ignoring this directory.") + None + } + }.toArray + } + /** Get the Yarn approved local directories. */ private def getYarnLocalDirs(conf: SparkConf): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the @@ -2333,3 +2346,36 @@ private[spark] class RedirectThread( } } } + +/** + * An [[OutputStream]] that will store the last 10 kilobytes (by default) written to it + * in a circular buffer. The current contents of the buffer can be accessed using + * the toString method. + */ +private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream { + var pos: Int = 0 + var buffer = new Array[Int](sizeInBytes) + + def write(i: Int): Unit = { + buffer(pos) = i + pos = (pos + 1) % buffer.length + } + + override def toString: String = { + val (end, start) = buffer.splitAt(pos) + val input = new java.io.InputStream { + val iterator = (start ++ end).iterator + + def read(): Int = if (iterator.hasNext) iterator.next() else -1 + } + val reader = new BufferedReader(new InputStreamReader(input)) + val stringBuilder = new StringBuilder + var line = reader.readLine() + while (line != null) { + stringBuilder.append(line) + stringBuilder.append("\n") + line = reader.readLine() + } + stringBuilder.toString() + } +} diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index c4a7b4441c85c..85fb923cd9bc7 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -70,12 +70,14 @@ private[spark] object XORShiftRandom { * @param args takes one argument - the number of random numbers to generate */ def main(args: Array[String]): Unit = { + // scalastyle:off println if (args.length != 1) { println("Benchmark of XORShiftRandom vis-a-vis java.util.Random") println("Usage: XORShiftRandom number_of_random_numbers_to_generate") System.exit(1) } println(benchmark(args(0).toInt)) + // scalastyle:on println } /** diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java new file mode 100644 index 0000000000000..ea8755e21eb68 --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.File; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.UUID; + +import scala.Tuple2; +import scala.Tuple2$; +import scala.runtime.AbstractFunction1; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.junit.Assert.*; +import static org.mockito.AdditionalAnswers.returnsFirstArg; +import static org.mockito.AdditionalAnswers.returnsSecondArg; +import static org.mockito.Answers.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.*; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +public class UnsafeExternalSorterSuite { + + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = new PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + + @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; + @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; + @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; + @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; + + File tempDir; + + private static final class CompressStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + return stream; + } + } + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + tempDir = new File(Utils.createTempDir$default$1()); + taskContext = mock(TaskContext.class); + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); + when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { + @Override + public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + when(blockManager.getDiskWriter( + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + @Override + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + Object[] args = invocationOnMock.getArguments(); + + return new DiskBlockObjectWriter( + (BlockId) args[0], + (File) args[1], + (SerializerInstance) args[2], + (Integer) args[3], + new CompressStream(), + false, + (ShuffleWriteMetrics) args[4] + ); + } + }); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) + .then(returnsSecondArg()); + } + + private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { + final int[] arr = new int[] { value }; + sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); + } + + @Test + public void testSortingOnlyByPrefix() throws Exception { + + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + new SparkConf()); + + insertNumber(sorter, 5); + insertNumber(sorter, 1); + insertNumber(sorter, 3); + sorter.spill(); + insertNumber(sorter, 4); + sorter.spill(); + insertNumber(sorter, 2); + + UnsafeSorterIterator iter = sorter.getSortedIterator(); + + for (int i = 1; i <= 5; i++) { + iter.loadNext(); + assertEquals(i, iter.getKeyPrefix()); + assertEquals(4, iter.getRecordLength()); + // TODO: read rest of value. + } + + // TODO: test for cleanup: + // assert(tempDir.isEmpty) + } + + @Test + public void testSortingEmptyArrays() throws Exception { + + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + new SparkConf()); + + sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0); + sorter.spill(); + sorter.insertRecord(null, 0, 0, 0); + sorter.spill(); + sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0); + + UnsafeSorterIterator iter = sorter.getSortedIterator(); + + for (int i = 1; i <= 5; i++) { + iter.loadNext(); + assertEquals(0, iter.getKeyPrefix()); + assertEquals(0, iter.getRecordLength()); + } + } + +} diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java new file mode 100644 index 0000000000000..909500930539c --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.Arrays; + +import org.junit.Test; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public class UnsafeInMemorySorterSuite { + + private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { + final byte[] strBytes = new byte[length]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, length); + return new String(strBytes); + } + + @Test + public void testSortingEmptyInput() { + final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), + mock(RecordComparator.class), + mock(PrefixComparator.class), + 100); + final UnsafeSorterIterator iter = sorter.getSortedIterator(); + assert(!iter.hasNext()); + } + + @Test + public void testSortingOnlyByIntegerPrefix() throws Exception { + final String[] dataToSort = new String[] { + "Boba", + "Pearls", + "Tapioca", + "Taho", + "Condensed Milk", + "Jasmine", + "Milk Tea", + "Lychee", + "Mango" + }; + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final Object baseObject = dataPage.getBaseObject(); + // Write the records into the data page: + long position = dataPage.getBaseOffset(); + for (String str : dataToSort) { + final byte[] strBytes = str.getBytes("utf-8"); + PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + position += 4; + PlatformDependent.copyMemory( + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + position, + strBytes.length); + position += strBytes.length; + } + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = new PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator, + prefixComparator, dataToSort.length); + // Given a page of records, insert those records into the sorter one-by-one: + position = dataPage.getBaseOffset(); + for (int i = 0; i < dataToSort.length; i++) { + // position now points to the start of a record (which holds its length). + final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position); + final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); + final String str = getStringFromDataPage(baseObject, position + 4, recordLength); + final int partitionId = hashPartitioner.getPartition(str); + sorter.insertRecord(address, partitionId); + position += 4 + recordLength; + } + final UnsafeSorterIterator iter = sorter.getSortedIterator(); + int iterLength = 0; + long prevPrefix = -1; + Arrays.sort(dataToSort); + while (iter.hasNext()) { + iter.loadNext(); + final String str = + getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset(), iter.getRecordLength()); + final long keyPrefix = iter.getKeyPrefix(); + assertThat(str, isIn(Arrays.asList(dataToSort))); + assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix)); + prevPrefix = keyPrefix; + iterLength++; + } + assertEquals(dataToSort.length, iterLength); + } +} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d1761a48babbc..cc50e6d79a3e2 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -46,7 +46,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging val parCollection = sc.makeRDD(1 to 4) val flatMappedRDD = parCollection.flatMap(x => 1 to x) flatMappedRDD.checkpoint() - assert(flatMappedRDD.dependencies.head.rdd == parCollection) + assert(flatMappedRDD.dependencies.head.rdd === parCollection) val result = flatMappedRDD.collect() assert(flatMappedRDD.dependencies.head.rdd != parCollection) assert(flatMappedRDD.collect() === result) diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 9c191ed52206d..2300bcff4f118 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -107,7 +107,9 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc = new SparkContext(clusterUrl, "test") val accum = sc.accumulator(0) val thrown = intercept[SparkException] { + // scalastyle:off println sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) + // scalastyle:on println } assert(thrown.getClass === classOf[SparkException]) assert(thrown.getMessage.contains("failed 4 times")) diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index a8c8c6f73fb5a..b099cd3fb7965 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -130,7 +130,9 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { // Non-serializable closure in foreach function val thrown2 = intercept[SparkException] { + // scalastyle:off println sc.parallelize(1 to 10, 2).foreach(x => println(a)) + // scalastyle:on println } assert(thrown2.getClass === classOf[SparkException]) assert(thrown2.getMessage.contains("NotSerializableException") || diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 6e65b0a8f6c76..876418aa13029 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -51,7 +51,9 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { val textFile = new File(testTempDir, "FileServerSuite.txt") val pw = new PrintWriter(textFile) + // scalastyle:off println pw.println("100") + // scalastyle:on println pw.close() val jarFile = new File(testTempDir, "test.jar") diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 911b3bddd1836..b31b09196608f 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -17,64 +17,145 @@ package org.apache.spark -import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.storage.BlockManagerId +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.mockito.Mockito.{mock, spy, verify, when} import org.mockito.Matchers import org.mockito.Matchers._ -import org.apache.spark.scheduler.TaskScheduler -import org.apache.spark.util.RpcUtils -import org.scalatest.concurrent.Eventually._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.ManualClock -class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext { +class HeartbeatReceiverSuite + extends SparkFunSuite + with BeforeAndAfterEach + with PrivateMethodTester + with LocalSparkContext { - test("HeartbeatReceiver") { + private val executorId1 = "executor-1" + private val executorId2 = "executor-2" + + // Shared state that must be reset before and after each test + private var scheduler: TaskScheduler = null + private var heartbeatReceiver: HeartbeatReceiver = null + private var heartbeatReceiverRef: RpcEndpointRef = null + private var heartbeatReceiverClock: ManualClock = null + + override def beforeEach(): Unit = { sc = spy(new SparkContext("local[2]", "test")) - val scheduler = mock(classOf[TaskScheduler]) - when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + scheduler = mock(classOf[TaskScheduler]) when(sc.taskScheduler).thenReturn(scheduler) + heartbeatReceiverClock = new ManualClock + heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) + heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + } - val heartbeatReceiver = new HeartbeatReceiver(sc) - sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) - eventually(timeout(5 seconds), interval(5 millis)) { - assert(heartbeatReceiver.scheduler != null) - } - val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + override def afterEach(): Unit = { + resetSparkContext() + scheduler = null + heartbeatReceiver = null + heartbeatReceiverRef = null + heartbeatReceiverClock = null + } - val metrics = new TaskMetrics - val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithRetry[HeartbeatResponse]( - Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + test("task scheduler is set correctly") { + assert(heartbeatReceiver.scheduler === null) + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + assert(heartbeatReceiver.scheduler !== null) + } - verify(scheduler).executorHeartbeatReceived( - Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) - assert(false === response.reregisterBlockManager) + test("normal heartbeat") { + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = false) + val trackedExecutors = executorLastSeen(heartbeatReceiver) + assert(trackedExecutors.size === 2) + assert(trackedExecutors.contains(executorId1)) + assert(trackedExecutors.contains(executorId2)) } - test("HeartbeatReceiver re-register") { - sc = spy(new SparkContext("local[2]", "test")) - val scheduler = mock(classOf[TaskScheduler]) - when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) - when(sc.taskScheduler).thenReturn(scheduler) + test("reregister if scheduler is not ready yet") { + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + // Task scheduler not set in HeartbeatReceiver + triggerHeartbeat(executorId1, executorShouldReregister = true) + } - val heartbeatReceiver = new HeartbeatReceiver(sc) - sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) - eventually(timeout(5 seconds), interval(5 millis)) { - assert(heartbeatReceiver.scheduler != null) - } - val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + test("reregister if heartbeat from unregistered executor") { + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + // Received heartbeat from unknown receiver, so we ask it to re-register + triggerHeartbeat(executorId1, executorShouldReregister = true) + assert(executorLastSeen(heartbeatReceiver).isEmpty) + } + + test("reregister if heartbeat from removed executor") { + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + // Remove the second executor but not the first + heartbeatReceiver.onExecutorRemoved(SparkListenerExecutorRemoved(0, executorId2, "bad boy")) + // Now trigger the heartbeats + // A heartbeat from the second executor should require reregistering + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = true) + val trackedExecutors = executorLastSeen(heartbeatReceiver) + assert(trackedExecutors.size === 1) + assert(trackedExecutors.contains(executorId1)) + assert(!trackedExecutors.contains(executorId2)) + } + test("expire dead hosts") { + val executorTimeout = executorTimeoutMs(heartbeatReceiver) + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = false) + // Advance the clock and only trigger a heartbeat for the first executor + heartbeatReceiverClock.advance(executorTimeout / 2) + triggerHeartbeat(executorId1, executorShouldReregister = false) + heartbeatReceiverClock.advance(executorTimeout) + heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts) + // Only the second executor should be expired as a dead host + verify(scheduler).executorLost(Matchers.eq(executorId2), any()) + val trackedExecutors = executorLastSeen(heartbeatReceiver) + assert(trackedExecutors.size === 1) + assert(trackedExecutors.contains(executorId1)) + assert(!trackedExecutors.contains(executorId2)) + } + + /** Manually send a heartbeat and return the response. */ + private def triggerHeartbeat( + executorId: String, + executorShouldReregister: Boolean): Unit = { val metrics = new TaskMetrics - val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithRetry[HeartbeatResponse]( - Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + val blockManagerId = BlockManagerId(executorId, "localhost", 12345) + val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( + Heartbeat(executorId, Array(1L -> metrics), blockManagerId)) + if (executorShouldReregister) { + assert(response.reregisterBlockManager) + } else { + assert(!response.reregisterBlockManager) + // Additionally verify that the scheduler callback is called with the correct parameters + verify(scheduler).executorHeartbeatReceived( + Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + } + } - verify(scheduler).executorHeartbeatReceived( - Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) - assert(true === response.reregisterBlockManager) + // Helper methods to access private fields in HeartbeatReceiver + private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen) + private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs) + private def executorLastSeen(receiver: HeartbeatReceiver): collection.Map[String, Long] = { + receiver invokePrivate _executorLastSeen() + } + private def executorTimeoutMs(receiver: HeartbeatReceiver): Long = { + receiver invokePrivate _executorTimeoutMs() } + } diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 376481ba541fa..25b79bce6ab98 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import javax.net.ssl.SSLContext import com.google.common.io.Files import org.apache.spark.util.Utils @@ -29,6 +30,15 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + // Pick two cipher suites that the provider knows about + val sslContext = SSLContext.getInstance("TLSv1.2") + sslContext.init(null, null, null) + val algorithms = sslContext + .getServerSocketFactory + .getDefaultCipherSuites + .take(2) + .toSet + val conf = new SparkConf conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) @@ -36,9 +46,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") - conf.set("spark.ssl.protocol", "SSLv3") + conf.set("spark.ssl.enabledAlgorithms", algorithms.mkString(",")) + conf.set("spark.ssl.protocol", "TLSv1.2") val opts = SSLOptions.parse(conf, "spark.ssl") @@ -52,9 +61,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(opts.trustStorePassword === Some("password")) assert(opts.keyStorePassword === Some("password")) assert(opts.keyPassword === Some("password")) - assert(opts.protocol === Some("SSLv3")) - assert(opts.enabledAlgorithms === - Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + assert(opts.protocol === Some("TLSv1.2")) + assert(opts.enabledAlgorithms === algorithms) } test("test resolving property with defaults specified ") { diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 1a099da2c6c8e..33270bec6247c 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -25,6 +25,20 @@ object SSLSampleConfigs { this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + val enabledAlgorithms = + // A reasonable set of TLSv1.2 Oracle security provider suites + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + + "TLS_RSA_WITH_AES_256_CBC_SHA256, " + + "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, " + + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + + "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, " + + // and their equivalent names in the IBM Security provider + "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + + "SSL_RSA_WITH_AES_256_CBC_SHA256, " + + "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256, " + + "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + + "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256" + def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.ssl.enabled", "true") @@ -33,9 +47,8 @@ object SSLSampleConfigs { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") - conf.set("spark.ssl.protocol", "TLSv1") + conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) + conf.set("spark.ssl.protocol", "TLSv1.2") conf } @@ -47,9 +60,8 @@ object SSLSampleConfigs { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") - conf.set("spark.ssl.protocol", "TLSv1") + conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) + conf.set("spark.ssl.protocol", "TLSv1.2") conf } diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index e9b64aa82a17a..f34aefca4eb18 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -127,6 +127,17 @@ class SecurityManagerSuite extends SparkFunSuite { test("ssl on setup") { val conf = SSLSampleConfigs.sparkSSLConfig() + val expectedAlgorithms = Set( + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384", + "TLS_RSA_WITH_AES_256_CBC_SHA256", + "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256", + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256", + "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384", + "SSL_RSA_WITH_AES_256_CBC_SHA256", + "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256", + "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256") val securityManager = new SecurityManager(conf) @@ -143,9 +154,8 @@ class SecurityManagerSuite extends SparkFunSuite { assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password")) assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password")) assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) - assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1")) - assert(securityManager.fileServerSSLOptions.enabledAlgorithms === - Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2")) + assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms) assert(securityManager.akkaSSLOptions.trustStore.isDefined === true) assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore") @@ -154,9 +164,8 @@ class SecurityManagerSuite extends SparkFunSuite { assert(securityManager.akkaSSLOptions.trustStorePassword === Some("password")) assert(securityManager.akkaSSLOptions.keyStorePassword === Some("password")) assert(securityManager.akkaSSLOptions.keyPassword === Some("password")) - assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1")) - assert(securityManager.akkaSSLOptions.enabledAlgorithms === - Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1.2")) + assert(securityManager.akkaSSLOptions.enabledAlgorithms === expectedAlgorithms) } test("ssl off setup") { diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 9fbaeb33f97cd..90cb7da94e88a 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -260,10 +260,10 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst assert(RpcUtils.retryWaitMs(conf) === 2L) conf.set("spark.akka.askTimeout", "3") - assert(RpcUtils.askTimeout(conf) === (3 seconds)) + assert(RpcUtils.askRpcTimeout(conf).duration === (3 seconds)) conf.set("spark.akka.lookupTimeout", "4") - assert(RpcUtils.lookupTimeout(conf) === (4 seconds)) + assert(RpcUtils.lookupRpcTimeout(conf).duration === (4 seconds)) } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 6838b35ab4cc8..5c57940fa5f77 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.util.Utils import scala.concurrent.Await import scala.concurrent.duration.Duration +import org.scalatest.Matchers._ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { @@ -272,4 +273,16 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } } + + test("calling multiple sc.stop() must not throw any exception") { + noException should be thrownBy { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val cnt = sc.parallelize(1 to 4).count() + sc.cancelAllJobs() + sc.stop() + // call stop second time + sc.stop() + } + } + } diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 6580139df6c60..48509f0759a3b 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -36,7 +36,7 @@ object ThreadingSuiteState { } } -class ThreadingSuite extends SparkFunSuite with LocalSparkContext { +class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { test("accessing SparkContext form a different thread") { sc = new SparkContext("local", "test") @@ -130,8 +130,6 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext { Thread.sleep(100) } if (running.get() != 4) { - println("Waited 1 second without seeing runningThreads = 4 (it was " + - running.get() + "); failing test") ThreadingSuiteState.failed.set(true) } number @@ -143,6 +141,8 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext { } sem.acquire(2) if (ThreadingSuiteState.failed.get()) { + logError("Waited 1 second without seeing runningThreads = 4 (it was " + + ThreadingSuiteState.runningThreads.get() + "); failing test") fail("One or more threads didn't see runningThreads = 4") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 357ed90be3f5c..e7878bde6fcb0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -51,9 +51,11 @@ class SparkSubmitSuite /** Simple PrintStream that reads data into a buffer */ private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() + // scalastyle:off println override def println(line: String) { lineBuffer += line } + // scalastyle:on println } /** Returns true if the script exits and the given search string is printed. */ @@ -81,6 +83,7 @@ class SparkSubmitSuite } } + // scalastyle:off println test("prints usage on empty input") { testPrematureExit(Array[String](), "Usage: spark-submit") } @@ -243,7 +246,7 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.Client") } classpath should have size 0 - sysProps should have size 8 + sysProps should have size 9 sysProps.keys should contain ("SPARK_SUBMIT") sysProps.keys should contain ("spark.master") sysProps.keys should contain ("spark.app.name") @@ -252,6 +255,7 @@ class SparkSubmitSuite sysProps.keys should contain ("spark.driver.cores") sysProps.keys should contain ("spark.driver.supervise") sysProps.keys should contain ("spark.shuffle.spill") + sysProps.keys should contain ("spark.submit.deployMode") sysProps("spark.shuffle.spill") should be ("false") } @@ -491,6 +495,7 @@ class SparkSubmitSuite appArgs.executorMemory should be ("2.3g") } } + // scalastyle:on println // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { @@ -548,6 +553,7 @@ object JarCreationTest extends Logging { if (result.nonEmpty) { throw new Exception("Could not load user class from jar:\n" + result(0)) } + sc.stop() } } @@ -573,6 +579,7 @@ object SimpleApplicationTest { s"Master had $config=$masterValue but executor had $config=$executorValue") } } + sc.stop() } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 12c40f0b7d658..01ece1a10f46d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -41,9 +41,11 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { /** Simple PrintStream that reads data into a buffer */ private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() + // scalastyle:off println override def println(line: String) { lineBuffer += line } + // scalastyle:on println } override def beforeAll() { @@ -77,9 +79,9 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(resolver2.getResolvers.size() === 7) val expected = repos.split(",").map(r => s"$r/") resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: AbstractResolver, i) => - if (i > 3) { - assert(resolver.getName === s"repo-${i - 3}") - assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i - 4)) + if (i < 3) { + assert(resolver.getName === s"repo-${i + 1}") + assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i)) } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 09075eeb539aa..2a62450bcdbad 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -39,6 +39,8 @@ import org.apache.spark.util.{JsonProtocol, ManualClock, Utils} class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { + import FsHistoryProvider._ + private var testDir: File = null before { @@ -67,7 +69,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Write a new-style application log. val newAppComplete = newLogFile("new1", None, inProgress = false) writeFile(newAppComplete, true, None, - SparkListenerApplicationStart("new-app-complete", None, 1L, "test", None), + SparkListenerApplicationStart(newAppComplete.getName(), Some("new-app-complete"), 1L, "test", + None), SparkListenerApplicationEnd(5L) ) @@ -75,35 +78,30 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val newAppCompressedComplete = newLogFile("new1compressed", None, inProgress = false, Some("lzf")) writeFile(newAppCompressedComplete, true, None, - SparkListenerApplicationStart("new-app-compressed-complete", None, 1L, "test", None), + SparkListenerApplicationStart(newAppCompressedComplete.getName(), Some("new-complete-lzf"), + 1L, "test", None), SparkListenerApplicationEnd(4L)) // Write an unfinished app, new-style. val newAppIncomplete = newLogFile("new2", None, inProgress = true) writeFile(newAppIncomplete, true, None, - SparkListenerApplicationStart("new-app-incomplete", None, 1L, "test", None) + SparkListenerApplicationStart(newAppIncomplete.getName(), Some("new-incomplete"), 1L, "test", + None) ) // Write an old-style application log. - val oldAppComplete = new File(testDir, "old1") - oldAppComplete.mkdir() - createEmptyFile(new File(oldAppComplete, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldAppComplete, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart("old-app-complete", None, 2L, "test", None), + val oldAppComplete = writeOldLog("old1", "1.0", None, true, + SparkListenerApplicationStart("old1", Some("old-app-complete"), 2L, "test", None), SparkListenerApplicationEnd(3L) ) - createEmptyFile(new File(oldAppComplete, provider.APPLICATION_COMPLETE)) // Check for logs so that we force the older unfinished app to be loaded, to make // sure unfinished apps are also sorted correctly. provider.checkForLogs() // Write an unfinished app, old-style. - val oldAppIncomplete = new File(testDir, "old2") - oldAppIncomplete.mkdir() - createEmptyFile(new File(oldAppIncomplete, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldAppIncomplete, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart("old-app-incomplete", None, 2L, "test", None) + val oldAppIncomplete = writeOldLog("old2", "1.0", None, false, + SparkListenerApplicationStart("old2", None, 2L, "test", None) ) // Force a reload of data from the log directory, and check that both logs are loaded. @@ -124,16 +122,15 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc List(ApplicationAttemptInfo(None, start, end, lastMod, user, completed))) } - list(0) should be (makeAppInfo(newAppComplete.getName(), "new-app-complete", 1L, 5L, + list(0) should be (makeAppInfo("new-app-complete", newAppComplete.getName(), 1L, 5L, newAppComplete.lastModified(), "test", true)) - list(1) should be (makeAppInfo(newAppCompressedComplete.getName(), - "new-app-compressed-complete", 1L, 4L, newAppCompressedComplete.lastModified(), "test", - true)) - list(2) should be (makeAppInfo(oldAppComplete.getName(), "old-app-complete", 2L, 3L, + list(1) should be (makeAppInfo("new-complete-lzf", newAppCompressedComplete.getName(), + 1L, 4L, newAppCompressedComplete.lastModified(), "test", true)) + list(2) should be (makeAppInfo("old-app-complete", oldAppComplete.getName(), 2L, 3L, oldAppComplete.lastModified(), "test", true)) - list(3) should be (makeAppInfo(oldAppIncomplete.getName(), "old-app-incomplete", 2L, -1L, - oldAppIncomplete.lastModified(), "test", false)) - list(4) should be (makeAppInfo(newAppIncomplete.getName(), "new-app-incomplete", 1L, -1L, + list(3) should be (makeAppInfo(oldAppIncomplete.getName(), oldAppIncomplete.getName(), 2L, + -1L, oldAppIncomplete.lastModified(), "test", false)) + list(4) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, newAppIncomplete.lastModified(), "test", false)) // Make sure the UI can be rendered. @@ -155,12 +152,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val codec = if (valid) CompressionCodec.createCodec(new SparkConf(), codecName) else null val logDir = new File(testDir, codecName) logDir.mkdir() - createEmptyFile(new File(logDir, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(logDir, provider.LOG_PREFIX + "1"), false, Option(codec), + createEmptyFile(new File(logDir, SPARK_VERSION_PREFIX + "1.0")) + writeFile(new File(logDir, LOG_PREFIX + "1"), false, Option(codec), SparkListenerApplicationStart("app2", None, 2L, "test", None), SparkListenerApplicationEnd(3L) ) - createEmptyFile(new File(logDir, provider.COMPRESSION_CODEC_PREFIX + codecName)) + createEmptyFile(new File(logDir, COMPRESSION_CODEC_PREFIX + codecName)) val logPath = new Path(logDir.getAbsolutePath()) try { @@ -180,12 +177,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("SPARK-3697: ignore directories that cannot be read.") { val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, - SparkListenerApplicationStart("app1-1", None, 1L, "test", None), + SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None), SparkListenerApplicationEnd(2L) ) val logFile2 = newLogFile("new2", None, inProgress = false) writeFile(logFile2, true, None, - SparkListenerApplicationStart("app1-2", None, 1L, "test", None), + SparkListenerApplicationStart("app1-2", Some("app1-2"), 1L, "test", None), SparkListenerApplicationEnd(2L) ) logFile2.setReadable(false, false) @@ -218,6 +215,18 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("Parse logs that application is not started") { + val provider = new FsHistoryProvider((createTestConf())) + + val logFile1 = newLogFile("app1", None, inProgress = true) + writeFile(logFile1, true, None, + SparkListenerLogStart("1.4") + ) + updateAndCheck(provider) { list => + list.size should be (0) + } + } + test("SPARK-5582: empty log directory") { val provider = new FsHistoryProvider(createTestConf()) @@ -373,6 +382,33 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("SPARK-8372: new logs with no app ID are ignored") { + val provider = new FsHistoryProvider(createTestConf()) + + // Write a new log file without an app id, to make sure it's ignored. + val logFile1 = newLogFile("app1", None, inProgress = true) + writeFile(logFile1, true, None, + SparkListenerLogStart("1.4") + ) + + // Write a 1.2 log file with no start event (= no app id), it should be ignored. + writeOldLog("v12Log", "1.2", None, false) + + // Write 1.0 and 1.1 logs, which don't have app ids. + writeOldLog("v11Log", "1.1", None, true, + SparkListenerApplicationStart("v11Log", None, 2L, "test", None), + SparkListenerApplicationEnd(3L)) + writeOldLog("v10Log", "1.0", None, true, + SparkListenerApplicationStart("v10Log", None, 2L, "test", None), + SparkListenerApplicationEnd(4L)) + + updateAndCheck(provider) { list => + list.size should be (2) + list(0).id should be ("v10Log") + list(1).id should be ("v11Log") + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: @@ -412,4 +448,23 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) } + private def writeOldLog( + fname: String, + sparkVersion: String, + codec: Option[CompressionCodec], + completed: Boolean, + events: SparkListenerEvent*): File = { + val log = new File(testDir, fname) + log.mkdir() + + val oldEventLog = new File(log, LOG_PREFIX + "1") + createEmptyFile(new File(log, SPARK_VERSION_PREFIX + sparkVersion)) + writeFile(new File(log, LOG_PREFIX + "1"), false, codec, events: _*) + if (completed) { + createEmptyFile(new File(log, APPLICATION_COMPLETE)) + } + + log + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 014e87bb40254..9cb6dd43bac47 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -19,63 +19,21 @@ package org.apache.spark.deploy.master import java.util.Date -import scala.concurrent.Await import scala.concurrent.duration._ import scala.io.Source import scala.language.postfixOps -import akka.actor.Address import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.scalatest.Matchers import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy._ class MasterSuite extends SparkFunSuite with Matchers with Eventually { - test("toAkkaUrl") { - val conf = new SparkConf(loadDefaults = false) - val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.tcp") - assert("akka.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl) - } - - test("toAkkaUrl with SSL") { - val conf = new SparkConf(loadDefaults = false) - val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.ssl.tcp") - assert("akka.ssl.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl) - } - - test("toAkkaUrl: a typo url") { - val conf = new SparkConf(loadDefaults = false) - val e = intercept[SparkException] { - Master.toAkkaUrl("spark://1.2. 3.4:1234", "akka.tcp") - } - assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) - } - - test("toAkkaAddress") { - val conf = new SparkConf(loadDefaults = false) - val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.tcp") - assert(Address("akka.tcp", "sparkMaster", "1.2.3.4", 1234) === address) - } - - test("toAkkaAddress with SSL") { - val conf = new SparkConf(loadDefaults = false) - val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.ssl.tcp") - assert(Address("akka.ssl.tcp", "sparkMaster", "1.2.3.4", 1234) === address) - } - - test("toAkkaAddress: a typo url") { - val conf = new SparkConf(loadDefaults = false) - val e = intercept[SparkException] { - Master.toAkkaAddress("spark://1.2. 3.4:1234", "akka.tcp") - } - assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) - } - test("can use a custom recovery mode factory") { val conf = new SparkConf(loadDefaults = false) conf.set("spark.deploy.recoveryMode", "CUSTOM") @@ -129,16 +87,16 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { port = 10000, cores = 0, memory = 0, - actor = null, + endpoint = null, webUiPort = 0, publicAddress = "" ) - val (actorSystem, port, uiPort, restPort) = - Master.startSystemAndActor("127.0.0.1", 7077, 8080, conf) + val (rpcEnv, uiPort, restPort) = + Master.startRpcEnvAndEndpoint("127.0.0.1", 7077, 8080, conf) try { - Await.result(actorSystem.actorSelection("/user/Master").resolveOne(10 seconds), 10 seconds) + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME) CustomPersistenceEngine.lastInstance.isDefined shouldBe true val persistenceEngine = CustomPersistenceEngine.lastInstance.get @@ -154,8 +112,8 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { workers.map(_.id) should contain(workerToPersist.id) } finally { - actorSystem.shutdown() - actorSystem.awaitTermination() + rpcEnv.shutdown() + rpcEnv.awaitTermination() } CustomRecoveryModeFactory.instantiationAttempts should be > instantiationAttempts diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 197f68e7ec5ed..96e456d889ac3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -23,14 +23,14 @@ import javax.servlet.http.HttpServletResponse import scala.collection.mutable -import akka.actor.{Actor, ActorRef, ActorSystem, Props} import com.google.common.base.Charsets import org.scalatest.BeforeAndAfterEach import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ import org.apache.spark._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.rpc._ +import org.apache.spark.util.Utils import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} import org.apache.spark.deploy.master.DriverState._ @@ -39,11 +39,11 @@ import org.apache.spark.deploy.master.DriverState._ * Tests for the REST application submission protocol used in standalone cluster mode. */ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { - private var actorSystem: Option[ActorSystem] = None + private var rpcEnv: Option[RpcEnv] = None private var server: Option[RestSubmissionServer] = None override def afterEach() { - actorSystem.foreach(_.shutdown()) + rpcEnv.foreach(_.shutdown()) server.foreach(_.stop()) } @@ -377,31 +377,32 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { killMessage: String = "driver is killed", state: DriverState = FINISHED, exception: Option[Exception] = None): String = { - startServer(new DummyMaster(submitId, submitMessage, killMessage, state, exception)) + startServer(new DummyMaster(_, submitId, submitMessage, killMessage, state, exception)) } /** Start a smarter dummy server that keeps track of submitted driver states. */ private def startSmartServer(): String = { - startServer(new SmarterMaster) + startServer(new SmarterMaster(_)) } /** Start a dummy server that is faulty in many ways... */ private def startFaultyServer(): String = { - startServer(new DummyMaster, faulty = true) + startServer(new DummyMaster(_), faulty = true) } /** - * Start a [[StandaloneRestServer]] that communicates with the given actor. + * Start a [[StandaloneRestServer]] that communicates with the given endpoint. * If `faulty` is true, start an [[FaultyStandaloneRestServer]] instead. * Return the master URL that corresponds to the address of this server. */ - private def startServer(makeFakeMaster: => Actor, faulty: Boolean = false): String = { + private def startServer( + makeFakeMaster: RpcEnv => RpcEndpoint, faulty: Boolean = false): String = { val name = "test-standalone-rest-protocol" val conf = new SparkConf val localhost = Utils.localHostName() val securityManager = new SecurityManager(conf) - val (_actorSystem, _) = AkkaUtils.createActorSystem(name, localhost, 0, conf, securityManager) - val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster)) + val _rpcEnv = RpcEnv.create(name, localhost, 0, conf, securityManager) + val fakeMasterRef = _rpcEnv.setupEndpoint("fake-master", makeFakeMaster(_rpcEnv)) val _server = if (faulty) { new FaultyStandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") @@ -410,7 +411,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { } val port = _server.start() // set these to clean them up after every test - actorSystem = Some(_actorSystem) + rpcEnv = Some(_rpcEnv) server = Some(_server) s"spark://$localhost:$port" } @@ -505,20 +506,21 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { * In all responses, the success parameter is always true. */ private class DummyMaster( + override val rpcEnv: RpcEnv, submitId: String = "fake-driver-id", submitMessage: String = "submitted", killMessage: String = "killed", state: DriverState = FINISHED, exception: Option[Exception] = None) - extends Actor { + extends RpcEndpoint { - override def receive: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestSubmitDriver(driverDesc) => - sender ! SubmitDriverResponse(success = true, Some(submitId), submitMessage) + context.reply(SubmitDriverResponse(self, success = true, Some(submitId), submitMessage)) case RequestKillDriver(driverId) => - sender ! KillDriverResponse(driverId, success = true, killMessage) + context.reply(KillDriverResponse(self, driverId, success = true, killMessage)) case RequestDriverStatus(driverId) => - sender ! DriverStatusResponse(found = true, Some(state), None, None, exception) + context.reply(DriverStatusResponse(found = true, Some(state), None, None, exception)) } } @@ -531,28 +533,28 @@ private class DummyMaster( * Submits are always successful while kills and status requests are successful only * if the driver was submitted in the past. */ -private class SmarterMaster extends Actor { +private class SmarterMaster(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { private var counter: Int = 0 private val submittedDrivers = new mutable.HashMap[String, DriverState] - override def receive: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestSubmitDriver(driverDesc) => val driverId = s"driver-$counter" submittedDrivers(driverId) = RUNNING counter += 1 - sender ! SubmitDriverResponse(success = true, Some(driverId), "submitted") + context.reply(SubmitDriverResponse(self, success = true, Some(driverId), "submitted")) case RequestKillDriver(driverId) => val success = submittedDrivers.contains(driverId) if (success) { submittedDrivers(driverId) = KILLED } - sender ! KillDriverResponse(driverId, success, "killed") + context.reply(KillDriverResponse(self, driverId, success, "killed")) case RequestDriverStatus(driverId) => val found = submittedDrivers.contains(driverId) val state = submittedDrivers.get(driverId) - sender ! DriverStatusResponse(found, state, None, None, None) + context.reply(DriverStatusResponse(found, state, None, None, None)) } } @@ -568,7 +570,7 @@ private class FaultyStandaloneRestServer( host: String, requestedPort: Int, masterConf: SparkConf, - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String) extends RestSubmissionServer(host, requestedPort, masterConf) { @@ -578,7 +580,7 @@ private class FaultyStandaloneRestServer( /** A faulty servlet that produces malformed responses. */ class MalformedSubmitServlet - extends StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) { + extends StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) { protected override def sendResponse( responseMessage: SubmitRestProtocolResponse, responseServlet: HttpServletResponse): Unit = { @@ -588,7 +590,7 @@ private class FaultyStandaloneRestServer( } /** A faulty servlet that produces invalid responses. */ - class InvalidKillServlet extends StandaloneKillRequestServlet(masterActor, masterConf) { + class InvalidKillServlet extends StandaloneKillRequestServlet(masterEndpoint, masterConf) { protected override def handleKill(submissionId: String): KillSubmissionResponse = { val k = super.handleKill(submissionId) k.submissionId = null @@ -597,7 +599,7 @@ private class FaultyStandaloneRestServer( } /** A faulty status servlet that explodes. */ - class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterActor, masterConf) { + class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterEndpoint, masterConf) { private def explode: Int = 1 / 0 protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { val s = super.handleStatus(submissionId) diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 115ac0534a1b4..725b8848bc052 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.deploy.rest import java.lang.Boolean -import java.lang.Integer import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.util.Utils /** * Tests for the REST application submission protocol. @@ -93,7 +93,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite { // optional fields conf.set("spark.jars", "mayonnaise.jar,ketchup.jar") conf.set("spark.files", "fireball.png") - conf.set("spark.driver.memory", "512m") + conf.set("spark.driver.memory", s"${Utils.DEFAULT_DRIVER_MEM_MB}m") conf.set("spark.driver.cores", "180") conf.set("spark.driver.extraJavaOptions", " -Dslices=5 -Dcolor=mostly_red") conf.set("spark.driver.extraClassPath", "food-coloring.jar") @@ -126,7 +126,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite { assert(newMessage.sparkProperties("spark.app.name") === "SparkPie") assert(newMessage.sparkProperties("spark.jars") === "mayonnaise.jar,ketchup.jar") assert(newMessage.sparkProperties("spark.files") === "fireball.png") - assert(newMessage.sparkProperties("spark.driver.memory") === "512m") + assert(newMessage.sparkProperties("spark.driver.memory") === s"${Utils.DEFAULT_DRIVER_MEM_MB}m") assert(newMessage.sparkProperties("spark.driver.cores") === "180") assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") === " -Dslices=5 -Dcolor=mostly_red") @@ -230,7 +230,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite { """.stripMargin private val submitDriverRequestJson = - """ + s""" |{ | "action" : "CreateSubmissionRequest", | "appArgs" : [ "two slices", "a hint of cinnamon" ], @@ -246,7 +246,7 @@ class SubmitRestProtocolSuite extends SparkFunSuite { | "spark.driver.supervise" : "false", | "spark.app.name" : "SparkPie", | "spark.cores.max" : "10000", - | "spark.driver.memory" : "512m", + | "spark.driver.memory" : "${Utils.DEFAULT_DRIVER_MEM_MB}m", | "spark.files" : "fireball.png", | "spark.driver.cores" : "180", | "spark.driver.extraJavaOptions" : " -Dslices=5 -Dcolor=mostly_red", diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index ac18f04a11475..cd24d79423316 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy.worker -import akka.actor.AddressFromURIString import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SecurityManager import org.apache.spark.rpc.{RpcAddress, RpcEnv} @@ -26,13 +25,11 @@ class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" - val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) + val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.onDisconnected( - RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get)) + workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) assert(workerWatcher.isShutDown) rpcEnv.shutdown() } @@ -40,13 +37,13 @@ class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher stays alive on invalid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" - val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor" - val otherAkkaAddress = AddressFromURIString(otherAkkaURL) + val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") + val otherAddress = "akka://test@4.3.2.1:1234/user/OtherActor" + val otherAkkaAddress = RpcAddress("4.3.2.1", 1234) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.onDisconnected(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get)) + workerWatcher.onDisconnected(otherAkkaAddress) assert(!workerWatcher.isShutDown) rpcEnv.shutdown() } diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 63947df3d43a2..8a199459c1ddf 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.hadoop.io.Text -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.util.Utils import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} @@ -36,7 +36,7 @@ import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, Gzi * [[org.apache.spark.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary * directory is created as fake input. Temporal storage would be deleted in the end. */ -class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll { +class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { private var sc: SparkContext = _ private var factory: CompressionCodecFactory = _ @@ -85,7 +85,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl */ test("Correctness of WholeTextFileRecordReader.") { val dir = Utils.createTempDir() - println(s"Local disk address is ${dir.toString}.") + logInfo(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => createNativeFile(dir, filename, contents, false) @@ -109,7 +109,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl test("Correctness of WholeTextFileRecordReader with GzipCodec.") { val dir = Utils.createTempDir() - println(s"Local disk address is ${dir.toString}.") + logInfo(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => createNativeFile(dir, filename, contents, true) diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 9e4d34fb7d382..d3218a548efc7 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -60,7 +60,9 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(tmpFile)) for (x <- 1 to numRecords) { + // scalastyle:off println pw.println(RandomUtils.nextInt(0, numBuckets)) + // scalastyle:on println } pw.close() diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala new file mode 100644 index 0000000000000..b3223ec61bf79 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.apache.spark.{SparkException, SparkFunSuite} + +class RpcAddressSuite extends SparkFunSuite { + + test("hostPort") { + val address = RpcAddress("1.2.3.4", 1234) + assert(address.host == "1.2.3.4") + assert(address.port == 1234) + assert(address.hostPort == "1.2.3.4:1234") + } + + test("fromSparkURL") { + val address = RpcAddress.fromSparkURL("spark://1.2.3.4:1234") + assert(address.host == "1.2.3.4") + assert(address.port == 1234) + } + + test("fromSparkURL: a typo url") { + val e = intercept[SparkException] { + RpcAddress.fromSparkURL("spark://1.2. 3.4:1234") + } + assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) + } + + test("fromSparkURL: invalid scheme") { + val e = intercept[SparkException] { + RpcAddress.fromSparkURL("invalid://1.2.3.4:1234") + } + assert("Invalid master URL: invalid://1.2.3.4:1234" === e.getMessage) + } + + test("toSparkURL") { + val address = RpcAddress("1.2.3.4", 1234) + assert(address.toSparkURL == "spark://1.2.3.4:1234") + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 1f0aa759b08da..6ceafe4337747 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -155,16 +155,21 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { }) val conf = new SparkConf() + val shortProp = "spark.rpc.short.timeout" conf.set("spark.rpc.retry.wait", "0") conf.set("spark.rpc.numRetries", "1") val anotherEnv = createRpcEnv(conf, "remote", 13345) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { - val e = intercept[Exception] { - rpcEndpointRef.askWithRetry[String]("hello", 1 millis) + // Any exception thrown in askWithRetry is wrapped with a SparkException and set as the cause + val e = intercept[SparkException] { + rpcEndpointRef.askWithRetry[String]("hello", new RpcTimeout(1 millis, shortProp)) } - assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException]) + // The SparkException cause should be a RpcTimeoutException with message indicating the + // controlling timeout property + assert(e.getCause.isInstanceOf[RpcTimeoutException]) + assert(e.getCause.getMessage.contains(shortProp)) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() @@ -539,6 +544,92 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("construct RpcTimeout with conf property") { + val conf = new SparkConf + + val testProp = "spark.ask.test.timeout" + val testDurationSeconds = 30 + val secondaryProp = "spark.ask.secondary.timeout" + + conf.set(testProp, s"${testDurationSeconds}s") + conf.set(secondaryProp, "100s") + + // Construct RpcTimeout with a single property + val rt1 = RpcTimeout(conf, testProp) + assert( testDurationSeconds === rt1.duration.toSeconds ) + + // Construct RpcTimeout with prioritized list of properties + val rt2 = RpcTimeout(conf, Seq("spark.ask.invalid.timeout", testProp, secondaryProp), "1s") + assert( testDurationSeconds === rt2.duration.toSeconds ) + + // Construct RpcTimeout with default value, + val defaultProp = "spark.ask.default.timeout" + val defaultDurationSeconds = 1 + val rt3 = RpcTimeout(conf, Seq(defaultProp), defaultDurationSeconds.toString + "s") + assert( defaultDurationSeconds === rt3.duration.toSeconds ) + assert( rt3.timeoutProp.contains(defaultProp) ) + + // Try to construct RpcTimeout with an unconfigured property + intercept[NoSuchElementException] { + RpcTimeout(conf, "spark.ask.invalid.timeout") + } + } + + test("ask a message timeout on Future using RpcTimeout") { + case class NeverReply(msg: String) + + val rpcEndpointRef = env.setupEndpoint("ask-future", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => context.reply(msg) + case _: NeverReply => + } + }) + + val longTimeout = new RpcTimeout(1 second, "spark.rpc.long.timeout") + val shortTimeout = new RpcTimeout(10 millis, "spark.rpc.short.timeout") + + // Ask with immediate response, should complete successfully + val fut1 = rpcEndpointRef.ask[String]("hello", longTimeout) + val reply1 = longTimeout.awaitResult(fut1) + assert("hello" === reply1) + + // Ask with a delayed response and wait for response immediately that should timeout + val fut2 = rpcEndpointRef.ask[String](NeverReply("doh"), shortTimeout) + val reply2 = + intercept[RpcTimeoutException] { + shortTimeout.awaitResult(fut2) + }.getMessage + + // RpcTimeout.awaitResult should have added the property to the TimeoutException message + assert(reply2.contains(shortTimeout.timeoutProp)) + + // Ask with delayed response and allow the Future to timeout before Await.result + val fut3 = rpcEndpointRef.ask[String](NeverReply("goodbye"), shortTimeout) + + // Allow future to complete with failure using plain Await.result, this will return + // once the future is complete to verify addMessageIfTimeout was invoked + val reply3 = + intercept[RpcTimeoutException] { + Await.result(fut3, 200 millis) + }.getMessage + + // When the future timed out, the recover callback should have used + // RpcTimeout.addMessageIfTimeout to add the property to the TimeoutException message + assert(reply3.contains(shortTimeout.timeoutProp)) + + // Use RpcTimeout.awaitResult to process Future, since it has already failed with + // RpcTimeoutException, the same RpcTimeoutException should be thrown + val reply4 = + intercept[RpcTimeoutException] { + shortTimeout.awaitResult(fut3) + }.getMessage + + // Ensure description is not in message twice after addMessageIfTimeout and awaitResult + assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1) + } + } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index a33a83db7bc9e..4aa75c9230b2c 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.rpc.akka import org.apache.spark.rpc._ -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf} class AkkaRpcEnvSuite extends RpcEnvSuite { @@ -47,4 +47,22 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { } } + test("uriOf") { + val uri = env.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") + assert("akka.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) + } + + test("uriOf: ssl") { + val conf = SSLSampleConfigs.sparkSSLConfig() + val securityManager = new SecurityManager(conf) + val rpcEnv = new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, "test", "localhost", 12346, securityManager)) + try { + val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") + assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) + } finally { + rpcEnv.shutdown() + } + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index ff3fa95ec32ae..4e3defb43a021 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -52,8 +52,10 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey", None) val applicationEnd = SparkListenerApplicationEnd(1000L) + // scalastyle:off println writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart)))) writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd)))) + // scalastyle:on println writer.close() val conf = EventLoggingListenerSuite.getLoggingConf(logFilePath) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala new file mode 100644 index 0000000000000..3f1692917a357 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util +import java.util.Collections + +import org.apache.mesos.Protos.Value.Scalar +import org.apache.mesos.Protos._ +import org.apache.mesos.SchedulerDriver +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.Matchers +import org.scalatest.mock.MockitoSugar +import org.scalatest.BeforeAndAfter + +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} + +class CoarseMesosSchedulerBackendSuite extends SparkFunSuite + with LocalSparkContext + with MockitoSugar + with BeforeAndAfter { + + private def createOffer(offerId: String, slaveId: String, mem: Int, cpu: Int): Offer = { + val builder = Offer.newBuilder() + builder.addResourcesBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(mem)) + builder.addResourcesBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(cpu)) + builder.setId(OfferID.newBuilder() + .setValue(offerId).build()) + .setFrameworkId(FrameworkID.newBuilder() + .setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) + .setHostname(s"host${slaveId}") + .build() + } + + private def createSchedulerBackend( + taskScheduler: TaskSchedulerImpl, + driver: SchedulerDriver): CoarseMesosSchedulerBackend = { + val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master") { + mesosDriver = driver + markRegistered() + } + backend.start() + backend + } + + var sparkConf: SparkConf = _ + + before { + sparkConf = (new SparkConf) + .setMaster("local[*]") + .setAppName("test-mesos-dynamic-alloc") + .setSparkHome("/path") + + sc = new SparkContext(sparkConf) + } + + test("mesos supports killing and limiting executors") { + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + sparkConf.set("spark.driver.host", "driverHost") + sparkConf.set("spark.driver.port", "1234") + + val backend = createSchedulerBackend(taskScheduler, driver) + val minMem = backend.calculateTotalMemory(sc).toInt + val minCpu = 4 + + val mesosOffers = new java.util.ArrayList[Offer] + mesosOffers.add(createOffer("o1", "s1", minMem, minCpu)) + + val taskID0 = TaskID.newBuilder().setValue("0").build() + + backend.resourceOffers(driver, mesosOffers) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + any[util.Collection[TaskInfo]], + any[Filters]) + + // simulate the allocation manager down-scaling executors + backend.doRequestTotalExecutors(0) + assert(backend.doKillExecutors(Seq("s1/0"))) + verify(driver, times(1)).killTask(taskID0) + + val mesosOffers2 = new java.util.ArrayList[Offer] + mesosOffers2.add(createOffer("o2", "s2", minMem, minCpu)) + backend.resourceOffers(driver, mesosOffers2) + + verify(driver, times(1)) + .declineOffer(OfferID.newBuilder().setValue("o2").build()) + + // Verify we didn't launch any new executor + assert(backend.slaveIdsWithExecutors.size === 1) + + backend.doRequestTotalExecutors(2) + backend.resourceOffers(driver, mesosOffers2) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers2.get(0).getId)), + any[util.Collection[TaskInfo]], + any[Filters]) + + assert(backend.slaveIdsWithExecutors.size === 2) + backend.slaveLost(driver, SlaveID.newBuilder().setValue("s1").build()) + assert(backend.slaveIdsWithExecutors.size === 1) + } + + test("mesos supports killing and relaunching tasks with executors") { + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + val backend = createSchedulerBackend(taskScheduler, driver) + val minMem = backend.calculateTotalMemory(sc).toInt + 1024 + val minCpu = 4 + + val mesosOffers = new java.util.ArrayList[Offer] + val offer1 = createOffer("o1", "s1", minMem, minCpu) + mesosOffers.add(offer1) + + val offer2 = createOffer("o2", "s1", minMem, 1); + + backend.resourceOffers(driver, mesosOffers) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer1.getId)), + anyObject(), + anyObject[Filters]) + + // Simulate task killed, executor no longer running + val status = TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue("0").build()) + .setSlaveId(SlaveID.newBuilder().setValue("s1").build()) + .setState(TaskState.TASK_KILLED) + .build + + backend.statusUpdate(driver, status) + assert(!backend.slaveIdsWithExecutors.contains("s1")) + + mesosOffers.clear() + mesosOffers.add(offer2) + backend.resourceOffers(driver, mesosOffers) + assert(backend.slaveIdsWithExecutors.contains("s1")) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer2.getId)), + anyObject(), + anyObject[Filters]) + + verify(driver, times(1)).reviveOffers() + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala deleted file mode 100644 index e72285d03d3ee..0000000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster.mesos - -import org.mockito.Mockito._ -import org.scalatest.mock.MockitoSugar - -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} - -class MemoryUtilsSuite extends SparkFunSuite with MockitoSugar { - test("MesosMemoryUtils should always override memoryOverhead when it's set") { - val sparkConf = new SparkConf - - val sc = mock[SparkContext] - when(sc.conf).thenReturn(sparkConf) - - // 384 > sc.executorMemory * 0.1 => 512 + 384 = 896 - when(sc.executorMemory).thenReturn(512) - assert(MemoryUtils.calculateTotalMemory(sc) === 896) - - // 384 < sc.executorMemory * 0.1 => 4096 + (4096 * 0.1) = 4505.6 - when(sc.executorMemory).thenReturn(4096) - assert(MemoryUtils.calculateTotalMemory(sc) === 4505) - - // set memoryOverhead - sparkConf.set("spark.mesos.executor.memoryOverhead", "100") - assert(MemoryUtils.calculateTotalMemory(sc) === 4196) - sparkConf.set("spark.mesos.executor.memoryOverhead", "400") - assert(MemoryUtils.calculateTotalMemory(sc) === 4496) - } -} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 68df46a41ddc8..d01837fe78957 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -149,7 +149,9 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(sc.conf).thenReturn(new SparkConf) when(sc.listenerBus).thenReturn(listenerBus) - val minMem = MemoryUtils.calculateTotalMemory(sc).toInt + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val minMem = backend.calculateTotalMemory(sc) val minCpu = 4 val mesosOffers = new java.util.ArrayList[Offer] @@ -157,8 +159,6 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi mesosOffers.add(createOffer(2, minMem - 1, minCpu)) mesosOffers.add(createOffer(3, minMem, minCpu)) - val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val expectedWorkerOffers = new ArrayBuffer[WorkerOffer](2) expectedWorkerOffers.append(new WorkerOffer( mesosOffers.get(0).getSlaveId.getValue, diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala new file mode 100644 index 0000000000000..b354914b6ffd0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.apache.mesos.Protos.Value +import org.mockito.Mockito._ +import org.scalatest._ +import org.scalatest.mock.MockitoSugar +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} + +class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar { + + // scalastyle:off structural.type + // this is the documented way of generating fixtures in scalatest + def fixture: Object {val sc: SparkContext; val sparkConf: SparkConf} = new { + val sparkConf = new SparkConf + val sc = mock[SparkContext] + when(sc.conf).thenReturn(sparkConf) + } + val utils = new MesosSchedulerUtils { } + // scalastyle:on structural.type + + test("use at-least minimum overhead") { + val f = fixture + when(f.sc.executorMemory).thenReturn(512) + utils.calculateTotalMemory(f.sc) shouldBe 896 + } + + test("use overhead if it is greater than minimum value") { + val f = fixture + when(f.sc.executorMemory).thenReturn(4096) + utils.calculateTotalMemory(f.sc) shouldBe 4505 + } + + test("use spark.mesos.executor.memoryOverhead (if set)") { + val f = fixture + when(f.sc.executorMemory).thenReturn(1024) + f.sparkConf.set("spark.mesos.executor.memoryOverhead", "512") + utils.calculateTotalMemory(f.sc) shouldBe 1536 + } + + test("parse a non-empty constraint string correctly") { + val expectedMap = Map( + "tachyon" -> Set("true"), + "zone" -> Set("us-east-1a", "us-east-1b") + ) + utils.parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") should be (expectedMap) + } + + test("parse an empty constraint string correctly") { + utils.parseConstraintString("") shouldBe Map() + } + + test("throw an exception when the input is malformed") { + an[IllegalArgumentException] should be thrownBy + utils.parseConstraintString("tachyon;zone:us-east") + } + + test("empty values for attributes' constraints matches all values") { + val constraintsStr = "tachyon:" + val parsedConstraints = utils.parseConstraintString(constraintsStr) + + parsedConstraints shouldBe Map("tachyon" -> Set()) + + val zoneSet = Value.Set.newBuilder().addItem("us-east-1a").addItem("us-east-1b").build() + val noTachyonOffer = Map("zone" -> zoneSet) + val tachyonTrueOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) + val tachyonFalseOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("false").build()) + + utils.matchesAttributeRequirements(parsedConstraints, noTachyonOffer) shouldBe false + utils.matchesAttributeRequirements(parsedConstraints, tachyonTrueOffer) shouldBe true + utils.matchesAttributeRequirements(parsedConstraints, tachyonFalseOffer) shouldBe true + } + + test("subset match is performed for set attributes") { + val supersetConstraint = Map( + "tachyon" -> Value.Text.newBuilder().setValue("true").build(), + "zone" -> Value.Set.newBuilder() + .addItem("us-east-1a") + .addItem("us-east-1b") + .addItem("us-east-1c") + .build()) + + val zoneConstraintStr = "tachyon:;zone:us-east-1a,us-east-1c" + val parsedConstraints = utils.parseConstraintString(zoneConstraintStr) + + utils.matchesAttributeRequirements(parsedConstraints, supersetConstraint) shouldBe true + } + + test("less than equal match is performed on scalar attributes") { + val offerAttribs = Map("gpus" -> Value.Scalar.newBuilder().setValue(3).build()) + + val ltConstraint = utils.parseConstraintString("gpus:2") + val eqConstraint = utils.parseConstraintString("gpus:3") + val gtConstraint = utils.parseConstraintString("gpus:4") + + utils.matchesAttributeRequirements(ltConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(eqConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(gtConstraint, offerAttribs) shouldBe false + } + + test("contains match is performed for range attributes") { + val offerAttribs = Map("ports" -> Value.Range.newBuilder().setBegin(7000).setEnd(8000).build()) + val ltConstraint = utils.parseConstraintString("ports:6000") + val eqConstraint = utils.parseConstraintString("ports:7500") + val gtConstraint = utils.parseConstraintString("ports:8002") + val multiConstraint = utils.parseConstraintString("ports:5000,7500,8300") + + utils.matchesAttributeRequirements(ltConstraint, offerAttribs) shouldBe false + utils.matchesAttributeRequirements(eqConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(gtConstraint, offerAttribs) shouldBe false + utils.matchesAttributeRequirements(multiConstraint, offerAttribs) shouldBe true + } + + test("equality match is performed for text attributes") { + val offerAttribs = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) + + val trueConstraint = utils.parseConstraintString("tachyon:true") + val falseConstraint = utils.parseConstraintString("tachyon:false") + + utils.matchesAttributeRequirements(trueConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(falseConstraint, offerAttribs) shouldBe false + } + +} diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 1053c6caf7718..480722a5ac182 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -375,6 +375,7 @@ class TestCreateNullValue { // parameters of the closure constructor. This allows us to test whether // null values are created correctly for each type. val nestedClosure = () => { + // scalastyle:off println if (s.toString == "123") { // Don't really output them to avoid noisy println(bo) println(c) @@ -389,6 +390,7 @@ class TestCreateNullValue { val closure = () => { println(getX) } + // scalastyle:on println ClosureCleaner.clean(closure) } nestedClosure() diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index a61ea3918f46a..c7638507c88c6 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -486,11 +486,17 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { // Test for using the util function to change our log levels. test("log4j log level change") { - Utils.setLogLevel(org.apache.log4j.Level.ALL) - assert(log.isInfoEnabled()) - Utils.setLogLevel(org.apache.log4j.Level.ERROR) - assert(!log.isInfoEnabled()) - assert(log.isErrorEnabled()) + val current = org.apache.log4j.Logger.getRootLogger().getLevel() + try { + Utils.setLogLevel(org.apache.log4j.Level.ALL) + assert(log.isInfoEnabled()) + Utils.setLogLevel(org.apache.log4j.Level.ERROR) + assert(!log.isInfoEnabled()) + assert(log.isErrorEnabled()) + } finally { + // Best effort at undoing changes this test made. + Utils.setLogLevel(current) + } } test("deleteRecursively") { @@ -673,4 +679,14 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(!Utils.isInDirectory(nullFile, parentDir)) assert(!Utils.isInDirectory(nullFile, childFile3)) } + + test("circular buffer") { + val buffer = new CircularBuffer(25) + val stream = new java.io.PrintStream(buffer, true, "UTF-8") + + // scalastyle:off println + stream.println("test circular test circular test circular test circular test circular") + // scalastyle:on println + assert(buffer.toString === "t circular test circular\n") + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala index 5a5919fca2469..4f382414a8dd7 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala @@ -103,7 +103,9 @@ private object SizeTrackerSuite { */ def main(args: Array[String]): Unit = { if (args.size < 1) { + // scalastyle:off println println("Usage: SizeTrackerSuite [num elements]") + // scalastyle:on println System.exit(1) } val numElements = args(0).toInt @@ -180,11 +182,13 @@ private object SizeTrackerSuite { baseTimes: Seq[Long], sampledTimes: Seq[Long], unsampledTimes: Seq[Long]): Unit = { + // scalastyle:off println println(s"Average times for $testName (ms):") println(" Base - " + averageTime(baseTimes)) println(" SizeTracker (sampled) - " + averageTime(sampledTimes)) println(" SizeEstimator (unsampled) - " + averageTime(unsampledTimes)) println() + // scalastyle:on println } def time(f: => Unit): Long = { diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index b2f5d9009ee5d..fefa5165db197 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.util.collection import java.lang.{Float => JFloat, Integer => JInteger} import java.util.{Arrays, Comparator} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.random.XORShiftRandom -class SorterSuite extends SparkFunSuite { +class SorterSuite extends SparkFunSuite with Logging { test("equivalent to Arrays.sort") { val rand = new XORShiftRandom(123) @@ -74,7 +74,7 @@ class SorterSuite extends SparkFunSuite { /** Runs an experiment several times. */ def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = { if (skip) { - println(s"Skipped experiment $name.") + logInfo(s"Skipped experiment $name.") return } @@ -86,11 +86,11 @@ class SorterSuite extends SparkFunSuite { while (i < 10) { val time = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare)) next10 += time - println(s"$name: Took $time ms") + logInfo(s"$name: Took $time ms") i += 1 } - println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") + logInfo(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") } /** diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala new file mode 100644 index 0000000000000..dd505dfa7d758 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort + +import org.scalatest.prop.PropertyChecks + +import org.apache.spark.SparkFunSuite + +class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { + + test("String prefix comparator") { + + def testPrefixComparison(s1: String, s2: String): Unit = { + val s1Prefix = PrefixComparators.STRING.computePrefix(s1) + val s2Prefix = PrefixComparators.STRING.computePrefix(s2) + val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) + assert( + (prefixComparisonResult == 0) || + (prefixComparisonResult < 0 && s1 < s2) || + (prefixComparisonResult > 0 && s1 > s2)) + } + + // scalastyle:off + val regressionTests = Table( + ("s1", "s2"), + ("abc", "世界"), + ("你好", "世界"), + ("你好123", "你好122") + ) + // scalastyle:on + + forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) } + forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } + } +} diff --git a/data/mllib/pic_data.txt b/data/mllib/pic_data.txt new file mode 100644 index 0000000000000..fcfef8cd19131 --- /dev/null +++ b/data/mllib/pic_data.txt @@ -0,0 +1,19 @@ +0 1 1.0 +0 2 1.0 +0 3 1.0 +1 2 1.0 +1 3 1.0 +2 3 1.0 +3 4 0.1 +4 5 1.0 +4 15 1.0 +5 6 1.0 +6 7 1.0 +7 8 1.0 +8 9 1.0 +9 10 1.0 +10 11 1.0 +11 12 1.0 +12 13 1.0 +13 14 1.0 +14 15 1.0 diff --git a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala index fc03fec9866a6..61d91c70e9709 100644 --- a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -59,3 +60,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala index 0be8e64fbfabd..9f7ae75d0b477 100644 --- a/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -37,3 +38,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala b/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala index 24c7f8d667296..2f0b6ef9a5672 100644 --- a/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala +++ b/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import org.apache.spark.{SparkContext, SparkConf} @@ -51,3 +52,4 @@ object GraphXApp { println("Test succeeded") } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala index 5111bc0adb772..4a980ec071ae4 100644 --- a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala +++ b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -55,3 +56,4 @@ object SparkSqlExample { sc.stop() } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala index 9f85066501472..adc25b57d6aa5 100644 --- a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -31,3 +32,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala index cc86ef45858c9..69c1154dc0955 100644 --- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala +++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -57,3 +58,4 @@ object SparkSqlExample { sc.stop() } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala b/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala index 58a662bd9b2e8..d6a074687f4a1 100644 --- a/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala +++ b/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -61,3 +62,4 @@ object SparkStreamingExample { ssc.stop() } } +// scalastyle:on println diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 54274a83f6d66..30190dcd41ec5 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -118,13 +118,13 @@ if [[ ! "$@" =~ --skip-publish ]]; then rm -rf $SPARK_REPO - build/mvn -DskipTests -Pyarn -Phive \ + build/mvn -DskipTests -Pyarn -Phive -Prelease\ -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-version-to-2.11.sh - build/mvn -DskipTests -Pyarn -Phive \ + build/mvn -DskipTests -Pyarn -Phive -Prelease\ -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index cf827ce89b857..4a17d48d8171d 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -47,6 +47,12 @@ JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "") # ASF JIRA password JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "") +# OAuth key used for issuing requests against the GitHub API. If this is not defined, then requests +# will be unauthenticated. You should only need to configure this if you find yourself regularly +# exceeding your IP's unauthenticated request rate limit. You can create an OAuth key at +# https://github.com/settings/tokens. This script only requires the "public_repo" scope. +GITHUB_OAUTH_KEY = os.environ.get("GITHUB_OAUTH_KEY") + GITHUB_BASE = "https://github.com/apache/spark/pull" GITHUB_API_BASE = "https://api.github.com/repos/apache/spark" @@ -58,9 +64,17 @@ def get_json(url): try: - return json.load(urllib2.urlopen(url)) + request = urllib2.Request(url) + if GITHUB_OAUTH_KEY: + request.add_header('Authorization', 'token %s' % GITHUB_OAUTH_KEY) + return json.load(urllib2.urlopen(request)) except urllib2.HTTPError as e: - print "Unable to fetch URL, exiting: %s" % url + if "X-RateLimit-Remaining" in e.headers and e.headers["X-RateLimit-Remaining"] == '0': + print "Exceeded the GitHub API rate limit; see the instructions in " + \ + "dev/merge_spark_pr.py to configure an OAuth token for making authenticated " + \ + "GitHub requests." + else: + print "Unable to fetch URL, exiting: %s" % url sys.exit(-1) diff --git a/dev/run-tests b/dev/run-tests index a00d9f0c27639..257d1e8d50bb4 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -20,4 +20,4 @@ FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" -exec python -u ./dev/run-tests.py +exec python -u ./dev/run-tests.py "$@" diff --git a/dev/run-tests.py b/dev/run-tests.py index e5c897b94d167..1f0d218514f92 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -19,6 +19,7 @@ from __future__ import print_function import itertools +from optparse import OptionParser import os import re import sys @@ -95,8 +96,8 @@ def determine_modules_to_test(changed_modules): ['examples', 'graphx'] >>> x = sorted(x.name for x in determine_modules_to_test([modules.sql])) >>> x # doctest: +NORMALIZE_WHITESPACE - ['examples', 'hive-thriftserver', 'mllib', 'pyspark-core', 'pyspark-ml', \ - 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming', 'sparkr', 'sql'] + ['examples', 'hive-thriftserver', 'mllib', 'pyspark-ml', \ + 'pyspark-mllib', 'pyspark-sql', 'sparkr', 'sql'] """ # If we're going to have to run all of the tests, then we can just short-circuit # and return 'root'. No module depends on root, so if it appears then it will be @@ -292,7 +293,8 @@ def build_spark_sbt(hadoop_version): build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["package", "assembly/assembly", - "streaming-kafka-assembly/assembly"] + "streaming-kafka-assembly/assembly", + "streaming-flume-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", @@ -360,12 +362,13 @@ def run_scala_tests(build_tool, hadoop_version, test_modules): run_scala_tests_sbt(test_modules, test_profiles) -def run_python_tests(test_modules): +def run_python_tests(test_modules, parallelism): set_title_and_block("Running PySpark tests", "BLOCK_PYSPARK_UNIT_TESTS") command = [os.path.join(SPARK_HOME, "python", "run-tests")] if test_modules != [modules.root]: command.append("--modules=%s" % ','.join(m.name for m in test_modules)) + command.append("--parallelism=%i" % parallelism) run_cmd(command) @@ -379,7 +382,25 @@ def run_sparkr_tests(): print("Ignoring SparkR tests as R was not found in PATH") +def parse_opts(): + parser = OptionParser( + prog="run-tests" + ) + parser.add_option( + "-p", "--parallelism", type="int", default=4, + help="The number of suites to test in parallel (default %default)" + ) + + (opts, args) = parser.parse_args() + if args: + parser.error("Unsupported arguments: %s" % ' '.join(args)) + if opts.parallelism < 1: + parser.error("Parallelism cannot be less than 1") + return opts + + def main(): + opts = parse_opts() # Ensure the user home directory (HOME) is valid and is an absolute directory if not USER_HOME or not os.path.isabs(USER_HOME): print("[error] Cannot determine your home directory as an absolute path;", @@ -461,7 +482,7 @@ def main(): modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: - run_python_tests(modules_with_python_tests) + run_python_tests(modules_with_python_tests, opts.parallelism) if any(m.should_run_r_tests for m in test_modules): run_sparkr_tests() diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index efe3a897e9c10..993583e2f4119 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -203,7 +203,7 @@ def contains_file(self, filename): streaming_flume = Module( - name="streaming_flume", + name="streaming-flume", dependencies=[streaming], source_file_regexes=[ "external/flume", @@ -214,6 +214,15 @@ def contains_file(self, filename): ) +streaming_flume_assembly = Module( + name="streaming-flume-assembly", + dependencies=[streaming_flume, streaming_flume_sink], + source_file_regexes=[ + "external/flume-assembly", + ] +) + + mllib = Module( name="mllib", dependencies=[streaming, sql], @@ -241,7 +250,7 @@ def contains_file(self, filename): pyspark_core = Module( name="pyspark-core", - dependencies=[mllib, streaming, streaming_kafka], + dependencies=[], source_file_regexes=[ "python/(?!pyspark/(ml|mllib|sql|streaming))" ], @@ -281,7 +290,7 @@ def contains_file(self, filename): pyspark_streaming = Module( name="pyspark-streaming", - dependencies=[pyspark_core, streaming, streaming_kafka], + dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly], source_file_regexes=[ "python/pyspark/streaming" ], diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py index ad9b0cc89e4ab..12bd0bf3a4fe9 100644 --- a/dev/sparktestsupport/shellutils.py +++ b/dev/sparktestsupport/shellutils.py @@ -15,6 +15,7 @@ # limitations under the License. # +from __future__ import print_function import os import shutil import subprocess diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile index 5956d59130fbf..5dbdb8b22a44f 100644 --- a/docker/spark-test/base/Dockerfile +++ b/docker/spark-test/base/Dockerfile @@ -17,13 +17,13 @@ FROM ubuntu:precise -RUN echo "deb http://archive.ubuntu.com/ubuntu precise main universe" > /etc/apt/sources.list - # Upgrade package index -RUN apt-get update - # install a few other useful packages plus Open Jdk 7 -RUN apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server +# Remove unneeded /var/lib/apt/lists/* after install to reduce the +# docker image size (by ~30MB) +RUN apt-get update && \ + apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \ + rm -rf /var/lib/apt/lists/* ENV SCALA_VERSION 2.10.4 ENV CDH_VERSION cdh4 diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 6073b3626c45b..15ceda11a8a80 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -63,6 +63,51 @@ puts "cp -r " + source + "/. " + dest cp_r(source + "/.", dest) + + # Begin updating JavaDoc files for badge post-processing + puts "Updating JavaDoc files for badge post-processing" + js_script_start = '' + + javadoc_files = Dir["./" + dest + "/**/*.html"] + javadoc_files.each do |javadoc_file| + # Determine file depths to reference js files + slash_count = javadoc_file.count "/" + i = 3 + path_to_js_file = "" + while (i < slash_count) do + path_to_js_file = path_to_js_file + "../" + i += 1 + end + + # Create script elements to reference js files + javadoc_jquery_script = js_script_start + path_to_js_file + "lib/jquery" + js_script_end; + javadoc_api_docs_script = js_script_start + path_to_js_file + "lib/api-javadocs" + js_script_end; + javadoc_script_elements = javadoc_jquery_script + javadoc_api_docs_script + + # Add script elements to JavaDoc files + javadoc_file_content = File.open(javadoc_file, "r") { |f| f.read } + javadoc_file_content = javadoc_file_content.sub("", javadoc_script_elements + "") + File.open(javadoc_file, "w") { |f| f.puts(javadoc_file_content) } + + end + # End updating JavaDoc files for badge post-processing + + puts "Copying jquery.js from Scala API to Java API for page post-processing of badges" + jquery_src_file = "./api/scala/lib/jquery.js" + jquery_dest_file = "./api/java/lib/jquery.js" + mkdir_p("./api/java/lib") + cp(jquery_src_file, jquery_dest_file) + + puts "Copying api_javadocs.js to Java API for page post-processing of badges" + api_javadocs_src_file = "./js/api-javadocs.js" + api_javadocs_dest_file = "./api/java/lib/api-javadocs.js" + cp(api_javadocs_src_file, api_javadocs_dest_file) + + puts "Appending content of api-javadocs.css to JavaDoc stylesheet.css for badge styles" + css = File.readlines("./css/api-javadocs.css") + css_file = dest + "/stylesheet.css" + File.open(css_file, 'a') { |f| f.write("\n" + css.join()) } end # Build Sphinx docs for Python diff --git a/docs/configuration.md b/docs/configuration.md index affcd21514d88..8a186ee51c1ca 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -137,10 +137,10 @@ of the most common options to set are: spark.driver.memory - 512m + 1g Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 512m, 2g). + (e.g. 1g, 2g).
Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. @@ -665,7 +665,7 @@ Apart from these, the following properties are also available, and may be useful Initial size of Kryo's serialization buffer. Note that there will be one buffer per core on each worker. This buffer will grow up to - spark.kryoserializer.buffer.max.mb if needed. + spark.kryoserializer.buffer.max if needed. @@ -1007,9 +1007,9 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.numRetries 3 + Number of times to retry before an RPC task gives up. An RPC task will run at most times of this number. - @@ -1029,8 +1029,8 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.lookupTimeout 120s - Duration for an RPC remote endpoint lookup operation to wait before timing out. + Duration for an RPC remote endpoint lookup operation to wait before timing out. @@ -1206,7 +1206,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.cachedExecutorIdleTimeout - 2 * executorIdleTimeout + infinity If dynamic allocation is enabled and an executor which has cached data blocks has been idle for more than this duration, the executor will be removed. For more details, see this @@ -1222,7 +1222,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.maxExecutors - Integer.MAX_VALUE + infinity Upper bound for the number of executors if dynamic allocation is enabled. diff --git a/docs/css/api-docs.css b/docs/css/api-docs.css index b2d1d7f869790..7cf222aad24f6 100644 --- a/docs/css/api-docs.css +++ b/docs/css/api-docs.css @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + /* Dynamically injected style for the API docs */ .developer { diff --git a/docs/css/api-javadocs.css b/docs/css/api-javadocs.css new file mode 100644 index 0000000000000..832e92609e011 --- /dev/null +++ b/docs/css/api-javadocs.css @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Dynamically injected style for the API docs */ + +.badge { + font-family: Arial, san-serif; + float: right; + margin: 4px; + /* The following declarations are taken from the ScalaDoc template.css */ + display: inline-block; + padding: 2px 4px; + font-size: 11.844px; + font-weight: bold; + line-height: 14px; + color: #ffffff; + text-shadow: 0 -1px 0 rgba(0, 0, 0, 0.25); + white-space: nowrap; + vertical-align: baseline; + background-color: #999999; + padding-right: 9px; + padding-left: 9px; + -webkit-border-radius: 9px; + -moz-border-radius: 9px; + border-radius: 9px; +} + +.developer { + background-color: #44751E; +} + +.experimental { + background-color: #257080; +} + +.alphaComponent { + background-color: #bb0000; +} diff --git a/docs/js/api-javadocs.js b/docs/js/api-javadocs.js new file mode 100644 index 0000000000000..ead13d6e5fa7c --- /dev/null +++ b/docs/js/api-javadocs.js @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Dynamically injected post-processing code for the API docs */ + +$(document).ready(function() { + addBadges(":: AlphaComponent ::", 'Alpha Component'); + addBadges(":: DeveloperApi ::", 'Developer API'); + addBadges(":: Experimental ::", 'Experimental'); +}); + +function addBadges(tag, html) { + var tags = $(".block:contains(" + tag + ")") + + // Remove identifier tags + tags.each(function(index) { + var oldHTML = $(this).html(); + var newHTML = oldHTML.replace(tag, ""); + $(this).html(newHTML); + }); + + // Add html badge tags + tags.each(function(index) { + if ($(this).parent().is('td.colLast')) { + $(this).parent().prepend(html); + } else if ($(this).parent('li.blockList') + .parent('ul.blockList') + .parent('div.description') + .parent().is('div.contentContainer')) { + var contentContainer = $(this).parent('li.blockList') + .parent('ul.blockList') + .parent('div.description') + .parent('div.contentContainer') + var header = contentContainer.prev('div.header'); + if (header.length > 0) { + header.prepend(html); + } else { + contentContainer.prepend(html); + } + } else if ($(this).parent().is('li.blockList')) { + $(this).parent().prepend(html); + } else { + $(this).prepend(html); + } + }); +} diff --git a/docs/ml-features.md b/docs/ml-features.md index f88c0248c1a8a..54068debe2159 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -288,6 +288,94 @@ for words_label in wordsDataFrame.select("words", "label").take(3): +## $n$-gram + +An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (typically words) for some integer $n$. The `NGram` class can be used to transform input features into $n$-grams. + +`NGram` takes as input a sequence of strings (e.g. the output of a [Tokenizer](ml-features.html#tokenizer). The parameter `n` is used to determine the number of terms in each $n$-gram. The output will consist of a sequence of $n$-grams where each $n$-gram is represented by a space-delimited string of $n$ consecutive words. If the input sequence contains fewer than `n` strings, no output is produced. + +

+
+
+ +
+ +[`NGram`](api/scala/index.html#org.apache.spark.ml.feature.NGram) takes an input column name, an output column name, and an optional length parameter n (n=2 by default). + +{% highlight scala %} +import org.apache.spark.ml.feature.NGram + +val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) +)).toDF("label", "words") + +val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") +val ngramDataFrame = ngram.transform(wordDataFrame) +ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) +{% endhighlight %} +
+ +
+ +[`NGram`](api/java/org/apache/spark/ml/feature/NGram.html) takes an input column name, an output column name, and an optional length parameter n (n=2 by default). + +{% highlight java %} +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(0D, Lists.newArrayList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1D, Lists.newArrayList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2D, Lists.newArrayList("Logistic", "regression", "models", "are", "neat")) +)); +StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) +}); +DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); +NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); +DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); +for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); +} +{% endhighlight %} +
+ +
+ +[`NGram`](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) takes an input column name, an output column name, and an optional length parameter n (n=2 by default). + +{% highlight python %} +from pyspark.ml.feature import NGram + +wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) +], ["label", "words"]) +ngram = NGram(inputCol="words", outputCol="ngrams") +ngramDataFrame = ngram.transform(wordDataFrame) +for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) +{% endhighlight %} +
+
+ + ## Binarizer Binarization is the process of thresholding numerical features to binary features. As some probabilistic estimators make assumption that the input data is distributed according to [Bernoulli distribution](http://en.wikipedia.org/wiki/Bernoulli_distribution), a binarizer is useful for pre-processing the input data with continuous numerical features. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index dcaa3784be874..d72dc20a5ad6e 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -327,11 +327,17 @@ which contains the computed clustering assignments. import org.apache.spark.mllib.clustering.{PowerIterationClustering, PowerIterationClusteringModel} import org.apache.spark.mllib.linalg.Vectors -val similarities: RDD[(Long, Long, Double)] = ... +// Load and parse the data +val data = sc.textFile("data/mllib/pic_data.txt") +val similarities = data.map { line => + val parts = line.split(' ') + (parts(0).toLong, parts(1).toLong, parts(2).toDouble) +} +// Cluster the data into two classes using PowerIterationClustering val pic = new PowerIterationClustering() - .setK(3) - .setMaxIterations(20) + .setK(2) + .setMaxIterations(10) val model = pic.run(similarities) model.assignments.foreach { a => @@ -363,11 +369,22 @@ import scala.Tuple2; import scala.Tuple3; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.clustering.PowerIterationClustering; import org.apache.spark.mllib.clustering.PowerIterationClusteringModel; -JavaRDD> similarities = ... +// Load and parse the data +JavaRDD data = sc.textFile("data/mllib/pic_data.txt"); +JavaRDD> similarities = data.map( + new Function>() { + public Tuple3 call(String line) { + String[] parts = line.split(" "); + return new Tuple3<>(new Long(parts[0]), new Long(parts[1]), new Double(parts[2])); + } + } +); +// Cluster the data into two classes using PowerIterationClustering PowerIterationClustering pic = new PowerIterationClustering() .setK(2) .setMaxIterations(10); @@ -383,6 +400,35 @@ PowerIterationClusteringModel sameModel = PowerIterationClusteringModel.load(sc. {% endhighlight %}
+
+ +[`PowerIterationClustering`](api/python/pyspark.mllib.html#pyspark.mllib.clustering.PowerIterationClustering) +implements the PIC algorithm. +It takes an `RDD` of `(srcId: Long, dstId: Long, similarity: Double)` tuples representing the +affinity matrix. +Calling `PowerIterationClustering.run` returns a +[`PowerIterationClusteringModel`](api/python/pyspark.mllib.html#pyspark.mllib.clustering.PowerIterationClustering), +which contains the computed clustering assignments. + +{% highlight python %} +from __future__ import print_function +from pyspark.mllib.clustering import PowerIterationClustering, PowerIterationClusteringModel + +# Load and parse the data +data = sc.textFile("data/mllib/pic_data.txt") +similarities = data.map(lambda line: tuple([float(x) for x in line.split(' ')])) + +# Cluster the data into two classes using PowerIterationClustering +model = PowerIterationClustering.train(similarities, 2, 10) + +model.assignments().foreach(lambda x: print(str(x.id) + " -> " + str(x.cluster))) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") +{% endhighlight %} +
+
## Latent Dirichlet allocation (LDA) @@ -401,7 +447,7 @@ It supports different inference algorithms via `setOptimizer` function. EMLDAOpt on the likelihood function and yields comprehensive results, while OnlineLDAOptimizer uses iterative mini-batch sampling for [online variational inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf) and is generally memory friendly. After fitting on the documents, LDA provides: * Topics: Inferred topics, each of which is a probability distribution over terms (words). -* Topic distributions for documents: For each document in the training set, LDA gives a probability distribution over topics. (EM only) +* Topic distributions for documents: For each non empty document in the training set, LDA gives a probability distribution over topics. (EM only). Note that for empty documents, we don't create the topic distributions. (EM only) LDA takes the following parameters: diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index dfdf6216b270c..eedc23424ad54 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -77,7 +77,7 @@ val ratings = data.map(_.split(',') match { case Array(user, item, rate) => // Build the recommendation model using ALS val rank = 10 -val numIterations = 20 +val numIterations = 10 val model = ALS.train(ratings, rank, numIterations, 0.01) // Evaluate the model on rating data @@ -149,7 +149,7 @@ public class CollaborativeFiltering { // Build the recommendation model using ALS int rank = 10; - int numIterations = 20; + int numIterations = 10; MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); // Evaluate the model on rating data @@ -210,7 +210,7 @@ ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l # Build the recommendation model using Alternating Least Squares rank = 10 -numIterations = 20 +numIterations = 10 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index d824dab1d7f7b..3aa040046fca5 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -226,7 +226,8 @@ examples = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") A local matrix has integer-typed row and column indices and double-typed values, stored on a single machine. MLlib supports dense matrices, whose entry values are stored in a single double array in -column major. For example, the following matrix `\[ \begin{pmatrix} +column-major order, and sparse matrices, whose non-zero entry values are stored in the Compressed Sparse +Column (CSC) format in column-major order. For example, the following dense matrix `\[ \begin{pmatrix} 1.0 & 2.0 \\ 3.0 & 4.0 \\ 5.0 & 6.0 @@ -238,28 +239,33 @@ is stored in a one-dimensional array `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]` with the m
The base class of local matrices is -[`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide one -implementation: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix). +[`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide two +implementations: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix), +and [`SparseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.SparseMatrix). We recommend using the factory methods implemented in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices$) to create local -matrices. +matrices. Remember, local matrices in MLlib are stored in column-major order. {% highlight scala %} import org.apache.spark.mllib.linalg.{Matrix, Matrices} // Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) val dm: Matrix = Matrices.dense(3, 2, Array(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +val sm: Matrix = Matrices.sparse(3, 2, Array(0, 1, 3), Array(0, 2, 1), Array(9, 6, 8)) {% endhighlight %}
The base class of local matrices is -[`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide one -implementation: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html). +[`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide two +implementations: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html), +and [`SparseMatrix`](api/java/org/apache/spark/mllib/linalg/SparseMatrix.html). We recommend using the factory methods implemented in [`Matrices`](api/java/org/apache/spark/mllib/linalg/Matrices.html) to create local -matrices. +matrices. Remember, local matrices in MLlib are stored in column-major order. {% highlight java %} import org.apache.spark.mllib.linalg.Matrix; @@ -267,6 +273,30 @@ import org.apache.spark.mllib.linalg.Matrices; // Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) Matrix dm = Matrices.dense(3, 2, new double[] {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +Matrix sm = Matrices.sparse(3, 2, new int[] {0, 1, 3}, new int[] {0, 2, 1}, new double[] {9, 6, 8}); +{% endhighlight %} +
+ +
+ +The base class of local matrices is +[`Matrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrix), and we provide two +implementations: [`DenseMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.DenseMatrix), +and [`SparseMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.SparseMatrix). +We recommend using the factory methods implemented +in [`Matrices`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrices) to create local +matrices. Remember, local matrices in MLlib are stored in column-major order. + +{% highlight python %} +import org.apache.spark.mllib.linalg.{Matrix, Matrices} + +// Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) +dm2 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 2, 1], [9, 6, 8]) {% endhighlight %}
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 83e937635a55b..a69e41e2a1936 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -384,7 +384,7 @@ data2 = labels.zip(normalizer2.transform(features)) [Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. ### ChiSqSelector -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which are most closely related to the label. +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which the class label depends on the most. This is akin to yielding the features with the most predictive power. #### Model Fitting @@ -405,7 +405,7 @@ Note that the user can also construct a `ChiSqSelectorModel` by hand by providin #### Example -The following example shows the basic use of ChiSqSelector. +The following example shows the basic use of ChiSqSelector. The data set used has a feature matrix consisting of greyscale values that vary from 0 to 255 for each feature.
@@ -419,10 +419,11 @@ import org.apache.spark.mllib.feature.ChiSqSelector // Load some data in libsvm format val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") // Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category val discretizedData = data.map { lp => - LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => x / 16 } ) ) + LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => (x / 16).floor } ) ) } -// Create ChiSqSelector that will select 50 features +// Create ChiSqSelector that will select top 50 of 692 features val selector = new ChiSqSelector(50) // Create ChiSqSelector model (selecting features) val transformer = selector.fit(discretizedData) @@ -451,19 +452,20 @@ JavaRDD points = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache(); // Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category JavaRDD discretizedData = points.map( new Function() { @Override public LabeledPoint call(LabeledPoint lp) { final double[] discretizedFeatures = new double[lp.features().size()]; for (int i = 0; i < lp.features().size(); ++i) { - discretizedFeatures[i] = lp.features().apply(i) / 16; + discretizedFeatures[i] = Math.floor(lp.features().apply(i) / 16); } return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); } }); -// Create ChiSqSelector that will select 50 features +// Create ChiSqSelector that will select top 50 of 692 features ChiSqSelector selector = new ChiSqSelector(50); // Create ChiSqSelector model (selecting features) final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 3dc8cc902fa72..3927d65fbf8fb 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -499,7 +499,7 @@ Note that the Python API does not yet support multiclass classification and mode will in the future. {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel from pyspark.mllib.regression import LabeledPoint from numpy import array @@ -518,6 +518,10 @@ model = LogisticRegressionWithLBFGS.train(parsedData) labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LogisticRegressionModel.load(sc, "myModelPath") {% endhighlight %}
@@ -668,7 +672,7 @@ values. We compute the mean squared error at the end to evaluate Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel from numpy import array # Load and parse the data @@ -686,6 +690,10 @@ model = LinearRegressionWithSGD.train(parsedData) valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y) / valuesAndPreds.count() print("Mean Squared Error = " + str(MSE)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LinearRegressionModel.load(sc, "myModelPath") {% endhighlight %} @@ -768,6 +776,58 @@ will get better! +
+ +First, we import the necessary classes for parsing our input data and creating the model. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.regression import StreamingLinearRegressionWithSGD +{% endhighlight %} + +Then we make input streams for training and testing data. We assume a StreamingContext `ssc` +has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) +for more info. For this example, we use labeled points in training and testing streams, +but in practice you will likely want to use unlabeled vectors for test data. + +{% highlight python %} +def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + +trainingData = ssc.textFileStream("/training/data/dir").map(parse).cache() +testData = ssc.textFileStream("/testing/data/dir").map(parse) +{% endhighlight %} + +We create our model by initializing the weights to 0 + +{% highlight python %} +numFeatures = 3 +model = StreamingLinearRegressionWithSGD() +model.setInitialWeights([0.0, 0.0, 0.0]) +{% endhighlight %} + +Now we register the streams for training and testing and start the job. + +{% highlight python %} +model.trainOn(trainingData) +print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + +ssc.start() +ssc.awaitTermination() +{% endhighlight %} + +We can now save text files with data to the training or testing folders. +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions +will get better! + +
+ diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index bf6d124fd5d8d..e73bd30f3a90a 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -119,7 +119,7 @@ used for evaluation and prediction. Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.classification import NaiveBayes +from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel from pyspark.mllib.linalg import Vectors from pyspark.mllib.regression import LabeledPoint @@ -140,6 +140,10 @@ model = NaiveBayes.train(training, 1.0) # Make prediction and test accuracy. predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label)) accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() + +# Save and load model +model.save(sc, "myModelPath") +sameModel = NaiveBayesModel.load(sc, "myModelPath") {% endhighlight %} diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 887eae7f4f07b..de5d6485f9b5f 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -283,7 +283,7 @@ approxSample = data.sampleByKey(False, fractions); Hypothesis testing is a powerful tool in statistics to determine whether a result is statistically significant, whether this result occurred by chance or not. MLlib currently supports Pearson's -chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine +chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine whether the goodness of fit or the independence test is conducted. The goodness of fit test requires an input type of `Vector`, whereas the independence test requires a `Matrix` as input. @@ -422,6 +422,41 @@ for i, result in enumerate(featureTestResults): +Additionally, MLlib provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test +for equality of probability distributions. By providing the name of a theoretical distribution +(currently solely supported for the normal distribution) and its parameters, or a function to +calculate the cumulative distribution according to a given theoretical distribution, the user can +test the null hypothesis that their sample is drawn from that distribution. In the case that the +user tests against the normal distribution (`distName="norm"`), but does not provide distribution +parameters, the test initializes to the standard normal distribution and logs an appropriate +message. + +
+
+[`Statistics`](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) provides methods to +run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run +and interpret the hypothesis tests. + +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.stat.Statistics._ + +val data: RDD[Double] = ... // an RDD of sample data + +// run a KS test for the sample versus a standard normal distribution +val testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) +println(testResult) // summary of the test including the p-value, test statistic, + // and null hypothesis + // if our p-value indicates significance, we can reject the null hypothesis + +// perform a KS test using a cumulative distribution function of our making +val myCDF: Double => Double = ... +val testResult2 = Statistics.kolmogorovSmirnovTest(data, myCDF) +{% endhighlight %} +
+
+ + ## Random data generation Random data generation is useful for randomized algorithms, prototyping, and performance testing. diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 5f1d6daeb27f0..1f915d8ea1d73 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -184,6 +184,14 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. + +{% highlight scala %} +conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +{% endhighlight %} + +For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. + # Mesos Docker Support Spark can make use of a Mesos Docker containerizer by setting the property `spark.mesos.executor.docker.image` @@ -298,6 +306,20 @@ See the [configuration page](configuration.html) for information on Spark config the final overhead will be this value. + + spark.mesos.constraints + Attribute based constraints to be matched against when accepting resource offers. + + Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes. +
    +
  • Scalar constraints are matched with "less than equal" semantics i.e. value in the constraint must be less than or equal to the value in the resource offer.
  • +
  • Range constraints are matched with "contains" semantics i.e. value in the constraint must be within the resource offer's value.
  • +
  • Set constraints are matched with "subset of" semantics i.e. value in the constraint must be a subset of the resource offer's value.
  • +
  • Text constraints are metched with "equality" semantics i.e. value in the constraint must be exactly equal to the resource offer's value.
  • +
  • In case there is no value present as a part of the constraint any offer with the corresponding attribute will be accepted (without value check).
  • +
+ + # Troubleshooting and Debugging diff --git a/docs/sparkr.md b/docs/sparkr.md index 095ea4308cfeb..4385a4eeacd5c 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -68,7 +68,7 @@ you can specify the packages with the `packages` argument.
{% highlight r %} -sc <- sparkR.init(packages="com.databricks:spark-csv_2.11:1.0.3") +sc <- sparkR.init(sparkPackages="com.databricks:spark-csv_2.11:1.0.3") sqlContext <- sparkRSQL.init(sc) {% endhighlight %}
@@ -116,7 +116,7 @@ sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results <- hiveContext.sql("FROM src SELECT key, value") +results <- sql(hiveContext, "FROM src SELECT key, value") # results is now a DataFrame head(results) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2786e3d2cd6bf..5838bc172fe86 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -828,7 +828,7 @@ using this syntax. {% highlight scala %} val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") -df.select("name", "age").write.format("json").save("namesAndAges.json") +df.select("name", "age").write.format("parquet").save("namesAndAges.parquet") {% endhighlight %} @@ -1637,7 +1637,7 @@ sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results = sqlContext.sql("FROM src SELECT key, value").collect() +results <- collect(sql(sqlContext, "FROM src SELECT key, value")) {% endhighlight %} @@ -1773,9 +1773,9 @@ the Data Sources API. The following options are supported:
{% highlight scala %} -val jdbcDF = sqlContext.load("jdbc", Map( - "url" -> "jdbc:postgresql:dbserver", - "dbtable" -> "schema.tablename")) +val jdbcDF = sqlContext.read.format("jdbc").options( + Map("url" -> "jdbc:postgresql:dbserver", + "dbtable" -> "schema.tablename")).load() {% endhighlight %}
@@ -1788,7 +1788,7 @@ Map options = new HashMap(); options.put("url", "jdbc:postgresql:dbserver"); options.put("dbtable", "schema.tablename"); -DataFrame jdbcDF = sqlContext.load("jdbc", options) +DataFrame jdbcDF = sqlContext.read().format("jdbc"). options(options).load(); {% endhighlight %} @@ -1798,7 +1798,7 @@ DataFrame jdbcDF = sqlContext.load("jdbc", options) {% highlight python %} -df = sqlContext.load(source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") +df = sqlContext.read.format('jdbc').options(url='jdbc:postgresql:dbserver', dbtable='schema.tablename').load() {% endhighlight %} diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 8d6e74370918f..de0461010daec 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -58,6 +58,15 @@ configuring Flume agents. See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java). +
+ from pyspark.streaming.flume import FlumeUtils + + flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/flume_wordcount.py). +
Note that the hostname should be the same as the one used by the resource manager in the @@ -135,6 +144,15 @@ configuring Flume agents. JavaReceiverInputDStreamflumeStream = FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]); +
+ from pyspark.streaming.flume import FlumeUtils + + addresses = [([sink machine hostname 1], [sink port 1]), ([sink machine hostname 2], [sink port 2])] + flumeStream = FlumeUtils.createPollingStream(streamingContext, addresses) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils). +
See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index b784d59666fec..2f3013b533eb0 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, *only* Kafka is available in the Python API. We will add more advanced sources in the Python API in future. +out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts @@ -854,6 +854,8 @@ it with new information. To use this, you will have to do two steps. 1. Define the state update function - Specify with a function how to update the state using the previous state and the new values from an input stream. +In every batch, Spark will apply the state update function for all existing keys, regardless of whether they have new data in a batch or not. If the update function returns `None` then the key-value pair will be eliminated. + Let's illustrate this with an example. Say you want to maintain a running count of each word seen in a text data stream. Here, the running count is the state and it is an integer. We define the update function as: diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 18ccbc0a3edd0..7c83d68e7993e 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -125,7 +125,7 @@ def setup_external_libs(libs): ) with open(tgz_file_path, "wb") as tgz_file: tgz_file.write(download_stream.read()) - with open(tgz_file_path) as tar: + with open(tgz_file_path, "rb") as tar: if hashlib.md5(tar.read()).hexdigest() != lib["md5"]: print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr) sys.exit(1) @@ -325,14 +325,16 @@ def parse_args(): home_dir = os.getenv('HOME') if home_dir is None or not os.path.isfile(home_dir + '/.boto'): if not os.path.isfile('/etc/boto.cfg'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", - file=stderr) - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", - file=stderr) - sys.exit(1) + # If there is no boto config, check aws credentials + if not os.path.isfile(home_dir + '/.aws/credentials'): + if os.getenv('AWS_ACCESS_KEY_ID') is None: + print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", + file=stderr) + sys.exit(1) + if os.getenv('AWS_SECRET_ACCESS_KEY') is None: + print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", + file=stderr) + sys.exit(1) return (opts, action, cluster_name) @@ -791,7 +793,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', - 'mapreduce', 'spark-standalone', 'tachyon'] + 'mapreduce', 'spark-standalone', 'tachyon', 'rstudio'] if opts.hadoop_major_version == "1": modules = list(filter(lambda x: x != "mapreduce", modules)) @@ -1153,8 +1155,8 @@ def ssh(host, opts, command): # If this was an ssh failure, provide the user with hints. if e.returncode == 255: raise UsageError( - "Failed to SSH to remote host {0}.\n" + - "Please check that you have provided the correct --identity-file and " + + "Failed to SSH to remote host {0}.\n" + "Please check that you have provided the correct --identity-file and " "--key-pair parameters and try again.".format(host)) else: raise e diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py new file mode 100644 index 0000000000000..55afe1b207fe0 --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.evaluation import MulticlassMetrics +from pyspark.ml.feature import StringIndexer +from pyspark.mllib.util import MLUtils +from pyspark.sql import SQLContext + +""" +A simple example demonstrating a logistic regression with elastic net regularization Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression.py +""" + +if __name__ == "__main__": + + if len(sys.argv) > 1: + print("Usage: logistic_regression", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonLogisticRegressionExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [training, test] = td.randomSplit([0.7, 0.3]) + + lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") + lr.setElasticNetParam(0.8) + + # Fit the model + lrModel = lr.fit(training) + + predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + sc.stop() diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py new file mode 100644 index 0000000000000..091b64d8c4af4 --- /dev/null +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: flume_wordcount.py + + To run this on your local machine, you need to setup Flume first, see + https://flume.apache.org/documentation.html + + and then run the example + `$ bin/spark-submit --jars external/flume-assembly/target/scala-*/\ + spark-streaming-flume-assembly-*.jar examples/src/main/python/streaming/flume_wordcount.py \ + localhost 12345 +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.flume import FlumeUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: flume_wordcount.py ", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonStreamingFlumeWordCount") + ssc = StreamingContext(sc, 1) + + hostname, port = sys.argv[1:] + kvs = FlumeUtils.createStream(ssc, hostname, int(port)) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/r/data-manipulation.R b/examples/src/main/r/data-manipulation.R new file mode 100644 index 0000000000000..aa2336e300a91 --- /dev/null +++ b/examples/src/main/r/data-manipulation.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# For this example, we shall use the "flights" dataset +# The dataset consists of every flight departing Houston in 2011. +# The data set is made up of 227,496 rows x 14 columns. + +# To run this example use +# ./bin/sparkR --packages com.databricks:spark-csv_2.10:1.0.3 +# examples/src/main/r/data-manipulation.R + +# Load SparkR library into your R session +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: data-manipulation.R % + summarize(avg(flightsDF$dep_delay), avg(flightsDF$arr_delay)) -> dailyDelayDF + + # Print the computed data frame + head(dailyDelayDF) +} + +# Stop the SparkContext now +sparkR.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 4c129dbe2d12d..d812262fd87dc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -52,3 +53,4 @@ object BroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 023bb3ee2d108..36832f51d2ad4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ + // scalastyle:off println package org.apache.spark.examples import java.nio.ByteBuffer @@ -140,3 +141,4 @@ object CassandraCQLTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index ec689474aecb0..96ef3e198e380 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.nio.ByteBuffer @@ -130,6 +131,7 @@ object CassandraTest { sc.stop() } } +// scalastyle:on println /* create keyspace casDemo; diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index 1f12034ce0f57..d651fe4d6ee75 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.io.File @@ -136,3 +137,4 @@ object DFSReadWriteTest { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index e757283823fc3..c42df2b8845d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.collection.JavaConversions._ @@ -46,3 +47,4 @@ object DriverSubmissionTest { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 15f6678648b29..fa4a3afeecd19 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -53,3 +54,4 @@ object GroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 95c96111c9b1f..244742327a907 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.hadoop.hbase.client.HBaseAdmin @@ -62,3 +63,4 @@ object HBaseTest { admin.close() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index ed2b38e2ca6f8..124dc9af6390f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark._ @@ -41,3 +42,4 @@ object HdfsTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 3d5259463003d..af5f216f28ba4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -142,3 +143,4 @@ object LocalALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index ac2ea35bbd0e0..9c8aae53cf48d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -73,3 +74,4 @@ object LocalFileLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index 04fc0a033014a..e7b28d38bdfc6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -119,3 +120,4 @@ object LocalKMeans { println("Final centers: " + kPoints) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index c3fc74a116c0a..4f6b092a59ca5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -77,3 +78,4 @@ object LocalLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index ee6b3ee34aeb2..3d923625f11b6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -33,3 +34,4 @@ object LocalPi { println("Pi is roughly " + 4 * count / 100000.0) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 75c82117cbad2..a80de10f4610a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -83,3 +84,4 @@ object LogQuery { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 2a5c0c0defe13..61ce9db914f9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.rdd.RDD @@ -53,3 +54,4 @@ object MultiBroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 5291ab81f459e..3b0b00fe4dd0a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -67,3 +68,4 @@ object SimpleSkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 017d4e1e5ce13..719e2176fed3f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -57,3 +58,4 @@ object SkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 30c4261551837..69799b7c2bb30 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -144,3 +145,4 @@ object SparkALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 9099c2fcc90b3..505ea5a4c7a85 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -97,3 +98,4 @@ object SparkHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index b514d9123f5e7..c56e1124ad415 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import breeze.linalg.{Vector, DenseVector, squaredDistance} @@ -100,3 +101,4 @@ object SparkKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 1e6b4fb0c7514..d265c227f4ed2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -86,3 +87,4 @@ object SparkLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index bd7894f184c4c..0fd79660dd196 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.SparkContext._ @@ -74,3 +75,4 @@ object SparkPageRank { ctx.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 35b8dd6c29b66..818d4f2b81f82 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -37,3 +38,4 @@ object SparkPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 772cd897f5140..95072071ccddb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.util.Random @@ -70,3 +71,4 @@ object SparkTC { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 4393b99e636b6..cfbdae02212a5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -94,3 +95,4 @@ object SparkTachyonHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala index 7743f7968b100..e46ac655beb58 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -46,3 +47,4 @@ object SparkTachyonPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 409721b01c8fd..8dd6c9706e7df 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import scala.collection.mutable @@ -151,3 +152,4 @@ object Analytics extends Logging { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala index f6f8d9f90c275..da3ffca1a6f2a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx /** @@ -42,3 +43,4 @@ object LiveJournalPageRank { Analytics.main(args.patch(0, List("pagerank"), 0)) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 3ec20d594b784..46e52aacd90bb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import org.apache.spark.SparkContext._ @@ -128,3 +129,4 @@ object SynthBenchmark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index 6c0af20461d3b..14b358d46f6ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -110,3 +111,4 @@ object CrossValidatorExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 54e4073941056..f28671f7869fc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -355,3 +356,4 @@ object DecisionTreeExample { println(s" Root mean squared error (RMSE): $RMSE") } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 7b8cc21ed8982..78f31b4ffe56a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -181,3 +182,4 @@ private class MyLogisticRegressionModel( copyValues(new MyLogisticRegressionModel(uid, weights), extra) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index 33905277c7341..f4a15f806ea81 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -236,3 +237,4 @@ object GBTExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala index b54466fd48bc5..b73299fb12d3f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -140,3 +141,4 @@ object LinearRegressionExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index 3cf193f353fbc..7682557127b51 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -157,3 +158,4 @@ object LogisticRegressionExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index 25f21113bf622..cd411397a4b9d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scopt.OptionParser @@ -178,3 +179,4 @@ object MovieLensALS { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index 6927eb8f275cf..bab31f585b0ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} @@ -183,3 +184,4 @@ object OneVsRestExample { (NANO.toSeconds(t1 - t0), result) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index 9f7cad68a4594..109178f4137b2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -244,3 +245,4 @@ object RandomForestExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index a0561e2573fc9..58d7b67674ff7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -100,3 +101,4 @@ object SimpleParamsExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 1324b066c30c3..960280137cbf9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.beans.BeanInfo @@ -89,3 +90,4 @@ object SimpleTextClassificationPipeline { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index a113653810b93..1a4016f76c2ad 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -153,3 +154,4 @@ object BinaryClassification { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index e49129c4e7844..026d4ecc6d10a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -91,3 +92,4 @@ object Correlations { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala index cb1abbd18fd4d..69988cc1b9334 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -106,3 +107,4 @@ object CosineSimilarity { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index 520893b26d595..dc13f82488af7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.io.File @@ -119,3 +120,4 @@ object DatasetExample { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 3381941673db8..57ffe3dd2524f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.language.reflectiveCalls @@ -368,3 +369,4 @@ object DecisionTreeRunner { } // scalastyle:on structural.type } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index f8c71ccabc43b..1fce4ba7efd60 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -65,3 +66,4 @@ object DenseGaussianMixture { println() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 14cc5cbb679c5..380d85d60e7b4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -107,3 +108,4 @@ object DenseKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 13f24a1e59610..14b930550d554 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -80,3 +81,4 @@ object FPGrowthExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 7416fb5a40848..e16a6bf033574 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -145,3 +146,4 @@ object GradientBoostedTreesRunner { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 31d629f853161..75b0f69cf91aa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.text.BreakIterator @@ -302,3 +303,4 @@ private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Se } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 6a456ba7ec07b..8878061a0970b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -134,3 +135,4 @@ object LinearRegression { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 99588b0984ab2..e43a6f2864c73 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.collection.mutable @@ -189,3 +190,4 @@ object MovieLensALS { math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).mean()) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 6e4e2d07f284b..5f839c75dd581 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -97,3 +98,4 @@ object MultivariateSummarizer { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 6d8b806569dfd..0723223954610 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -154,4 +155,4 @@ object PowerIterationClusteringExample { coeff * math.exp(expCoeff * ssquares) } } - +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala index 924b586e3af99..bee85ba0f9969 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.random.RandomRDDs @@ -58,3 +59,4 @@ object RandomRDDGeneration { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index 663c12734af68..6963f43e082c4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.util.MLUtils @@ -125,3 +126,4 @@ object SampledRDDs { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index f1ff4e6911f5e..f81fc292a3bd1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -100,3 +101,4 @@ object SparseNaiveBayes { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala index 8bb12d2ee9ed2..af03724a8ac62 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.SparkConf @@ -75,3 +76,4 @@ object StreamingKMeansExample { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index 1a95048bbfe2d..b4a5dca031abd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -69,3 +70,4 @@ object StreamingLinearRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala index e1998099c2d78..b42f4cb5f9338 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -71,3 +72,4 @@ object StreamingLogisticRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 3cd9cb743e309..464fbd385ab5d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnyPCA { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 4d6690318615a..65b4bc46f0266 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnySVD { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index b11e32047dc34..2cc56f04e5c1f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} @@ -73,3 +74,4 @@ object RDDRelation { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index b7ba60ec28155..bf40bd1ef13df 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql.hive import com.google.common.io.{ByteStreams, Files} @@ -77,3 +78,4 @@ object HiveFromSpark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 016de4c63d1d2..e9c9907198769 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import scala.collection.mutable.LinkedList @@ -170,3 +171,4 @@ object ActorWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 30269a7ccae97..28e9bf520e568 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.{InputStreamReader, BufferedReader, InputStream} @@ -100,3 +101,4 @@ class CustomReceiver(host: String, port: Int) } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index fbe394de4a179..bd78526f8c299 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import kafka.serializer.StringDecoder @@ -70,3 +71,4 @@ object DirectKafkaWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala index 20e7df7c45b1b..91e52e4eff5a7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -66,3 +67,4 @@ object FlumeEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala index 1cc8c8d5c23b6..2bdbc37e2a289 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -65,3 +66,4 @@ object FlumePollingEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala index 4b4667fec44e6..1f282d437dc38 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -53,3 +54,4 @@ object HdfsWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index 60416ee343544..b40d17e9c2fa3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.util.HashMap @@ -101,3 +102,4 @@ object KafkaWordCountProducer { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 813c8554f5193..d772ae309f40d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.eclipse.paho.client.mqttv3._ @@ -96,8 +97,10 @@ object MQTTWordCount { def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println System.err.println( "Usage: MQTTWordCount ") + // scalastyle:on println System.exit(1) } @@ -113,3 +116,4 @@ object MQTTWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index 2cd8073dada14..9a57fe286d1ae 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -57,3 +58,4 @@ object NetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index a9aaa445bccb6..5322929d177b4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -58,3 +59,4 @@ object RawNetworkGrep { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 751b30ea15782..9916882e4f94a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.File @@ -108,3 +109,4 @@ object RecoverableNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index 5a6b9216a3fbc..ed617754cbf1c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -99,3 +100,4 @@ object SQLContextSingleton { instance } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 345d0bc441351..02ba1c2eed0f7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -78,3 +79,4 @@ object StatefulNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala index c10de84a80ffe..825c671a929b1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird._ @@ -113,3 +114,4 @@ object TwitterAlgebirdCMS { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala index 62db5e663b8af..49826ede70418 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird.HyperLogLogMonoid @@ -90,3 +91,4 @@ object TwitterAlgebirdHLL { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala index f253d75b279f7..49cee1b43c2dc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -82,3 +83,4 @@ object TwitterPopularTags { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index e99d1baa72b9f..6ac9a72c37941 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import akka.actor.ActorSystem @@ -97,3 +98,4 @@ object ZeroMQWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 889f052c70263..bea7a47cb2855 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import java.net.ServerSocket @@ -108,3 +109,4 @@ object PageViewGenerator { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index fbacaee98690f..ec7d39da8b2e9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import org.apache.spark.SparkContext._ @@ -107,3 +108,4 @@ object PageViewStream { ssc.start() } } +// scalastyle:on println diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml new file mode 100644 index 0000000000000..13189595d1d6c --- /dev/null +++ b/external/flume-assembly/pom.xml @@ -0,0 +1,158 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-flume-assembly_2.10 + jar + Spark Project External Flume Assembly + http://spark.apache.org/ + + + provided + streaming-flume-assembly + + + + + org.apache.spark + spark-streaming-flume_${scala.binary.version} + ${project.version} + + + org.mortbay.jetty + jetty + + + org.mortbay.jetty + jetty-util + + + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + commons-codec + commons-codec + provided + + + commons-net + commons-net + provided + + + com.google.protobuf + protobuf-java + provided + + + org.apache.avro + avro + provided + + + org.apache.avro + avro-ipc + provided + + + org.scala-lang + scala-library + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + + + flume-provided + + provided + + + + + diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala new file mode 100644 index 0000000000000..9d9c3b189415f --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.net.{InetSocketAddress, ServerSocket} +import java.nio.ByteBuffer +import java.util.{List => JList} + +import scala.collection.JavaConversions._ + +import com.google.common.base.Charsets.UTF_8 +import org.apache.avro.ipc.NettyTransceiver +import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.commons.lang3.RandomUtils +import org.apache.flume.source.avro +import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent} +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.socket.SocketChannel +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder} + +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class FlumeTestUtils { + + private var transceiver: NettyTransceiver = null + + private val testPort: Int = findFreePort() + + def getTestPort(): Int = testPort + + /** Find a free port */ + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + /** Send data to the flume receiver */ + def writeInput(input: JList[String], enableCompression: Boolean): Unit = { + val testAddress = new InetSocketAddress("localhost", testPort) + + val inputEvents = input.map { item => + val event = new AvroFlumeEvent + event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) + event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) + event + } + + // if last attempted transceiver had succeeded, close it + close() + + // Create transceiver + transceiver = { + if (enableCompression) { + new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) + } else { + new NettyTransceiver(testAddress) + } + } + + // Create Avro client with the transceiver + val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) + if (client == null) { + throw new AssertionError("Cannot create client") + } + + // Send data + val status = client.appendBatch(inputEvents.toList) + if (status != avro.Status.OK) { + throw new AssertionError("Sent events unsuccessfully") + } + } + + def close(): Unit = { + if (transceiver != null) { + transceiver.close() + transceiver = null + } + } + + /** Class to create socket channel with compression */ + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { + val encoder = new ZlibEncoder(compressionLevel) + pipeline.addFirst("deflater", encoder) + pipeline.addFirst("inflater", new ZlibDecoder()) + super.newChannel(pipeline) + } + } + +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 44dec45c227ca..095bfb0c73a9a 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -18,10 +18,16 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress +import java.io.{DataOutputStream, ByteArrayOutputStream} +import java.util.{List => JList, Map => JMap} +import scala.collection.JavaConversions._ + +import org.apache.spark.api.java.function.PairFunction +import org.apache.spark.api.python.PythonRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -236,3 +242,71 @@ object FlumeUtils { createPollingStream(jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) } } + +/** + * This is a helper class that wraps the methods in FlumeUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's FlumeUtils. + */ +private class FlumeUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel, + enableDecompression: Boolean + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + val dstream = FlumeUtils.createStream(jssc, hostname, port, storageLevel, enableDecompression) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + + def createPollingStream( + jssc: JavaStreamingContext, + hosts: JList[String], + ports: JList[Int], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + assert(hosts.length == ports.length) + val addresses = hosts.zip(ports).map { + case (host, port) => new InetSocketAddress(host, port) + } + val dstream = FlumeUtils.createPollingStream( + jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + +} + +private object FlumeUtilsPythonHelper { + + private def stringMapToByteArray(map: JMap[CharSequence, CharSequence]): Array[Byte] = { + val byteStream = new ByteArrayOutputStream() + val output = new DataOutputStream(byteStream) + try { + output.writeInt(map.size) + map.foreach { kv => + PythonRDD.writeUTF(kv._1.toString, output) + PythonRDD.writeUTF(kv._2.toString, output) + } + byteStream.toByteArray + } + finally { + output.close() + } + } + + private def toByteArrayPairDStream(dstream: JavaReceiverInputDStream[SparkFlumeEvent]): + JavaPairDStream[Array[Byte], Array[Byte]] = { + dstream.mapToPair(new PairFunction[SparkFlumeEvent, Array[Byte], Array[Byte]] { + override def call(sparkEvent: SparkFlumeEvent): (Array[Byte], Array[Byte]) = { + val event = sparkEvent.event + val byteBuffer = event.getBody + val body = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(body) + (stringMapToByteArray(event.getHeaders), body) + } + }) + } +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala new file mode 100644 index 0000000000000..91d63d49dbec3 --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.util.concurrent._ +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer + +import com.google.common.base.Charsets.UTF_8 +import org.apache.flume.event.EventBuilder +import org.apache.flume.Context +import org.apache.flume.channel.MemoryChannel +import org.apache.flume.conf.Configurables + +import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink} + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class PollingFlumeTestUtils { + + private val batchCount = 5 + val eventsPerBatch = 100 + private val totalEventsPerChannel = batchCount * eventsPerBatch + private val channelCapacity = 5000 + + def getTotalEvents: Int = totalEventsPerChannel * channels.size + + private val channels = new ArrayBuffer[MemoryChannel] + private val sinks = new ArrayBuffer[SparkSink] + + /** + * Start a sink and return the port of this sink + */ + def startSingleSink(): Int = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + channels += (channel) + sinks += sink + + sink.getPort() + } + + /** + * Start 2 sinks and return the ports + */ + def startMultipleSinks(): JList[Int] = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val channel2 = new MemoryChannel() + Configurables.configure(channel2, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + val sink2 = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink2, context) + sink2.setChannel(channel2) + sink2.start() + + sinks += sink + sinks += sink2 + channels += channel + channels += channel2 + + sinks.map(_.getPort()) + } + + /** + * Send data and wait until all data has been received + */ + def sendDatAndEnsureAllDataHasBeenReceived(): Unit = { + val executor = Executors.newCachedThreadPool() + val executorCompletion = new ExecutorCompletionService[Void](executor) + + val latch = new CountDownLatch(batchCount * channels.size) + sinks.foreach(_.countdownWhenBatchReceived(latch)) + + channels.foreach(channel => { + executorCompletion.submit(new TxnSubmitter(channel)) + }) + + for (i <- 0 until channels.size) { + executorCompletion.take() + } + + latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. + } + + /** + * A Python-friendly method to assert the output + */ + def assertOutput( + outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = { + require(outputHeaders.size == outputBodies.size) + val eventSize = outputHeaders.size + if (eventSize != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"Expected ${totalEventsPerChannel * channels.size} events, but was $eventSize") + } + var counter = 0 + for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { + val eventBodyToVerify = s"${channels(k).getName}-$i" + val eventHeaderToVerify: JMap[String, String] = Map[String, String](s"test-$i" -> "header") + var found = false + var j = 0 + while (j < eventSize && !found) { + if (eventBodyToVerify == outputBodies.get(j) && + eventHeaderToVerify == outputHeaders.get(j)) { + found = true + counter += 1 + } + j += 1 + } + } + if (counter != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"111 Expected ${totalEventsPerChannel * channels.size} events, but was $counter") + } + } + + def assertChannelsAreEmpty(): Unit = { + channels.foreach(assertChannelIsEmpty) + } + + private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") + queueRemaining.setAccessible(true) + val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") + if (m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] != 5000) { + throw new AssertionError(s"Channel ${channel.getName} is not empty") + } + } + + def close(): Unit = { + sinks.foreach(_.stop()) + sinks.clear() + channels.foreach(_.stop()) + channels.clear() + } + + private class TxnSubmitter(channel: MemoryChannel) extends Callable[Void] { + override def call(): Void = { + var t = 0 + for (i <- 0 until batchCount) { + val tx = channel.getTransaction + tx.begin() + for (j <- 0 until eventsPerBatch) { + channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), + Map[String, String](s"test-$t" -> "header"))) + t += 1 + } + tx.commit() + tx.close() + Thread.sleep(500) // Allow some time for the events to reach + } + null + } + } + +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index d772b9ca9b570..d5f9a0aa38f9f 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -18,47 +18,33 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import java.util.concurrent._ import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.flume.Context -import org.apache.flume.channel.MemoryChannel -import org.apache.flume.conf.Configurables -import org.apache.flume.event.EventBuilder -import org.scalatest.concurrent.Eventually._ - +import com.google.common.base.Charsets.UTF_8 import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} -import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.{ManualClock, Utils} class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { - val batchCount = 5 - val eventsPerBatch = 100 - val totalEventsPerChannel = batchCount * eventsPerBatch - val channelCapacity = 5000 val maxAttempts = 5 val batchDuration = Seconds(1) val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - def beforeFunction() { - logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - } - - before(beforeFunction()) + val utils = new PollingFlumeTestUtils test("flume polling test") { testMultipleTimes(testFlumePolling) @@ -89,146 +75,55 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log } private def testFlumePolling(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - writeAndVerify(Seq(sink), Seq(channel)) - assertChannelIsEmpty(channel) - sink.stop() - channel.stop() + try { + val port = utils.startSingleSink() + + writeAndVerify(Seq(port)) + utils.assertChannelsAreEmpty() + } finally { + utils.close() + } } private def testFlumePollingMultipleHost(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val channel2 = new MemoryChannel() - Configurables.configure(channel2, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - val sink2 = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink2, context) - sink2.setChannel(channel2) - sink2.start() try { - writeAndVerify(Seq(sink, sink2), Seq(channel, channel2)) - assertChannelIsEmpty(channel) - assertChannelIsEmpty(channel2) + val ports = utils.startMultipleSinks() + writeAndVerify(ports) + utils.assertChannelsAreEmpty() } finally { - sink.stop() - sink2.stop() - channel.stop() - channel2.stop() + utils.close() } } - def writeAndVerify(sinks: Seq[SparkSink], channels: Seq[MemoryChannel]) { + def writeAndVerify(sinkPorts: Seq[Int]): Unit = { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) - val addresses = sinks.map(sink => new InetSocketAddress("localhost", sink.getPort())) + val addresses = sinkPorts.map(port => new InetSocketAddress("localhost", port)) val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - eventsPerBatch, 5) + utils.eventsPerBatch, 5) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) outputStream.register() ssc.start() - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val executor = Executors.newCachedThreadPool() - val executorCompletion = new ExecutorCompletionService[Void](executor) - - val latch = new CountDownLatch(batchCount * channels.size) - sinks.foreach(_.countdownWhenBatchReceived(latch)) - - channels.foreach(channel => { - executorCompletion.submit(new TxnSubmitter(channel, clock)) - }) - - for (i <- 0 until channels.size) { - executorCompletion.take() - } - - latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. - clock.advance(batchDuration.milliseconds) - - // The eventually is required to ensure that all data in the batch has been processed. - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val flattenedBuffer = outputBuffer.flatten - assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) - var counter = 0 - for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { - val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + - String.valueOf(i)).getBytes("utf-8"), - Map[String, String]("test-" + i.toString -> "header")) - var found = false - var j = 0 - while (j < flattenedBuffer.size && !found) { - val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") - if (new String(eventToVerify.getBody, "utf-8") == strToCompare && - eventToVerify.getHeaders.get("test-" + i.toString) - .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { - found = true - counter += 1 - } - j += 1 - } - } - assert(counter === totalEventsPerChannel * channels.size) - } - ssc.stop() - } - - def assertChannelIsEmpty(channel: MemoryChannel): Unit = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") - queueRemaining.setAccessible(true) - val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") - assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000) - } - - private class TxnSubmitter(channel: MemoryChannel, clock: ManualClock) extends Callable[Void] { - override def call(): Void = { - var t = 0 - for (i <- 0 until batchCount) { - val tx = channel.getTransaction - tx.begin() - for (j <- 0 until eventsPerBatch) { - channel.put(EventBuilder.withBody((channel.getName + " - " + String.valueOf(t)).getBytes( - "utf-8"), - Map[String, String]("test-" + t.toString -> "header"))) - t += 1 - } - tx.commit() - tx.close() - Thread.sleep(500) // Allow some time for the events to reach + try { + utils.sendDatAndEnsureAllDataHasBeenReceived() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds) + + // The eventually is required to ensure that all data in the batch has been processed. + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val flattenOutputBuffer = outputBuffer.flatten + val headers = flattenOutputBuffer.map(_.event.getHeaders.map { + case kv => (kv._1.toString, kv._2.toString) + }).map(mapAsJavaMap) + val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) + utils.assertOutput(headers, bodies) } - null + } finally { + ssc.stop() } } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index c926359987d89..5bc4cdf65306c 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,20 +17,12 @@ package org.apache.spark.streaming.flume -import java.net.{InetSocketAddress, ServerSocket} -import java.nio.ByteBuffer - import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.base.Charsets -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.apache.commons.lang3.RandomUtils -import org.apache.flume.source.avro -import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory @@ -41,22 +33,10 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} -import org.apache.spark.util.Utils class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") - var ssc: StreamingContext = null - var transceiver: NettyTransceiver = null - - after { - if (ssc != null) { - ssc.stop() - } - if (transceiver != null) { - transceiver.close() - } - } test("flume input stream") { testFlumeStream(testCompression = false) @@ -69,19 +49,29 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w /** Run test on flume stream */ private def testFlumeStream(testCompression: Boolean): Unit = { val input = (1 to 100).map { _.toString } - val testPort = findFreePort() - val outputBuffer = startContext(testPort, testCompression) - writeAndVerify(input, testPort, outputBuffer, testCompression) - } + val utils = new FlumeTestUtils + try { + val outputBuffer = startContext(utils.getTestPort(), testCompression) - /** Find a free port */ - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, conf)._2 + eventually(timeout(10 seconds), interval(100 milliseconds)) { + utils.writeInput(input, testCompression) + } + + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val outputEvents = outputBuffer.flatten.map { _.event } + outputEvents.foreach { + event => + event.getHeaders.get("test") should be("header") + } + val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) + output should be (input) + } + } finally { + if (ssc != null) { + ssc.stop() + } + utils.close() + } } /** Setup and start the streaming context */ @@ -98,58 +88,6 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w outputBuffer } - /** Send data to the flume receiver and verify whether the data was received */ - private def writeAndVerify( - input: Seq[String], - testPort: Int, - outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]], - enableCompression: Boolean - ) { - val testAddress = new InetSocketAddress("localhost", testPort) - - val inputEvents = input.map { item => - val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes(Charsets.UTF_8))) - event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) - event - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - // if last attempted transceiver had succeeded, close it - if (transceiver != null) { - transceiver.close() - transceiver = null - } - - // Create transceiver - transceiver = { - if (enableCompression) { - new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) - } else { - new NettyTransceiver(testAddress) - } - } - - // Create Avro client with the transceiver - val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) - client should not be null - - // Send data - val status = client.appendBatch(inputEvents.toList) - status should be (avro.Status.OK) - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val outputEvents = outputBuffer.flatten.map { _.event } - outputEvents.foreach { - event => - event.getHeaders.get("test") should be("header") - } - val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) - output should be (input) - } - } - /** Class to create socket channel with compression */ private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 8059c443827ef..977514fa5a1ec 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -58,6 +58,7 @@ maven-shade-plugin false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-kafka-assembly-${project.version}.jar *:* diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 876456c964770..48a1933d92f85 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka import scala.annotation.tailrec import scala.collection.mutable -import scala.reflect.{classTag, ClassTag} +import scala.reflect.ClassTag import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata @@ -29,7 +29,7 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset -import org.apache.spark.streaming.scheduler.InputInfo +import org.apache.spark.streaming.scheduler.StreamInputInfo /** * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where @@ -119,8 +119,23 @@ class DirectKafkaInputDStream[ val rdd = KafkaRDD[K, V, U, T, R]( context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) - // Report the record number of this batch interval to InputInfoTracker. - val inputInfo = InputInfo(id, rdd.count) + // Report the record number and metadata of this batch interval to InputInfoTracker. + val offsetRanges = currentOffsets.map { case (tp, fo) => + val uo = untilOffsets(tp) + OffsetRange(tp.topic, tp.partition, fo, uo.offset) + } + val description = offsetRanges.filter { offsetRange => + // Don't display empty ranges. + offsetRange.fromOffset != offsetRange.untilOffset + }.map { offsetRange => + s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" + + s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}" + }.mkString("\n") + // Copy offsetRanges to immutable.List to prevent from being modified by the user + val metadata = Map( + "offsets" -> offsetRanges.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> description) + val inputInfo = StreamInputInfo(id, rdd.count, metadata) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 3e6b937af57b0..8465432c5850f 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -410,7 +410,7 @@ object KafkaCluster { } Seq("zookeeper.connect", "group.id").foreach { s => - if (!props.contains(s)) { + if (!props.containsKey(s)) { props.setProperty(s, "") } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 0e33362d34acd..f3b01bd60b178 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -670,4 +670,17 @@ private class KafkaUtilsPythonHelper { TopicAndPartition(topic, partition) def createBroker(host: String, port: JInt): Broker = Broker(host, port) + + def offsetRangesOfKafkaRDD(rdd: RDD[_]): JList[OffsetRange] = { + val parentRDDs = rdd.getNarrowAncestors + val kafkaRDDs = parentRDDs.filter(rdd => rdd.isInstanceOf[KafkaRDD[_, _, _, _, _]]) + + require( + kafkaRDDs.length == 1, + "Cannot get offset ranges, as there may be multiple Kafka RDDs or no Kafka RDD associated" + + "with this RDD, please call this method only on a Kafka RDD.") + + val kafkaRDD = kafkaRDDs.head.asInstanceOf[KafkaRDD[_, _, _, _, _]] + kafkaRDD.offsetRanges.toSeq + } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 2675042666304..f326e7f1f6f8d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -75,7 +75,7 @@ final class OffsetRange private( } override def toString(): String = { - s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset]" + s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset])" } /** this is to avoid ClassNotFoundException during checkpoint restore */ diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 8e1715f6dbb95..5b3c79444aa68 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -111,7 +111,7 @@ class DirectKafkaStreamSuite rdd }.foreachRDD { rdd => for (o <- offsetRanges) { - println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + logInfo(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") } val collected = rdd.mapPartitionsWithIndex { (i, iter) => // For each partition, get size of the range in the partition, diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index f138251748c9e..3636a9037d43f 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -39,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version}
+ + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} @@ -49,6 +56,7 @@ spark-streaming_${scala.binary.version} ${project.version} test-jar + test junit diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index be8b62d3cc6ba..de749626ec09c 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.nio.ByteBuffer @@ -272,3 +273,4 @@ private[streaming] object StreamingExamples extends Logging { } } } +// scalastyle:on println diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index be6b9047d932d..5c07b415cd796 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -66,7 +66,6 @@ private[graphx] object BytecodeUtils { val finder = new MethodInvocationFinder(c.getName, m) getClassReader(c).accept(finder, 0) for (classMethod <- finder.methodsInvoked) { - // println(classMethod) if (classMethod._1 == targetClass && classMethod._2 == targetMethod) { return true } else if (!seen.contains(classMethod)) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 9591c4e9b8f4e..989e226305265 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -33,7 +33,7 @@ import org.apache.spark.graphx.Edge import org.apache.spark.graphx.impl.GraphImpl /** A collection of graph generating functions. */ -object GraphGenerators { +object GraphGenerators extends Logging { val RMATa = 0.45 val RMATb = 0.15 @@ -142,7 +142,7 @@ object GraphGenerators { var edges: Set[Edge[Int]] = Set() while (edges.size < numEdges) { if (edges.size % 100 == 0) { - println(edges.size + " edges") + logDebug(edges.size + " edges") } edges += addEdge(numVertices) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index 186d0cc2a977b..61e44dcab578c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.graphx.util import org.apache.spark.SparkFunSuite +// scalastyle:off println class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass @@ -102,6 +103,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { private val c = {e: TestClass => println(e.baz)} } +// scalastyle:on println object BytecodeUtilsSuite { class TestClass(val foo: Int, val bar: Long) { diff --git a/launcher/pom.xml b/launcher/pom.xml index a853e67f5cf78..2fd768d8119c4 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -68,12 +68,6 @@ org.apache.hadoop hadoop-client test - - - org.codehaus.jackson - jackson-mapper-asl - - diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 33d65d13f0d25..5e793a5c48775 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -136,7 +136,7 @@ void addPermGenSizeOpt(List cmd) { } } - cmd.add("-XX:MaxPermSize=128m"); + cmd.add("-XX:MaxPermSize=256m"); } void addOptionString(List cmd, String options) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 2665a700fe1f5..a16c0d2b5ca0b 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -27,7 +27,7 @@ */ class CommandBuilderUtils { - static final String DEFAULT_MEM = "512m"; + static final String DEFAULT_MEM = "1g"; static final String DEFAULT_PROPERTIES_FILE = "spark-defaults.conf"; static final String ENV_SPARK_HOME = "SPARK_HOME"; static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index d4cfeacb6ef18..c0f89c9230692 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -25,11 +25,12 @@ import static org.apache.spark.launcher.CommandBuilderUtils.*; -/** +/** * Launcher for Spark applications. - *

+ *

* Use this class to start Spark applications programmatically. The class uses a builder pattern * to allow clients to configure the Spark application and launch it as a child process. + *

*/ public class SparkLauncher { diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 3e5a2820b6c11..87c43aa9980e1 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -208,7 +208,7 @@ private List buildSparkSubmitCommand(Map env) throws IOE // - properties file. // - SPARK_DRIVER_MEMORY env variable // - SPARK_MEM env variable - // - default value (512m) + // - default value (1g) // Take Thrift Server as daemon String tsMemory = isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null; diff --git a/launcher/src/main/java/org/apache/spark/launcher/package-info.java b/launcher/src/main/java/org/apache/spark/launcher/package-info.java index 7ed756f4b8591..7c97dba511b28 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/package-info.java +++ b/launcher/src/main/java/org/apache/spark/launcher/package-info.java @@ -17,13 +17,17 @@ /** * Library for launching Spark applications. - *

+ * + *

* This library allows applications to launch Spark programmatically. There's only one entry * point to the library - the {@link org.apache.spark.launcher.SparkLauncher} class. - *

+ *

+ * + *

* To launch a Spark application, just instantiate a {@link org.apache.spark.launcher.SparkLauncher} * and configure the application to run. For example: - * + *

+ * *
  * {@code
  *   import org.apache.spark.launcher.SparkLauncher;
diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
index 97043a76cc612..7329ac9f7fb8c 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
@@ -194,7 +194,7 @@ private void testCmdBuilder(boolean isDriver) throws Exception {
         if (isDriver) {
           assertEquals("-XX:MaxPermSize=256m", arg);
         } else {
-          assertEquals("-XX:MaxPermSize=128m", arg);
+          assertEquals("-XX:MaxPermSize=256m", arg);
         }
       }
     }
diff --git a/make-distribution.sh b/make-distribution.sh
index 9f063da3a16c0..cac7032bb2e87 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -219,6 +219,7 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR"
 if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then
   mkdir -p "$DISTDIR"/R/lib
   cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib
+  cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib
 fi
 
 # Download and copy in tachyon, if requested
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index a1f3851d804ff..aef2c019d2871 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -95,6 +95,8 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
   /** @group setParam */
   def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
 
+  // Below, we clone stages so that modifications to the list of stages will not change
+  // the Param value in the Pipeline.
   /** @group getParam */
   def getStages: Array[PipelineStage] = $(stages).clone()
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 2e6eedd45ab07..8fc9199fb4602 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.classification
 
 import scala.collection.mutable
 
-import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
+import breeze.linalg.{DenseVector => BDV}
 import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
 
 import org.apache.spark.{Logging, SparkException}
@@ -41,7 +41,7 @@ import org.apache.spark.storage.StorageLevel
  */
 private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
   with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
-  with HasThreshold
+  with HasThreshold with HasStandardization
 
 /**
  * :: Experimental ::
@@ -98,6 +98,18 @@ class LogisticRegression(override val uid: String)
   def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
   setDefault(fitIntercept -> true)
 
+  /**
+   * Whether to standardize the training features before fitting the model.
+   * The coefficients of models will be always returned on the original scale,
+   * so it will be transparent for users. Note that when no regularization,
+   * with or without standardization, the models should be always converged to
+   * the same solution.
+   * Default is true.
+   * @group setParam
+   * */
+  def setStandardization(value: Boolean): this.type = set(standardization, value)
+  setDefault(standardization -> true)
+
   /** @group setParam */
   def setThreshold(value: Double): this.type = set(threshold, value)
   setDefault(threshold -> 0.5)
@@ -116,7 +128,7 @@ class LogisticRegression(override val uid: String)
           case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer),
           (label: Double, features: Vector)) =>
             (summarizer.add(features), labelSummarizer.add(label))
-      },
+        },
         combOp = (c1, c2) => (c1, c2) match {
           case ((summarizer1: MultivariateOnlineSummarizer,
           classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer,
@@ -149,15 +161,28 @@ class LogisticRegression(override val uid: String)
     val regParamL1 = $(elasticNetParam) * $(regParam)
     val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
 
-    val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
+    val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), $(standardization),
       featuresStd, featuresMean, regParamL2)
 
     val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
       new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
     } else {
-      // Remove the L1 penalization on the intercept
       def regParamL1Fun = (index: Int) => {
-        if (index == numFeatures) 0.0 else regParamL1
+        // Remove the L1 penalization on the intercept
+        if (index == numFeatures) {
+          0.0
+        } else {
+          if ($(standardization)) {
+            regParamL1
+          } else {
+            // If `standardization` is false, we still standardize the data
+            // to improve the rate of convergence; as a result, we have to
+            // perform this reverse standardization by penalizing each component
+            // differently to get effectively the same objective function when
+            // the training dataset is not standardized.
+            if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0
+          }
+        }
       }
       new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
     }
@@ -166,18 +191,18 @@ class LogisticRegression(override val uid: String)
       Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
 
     if ($(fitIntercept)) {
-      /**
-       * For binary logistic regression, when we initialize the weights as zeros,
-       * it will converge faster if we initialize the intercept such that
-       * it follows the distribution of the labels.
-       *
-       * {{{
-       * P(0) = 1 / (1 + \exp(b)), and
-       * P(1) = \exp(b) / (1 + \exp(b))
-       * }}}, hence
-       * {{{
-       * b = \log{P(1) / P(0)} = \log{count_1 / count_0}
-       * }}}
+      /*
+         For binary logistic regression, when we initialize the weights as zeros,
+         it will converge faster if we initialize the intercept such that
+         it follows the distribution of the labels.
+
+         {{{
+         P(0) = 1 / (1 + \exp(b)), and
+         P(1) = \exp(b) / (1 + \exp(b))
+         }}}, hence
+         {{{
+         b = \log{P(1) / P(0)} = \log{count_1 / count_0}
+         }}}
        */
       initialWeightsWithIntercept.toArray(numFeatures)
         = math.log(histogram(1).toDouble / histogram(0).toDouble)
@@ -186,39 +211,48 @@ class LogisticRegression(override val uid: String)
     val states = optimizer.iterations(new CachedDiffFunction(costFun),
       initialWeightsWithIntercept.toBreeze.toDenseVector)
 
-    var state = states.next()
-    val lossHistory = mutable.ArrayBuilder.make[Double]
+    val (weights, intercept, objectiveHistory) = {
+      /*
+         Note that in Logistic Regression, the objective history (loss + regularization)
+         is log-likelihood which is invariance under feature standardization. As a result,
+         the objective history from optimizer is the same as the one in the original space.
+       */
+      val arrayBuilder = mutable.ArrayBuilder.make[Double]
+      var state: optimizer.State = null
+      while (states.hasNext) {
+        state = states.next()
+        arrayBuilder += state.adjustedValue
+      }
 
-    while (states.hasNext) {
-      lossHistory += state.value
-      state = states.next()
-    }
-    lossHistory += state.value
+      if (state == null) {
+        val msg = s"${optimizer.getClass.getName} failed."
+        logError(msg)
+        throw new SparkException(msg)
+      }
 
-    // The weights are trained in the scaled space; we're converting them back to
-    // the original space.
-    val weightsWithIntercept = {
+      /*
+         The weights are trained in the scaled space; we're converting them back to
+         the original space.
+         Note that the intercept in scaled space and original space is the same;
+         as a result, no scaling is needed.
+       */
       val rawWeights = state.x.toArray.clone()
       var i = 0
-      // Note that the intercept in scaled space and original space is the same;
-      // as a result, no scaling is needed.
       while (i < numFeatures) {
         rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
         i += 1
       }
-      Vectors.dense(rawWeights)
+
+      if ($(fitIntercept)) {
+        (Vectors.dense(rawWeights.dropRight(1)).compressed, rawWeights.last, arrayBuilder.result())
+      } else {
+        (Vectors.dense(rawWeights).compressed, 0.0, arrayBuilder.result())
+      }
     }
 
     if (handlePersistence) instances.unpersist()
 
-    val (weights, intercept) = if ($(fitIntercept)) {
-      (Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)),
-        weightsWithIntercept(weightsWithIntercept.size - 1))
-    } else {
-      (weightsWithIntercept, 0.0)
-    }
-
-    new LogisticRegressionModel(uid, weights.compressed, intercept)
+    copyValues(new LogisticRegressionModel(uid, weights, intercept))
   }
 
   override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
@@ -423,16 +457,12 @@ private class LogisticAggregator(
     require(dim == data.size, s"Dimensions mismatch when adding new sample." +
       s" Expecting $dim but got ${data.size}.")
 
-    val dataSize = data.size
-
     val localWeightsArray = weightsArray
     val localGradientSumArray = gradientSumArray
 
     numClasses match {
       case 2 =>
-        /**
-         * For Binary Logistic Regression.
-         */
+        // For Binary Logistic Regression.
         val margin = - {
           var sum = 0.0
           data.foreachActive { (index, value) =>
@@ -518,11 +548,13 @@ private class LogisticCostFun(
     data: RDD[(Double, Vector)],
     numClasses: Int,
     fitIntercept: Boolean,
+    standardization: Boolean,
     featuresStd: Array[Double],
     featuresMean: Array[Double],
     regParamL2: Double) extends DiffFunction[BDV[Double]] {
 
   override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
+    val numFeatures = featuresStd.length
     val w = Vectors.fromBreeze(weights)
 
     val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept,
@@ -534,27 +566,43 @@ private class LogisticCostFun(
           case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
         })
 
-    // regVal is the sum of weight squares for L2 regularization
-    val norm = if (regParamL2 == 0.0) {
-      0.0
-    } else if (fitIntercept) {
-      brzNorm(Vectors.dense(weights.toArray.slice(0, weights.size -1)).toBreeze, 2.0)
-    } else {
-      brzNorm(weights, 2.0)
-    }
-    val regVal = 0.5 * regParamL2 * norm * norm
-
-    val loss = logisticAggregator.loss + regVal
-    val gradient = logisticAggregator.gradient
+    val totalGradientArray = logisticAggregator.gradient.toArray
 
-    if (fitIntercept) {
-      val wArray = w.toArray.clone()
-      wArray(wArray.length - 1) = 0.0
-      axpy(regParamL2, Vectors.dense(wArray), gradient)
+    // regVal is the sum of weight squares excluding intercept for L2 regularization.
+    val regVal = if (regParamL2 == 0.0) {
+      0.0
     } else {
-      axpy(regParamL2, w, gradient)
+      var sum = 0.0
+      w.foreachActive { (index, value) =>
+        // If `fitIntercept` is true, the last term which is intercept doesn't
+        // contribute to the regularization.
+        if (index != numFeatures) {
+          // The following code will compute the loss of the regularization; also
+          // the gradient of the regularization, and add back to totalGradientArray.
+          sum += {
+            if (standardization) {
+              totalGradientArray(index) += regParamL2 * value
+              value * value
+            } else {
+              if (featuresStd(index) != 0.0) {
+                // If `standardization` is false, we still standardize the data
+                // to improve the rate of convergence; as a result, we have to
+                // perform this reverse standardization by penalizing each component
+                // differently to get effectively the same objective function when
+                // the training dataset is not standardized.
+                val temp = value / (featuresStd(index) * featuresStd(index))
+                totalGradientArray(index) += regParamL2 * temp
+                value * temp
+              } else {
+                0.0
+              }
+            }
+          }
+        }
+      }
+      0.5 * regParamL2 * sum
     }
 
-    (loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
+    (logisticAggregator.loss + regVal, new BDV(totalGradientArray))
   }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
new file mode 100644
index 0000000000000..6b77de89a0330
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.feature
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector}
+import org.apache.spark.sql.types.{StringType, ArrayType, DataType}
+
+/**
+ * :: Experimental ::
+ * Converts a text document to a sparse vector of token counts.
+ * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
+ */
+@Experimental
+class CountVectorizerModel (override val uid: String, val vocabulary: Array[String])
+  extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] {
+
+  def this(vocabulary: Array[String]) =
+    this(Identifiable.randomUID("cntVec"), vocabulary)
+
+  /**
+   * Corpus-specific filter to ignore scarce words in a document. For each document, terms with
+   * frequency (count) less than the given threshold are ignored.
+   * Default: 1
+   * @group param
+   */
+  val minTermFreq: IntParam = new IntParam(this, "minTermFreq",
+    "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " +
+      "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1))
+
+  /** @group setParam */
+  def setMinTermFreq(value: Int): this.type = set(minTermFreq, value)
+
+  /** @group getParam */
+  def getMinTermFreq: Int = $(minTermFreq)
+
+  setDefault(minTermFreq -> 1)
+
+  override protected def createTransformFunc: Seq[String] => Vector = {
+    val dict = vocabulary.zipWithIndex.toMap
+    document =>
+      val termCounts = mutable.HashMap.empty[Int, Double]
+      document.foreach { term =>
+        dict.get(term) match {
+          case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0)
+          case None => // ignore terms not in the vocabulary
+        }
+      }
+      Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq)
+  }
+
+  override protected def validateInputType(inputType: DataType): Unit = {
+    require(inputType.sameType(ArrayType(StringType)),
+      s"Input type must be ArrayType(StringType) but got $inputType.")
+  }
+
+  override protected def outputDataType: DataType = new VectorUDT()
+
+  override def copy(extra: ParamMap): CountVectorizerModel = {
+    val copied = new CountVectorizerModel(uid, vocabulary)
+    copyValues(copied, extra)
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
new file mode 100644
index 0000000000000..228347635c92b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import edu.emory.mathcs.jtransforms.dct._
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.BooleanParam
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
+import org.apache.spark.sql.types.DataType
+
+/**
+ * :: Experimental ::
+ * A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero
+ * padding is performed on the input vector.
+ * It returns a real vector of the same length representing the DCT. The return vector is scaled
+ * such that the transform matrix is unitary (aka scaled DCT-II).
+ *
+ * More information on [[https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia]].
+ */
+@Experimental
+class DCT(override val uid: String)
+  extends UnaryTransformer[Vector, Vector, DCT] {
+
+  def this() = this(Identifiable.randomUID("dct"))
+
+  /**
+   * Indicates whether to perform the inverse DCT (true) or forward DCT (false).
+   * Default: false
+   * @group param
+   */
+  def inverse: BooleanParam = new BooleanParam(
+    this, "inverse", "Set transformer to perform inverse DCT")
+
+  /** @group setParam */
+  def setInverse(value: Boolean): this.type = set(inverse, value)
+
+  /** @group getParam */
+  def getInverse: Boolean = $(inverse)
+
+  setDefault(inverse -> false)
+
+  override protected def createTransformFunc: Vector => Vector = { vec =>
+    val result = vec.toArray
+    val jTransformer = new DoubleDCT_1D(result.length)
+    if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true)
+    Vectors.dense(result)
+  }
+
+  override protected def validateInputType(inputType: DataType): Unit = {
+    require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.")
+  }
+
+  override protected def outputDataType: DataType = new VectorUDT
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
new file mode 100644
index 0000000000000..b30adf3df48d2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
+import org.apache.spark.mllib.stat.Statistics
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{StructField, StructType}
+
+/**
+ * Params for [[MinMaxScaler]] and [[MinMaxScalerModel]].
+ */
+private[feature] trait MinMaxScalerParams extends Params with HasInputCol with HasOutputCol {
+
+  /**
+   * lower bound after transformation, shared by all features
+   * Default: 0.0
+   * @group param
+   */
+  val min: DoubleParam = new DoubleParam(this, "min",
+    "lower bound of the output feature range")
+
+  /**
+   * upper bound after transformation, shared by all features
+   * Default: 1.0
+   * @group param
+   */
+  val max: DoubleParam = new DoubleParam(this, "max",
+    "upper bound of the output feature range")
+
+  /** Validates and transforms the input schema. */
+  protected def validateAndTransformSchema(schema: StructType): StructType = {
+    val inputType = schema($(inputCol)).dataType
+    require(inputType.isInstanceOf[VectorUDT],
+      s"Input column ${$(inputCol)} must be a vector column")
+    require(!schema.fieldNames.contains($(outputCol)),
+      s"Output column ${$(outputCol)} already exists.")
+    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
+    StructType(outputFields)
+  }
+
+  override def validateParams(): Unit = {
+    require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
+  }
+}
+
+/**
+ * :: Experimental ::
+ * Rescale each feature individually to a common range [min, max] linearly using column summary
+ * statistics, which is also known as min-max normalization or Rescaling. The rescaled value for
+ * feature E is calculated as,
+ *
+ * Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min
+ *
+ * For the case E_{max} == E_{min}, Rescaled(e_i) = 0.5 * (max + min)
+ * Note that since zero values will probably be transformed to non-zero values, output of the
+ * transformer will be DenseVector even for sparse input.
+ */
+@Experimental
+class MinMaxScaler(override val uid: String)
+  extends Estimator[MinMaxScalerModel] with MinMaxScalerParams {
+
+  def this() = this(Identifiable.randomUID("minMaxScal"))
+
+  setDefault(min -> 0.0, max -> 1.0)
+
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  /** @group setParam */
+  def setMin(value: Double): this.type = set(min, value)
+
+  /** @group setParam */
+  def setMax(value: Double): this.type = set(max, value)
+
+  override def fit(dataset: DataFrame): MinMaxScalerModel = {
+    transformSchema(dataset.schema, logging = true)
+    val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
+    val summary = Statistics.colStats(input)
+    copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this))
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema)
+  }
+
+  override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra)
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by [[MinMaxScaler]].
+ *
+ * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529).
+ */
+@Experimental
+class MinMaxScalerModel private[ml] (
+    override val uid: String,
+    val originalMin: Vector,
+    val originalMax: Vector)
+  extends Model[MinMaxScalerModel] with MinMaxScalerParams {
+
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  /** @group setParam */
+  def setMin(value: Double): this.type = set(min, value)
+
+  /** @group setParam */
+  def setMax(value: Double): this.type = set(max, value)
+
+
+  override def transform(dataset: DataFrame): DataFrame = {
+    val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray
+    val minArray = originalMin.toArray
+
+    val reScale = udf { (vector: Vector) =>
+      val scale = $(max) - $(min)
+
+      // 0 in sparse vector will probably be rescaled to non-zero
+      val values = vector.toArray
+      val size = values.size
+      var i = 0
+      while (i < size) {
+        val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5
+        values(i) = raw * scale + $(min)
+        i += 1
+      }
+      Vectors.dense(values)
+    }
+
+    dataset.withColumn($(outputCol), reScale(col($(inputCol))))
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema)
+  }
+
+  override def copy(extra: ParamMap): MinMaxScalerModel = {
+    val copied = new MinMaxScalerModel(uid, originalMin, originalMax)
+    copyValues(copied, extra)
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
new file mode 100644
index 0000000000000..2d3bb680cf309
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{StructField, StructType}
+
+/**
+ * Params for [[PCA]] and [[PCAModel]].
+ */
+private[feature] trait PCAParams extends Params with HasInputCol with HasOutputCol {
+
+  /**
+   * The number of principal components.
+   * @group param
+   */
+  final val k: IntParam = new IntParam(this, "k", "the number of principal components")
+
+  /** @group getParam */
+  def getK: Int = $(k)
+
+}
+
+/**
+ * :: Experimental ::
+ * PCA trains a model to project vectors to a low-dimensional space using PCA.
+ */
+@Experimental
+class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams {
+
+  def this() = this(Identifiable.randomUID("pca"))
+
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  /** @group setParam */
+  def setK(value: Int): this.type = set(k, value)
+
+  /**
+   * Computes a [[PCAModel]] that contains the principal components of the input vectors.
+   */
+  override def fit(dataset: DataFrame): PCAModel = {
+    transformSchema(dataset.schema, logging = true)
+    val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v}
+    val pca = new feature.PCA(k = $(k))
+    val pcaModel = pca.fit(input)
+    copyValues(new PCAModel(uid, pcaModel).setParent(this))
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    val inputType = schema($(inputCol)).dataType
+    require(inputType.isInstanceOf[VectorUDT],
+      s"Input column ${$(inputCol)} must be a vector column")
+    require(!schema.fieldNames.contains($(outputCol)),
+      s"Output column ${$(outputCol)} already exists.")
+    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
+    StructType(outputFields)
+  }
+
+  override def copy(extra: ParamMap): PCA = defaultCopy(extra)
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by [[PCA]].
+ */
+@Experimental
+class PCAModel private[ml] (
+    override val uid: String,
+    pcaModel: feature.PCAModel)
+  extends Model[PCAModel] with PCAParams {
+
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  /**
+   * Transform a vector by computed Principal Components.
+   * NOTE: Vectors to be transformed must be the same length
+   * as the source vectors given to [[PCA.fit()]].
+   */
+  override def transform(dataset: DataFrame): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
+    val pcaOp = udf { pcaModel.transform _ }
+    dataset.withColumn($(outputCol), pcaOp(col($(inputCol))))
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    val inputType = schema($(inputCol)).dataType
+    require(inputType.isInstanceOf[VectorUDT],
+      s"Input column ${$(inputCol)} must be a vector column")
+    require(!schema.fieldNames.contains($(outputCol)),
+      s"Output column ${$(outputCol)} already exists.")
+    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
+    StructType(outputFields)
+  }
+
+  override def copy(extra: ParamMap): PCAModel = {
+    val copied = new PCAModel(uid, pcaModel)
+    copyValues(copied, extra)
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index ca3c1cfb56b7f..72b545e5db3e4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -106,6 +106,12 @@ class StandardScalerModel private[ml] (
     scaler: feature.StandardScalerModel)
   extends Model[StandardScalerModel] with StandardScalerParams {
 
+  /** Standard deviation of the StandardScalerModel */
+  val std: Vector = scaler.std
+
+  /** Mean of the StandardScalerModel */
+  val mean: Vector = scaler.mean
+
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 50c0d855066f8..d034d7ec6b60e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -341,9 +341,7 @@ trait Params extends Identifiable with Serializable {
    * those are checked during schema validation.
    */
   def validateParams(): Unit = {
-    params.filter(isDefined).foreach { param =>
-      param.asInstanceOf[Param[Any]].validate($(param))
-    }
+    // Do nothing by default.  Override to handle Param interactions.
   }
 
   /**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index b0a6af171c01f..f7ae1de522e01 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -54,8 +54,7 @@ private[shared] object SharedParamsCodeGen {
         isValid = "ParamValidators.gtEq(1)"),
       ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
       ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
-        " prior to fitting the model sequence. Note that the coefficients of models are" +
-        " always returned on the original scale.", Some("true")),
+        " before fitting the model.", Some("true")),
       ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),
       ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." +
         " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
@@ -135,7 +134,7 @@ private[shared] object SharedParamsCodeGen {
 
     s"""
       |/**
-      | * (private[ml]) Trait for shared param $name$defaultValueDoc.
+      | * Trait for shared param $name$defaultValueDoc.
       | */
       |private[ml] trait Has$Name extends Params {
       |
@@ -174,7 +173,6 @@ private[shared] object SharedParamsCodeGen {
         |package org.apache.spark.ml.param.shared
         |
         |import org.apache.spark.ml.param._
-        |import org.apache.spark.util.Utils
         |
         |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
         |
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index bbe08939b6d75..65e48e4ee5083 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -18,14 +18,13 @@
 package org.apache.spark.ml.param.shared
 
 import org.apache.spark.ml.param._
-import org.apache.spark.util.Utils
 
 // DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
 
 // scalastyle:off
 
 /**
- * (private[ml]) Trait for shared param regParam.
+ * Trait for shared param regParam.
  */
 private[ml] trait HasRegParam extends Params {
 
@@ -40,7 +39,7 @@ private[ml] trait HasRegParam extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param maxIter.
+ * Trait for shared param maxIter.
  */
 private[ml] trait HasMaxIter extends Params {
 
@@ -55,7 +54,7 @@ private[ml] trait HasMaxIter extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param featuresCol (default: "features").
+ * Trait for shared param featuresCol (default: "features").
  */
 private[ml] trait HasFeaturesCol extends Params {
 
@@ -72,7 +71,7 @@ private[ml] trait HasFeaturesCol extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param labelCol (default: "label").
+ * Trait for shared param labelCol (default: "label").
  */
 private[ml] trait HasLabelCol extends Params {
 
@@ -89,7 +88,7 @@ private[ml] trait HasLabelCol extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param predictionCol (default: "prediction").
+ * Trait for shared param predictionCol (default: "prediction").
  */
 private[ml] trait HasPredictionCol extends Params {
 
@@ -106,7 +105,7 @@ private[ml] trait HasPredictionCol extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param rawPredictionCol (default: "rawPrediction").
+ * Trait for shared param rawPredictionCol (default: "rawPrediction").
  */
 private[ml] trait HasRawPredictionCol extends Params {
 
@@ -123,7 +122,7 @@ private[ml] trait HasRawPredictionCol extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param probabilityCol (default: "probability").
+ * Trait for shared param probabilityCol (default: "probability").
  */
 private[ml] trait HasProbabilityCol extends Params {
 
@@ -140,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param threshold.
+ * Trait for shared param threshold.
  */
 private[ml] trait HasThreshold extends Params {
 
@@ -155,7 +154,7 @@ private[ml] trait HasThreshold extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param inputCol.
+ * Trait for shared param inputCol.
  */
 private[ml] trait HasInputCol extends Params {
 
@@ -170,7 +169,7 @@ private[ml] trait HasInputCol extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param inputCols.
+ * Trait for shared param inputCols.
  */
 private[ml] trait HasInputCols extends Params {
 
@@ -185,7 +184,7 @@ private[ml] trait HasInputCols extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param outputCol (default: uid + "__output").
+ * Trait for shared param outputCol (default: uid + "__output").
  */
 private[ml] trait HasOutputCol extends Params {
 
@@ -202,7 +201,7 @@ private[ml] trait HasOutputCol extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param checkpointInterval.
+ * Trait for shared param checkpointInterval.
  */
 private[ml] trait HasCheckpointInterval extends Params {
 
@@ -217,7 +216,7 @@ private[ml] trait HasCheckpointInterval extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param fitIntercept (default: true).
+ * Trait for shared param fitIntercept (default: true).
  */
 private[ml] trait HasFitIntercept extends Params {
 
@@ -234,15 +233,15 @@ private[ml] trait HasFitIntercept extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param standardization (default: true).
+ * Trait for shared param standardization (default: true).
  */
 private[ml] trait HasStandardization extends Params {
 
   /**
-   * Param for whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale..
+   * Param for whether to standardize the training features before fitting the model..
    * @group param
    */
-  final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale.")
+  final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model.")
 
   setDefault(standardization, true)
 
@@ -251,7 +250,7 @@ private[ml] trait HasStandardization extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param seed (default: this.getClass.getName.hashCode.toLong).
+ * Trait for shared param seed (default: this.getClass.getName.hashCode.toLong).
  */
 private[ml] trait HasSeed extends Params {
 
@@ -268,7 +267,7 @@ private[ml] trait HasSeed extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param elasticNetParam.
+ * Trait for shared param elasticNetParam.
  */
 private[ml] trait HasElasticNetParam extends Params {
 
@@ -283,7 +282,7 @@ private[ml] trait HasElasticNetParam extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param tol.
+ * Trait for shared param tol.
  */
 private[ml] trait HasTol extends Params {
 
@@ -298,7 +297,7 @@ private[ml] trait HasTol extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param stepSize.
+ * Trait for shared param stepSize.
  */
 private[ml] trait HasStepSize extends Params {
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 036e3acb07412..47c110d027d67 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -172,8 +172,7 @@ final class GBTRegressionModel(
     // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
     // Classifies by thresholding sum of weighted tree predictions
     val treePredictions = _trees.map(_.rootNode.predict(features))
-    val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
-    if (prediction > 0.0) 1.0 else 0.0
+    blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
   }
 
   override def copy(extra: ParamMap): GBTRegressionModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 1b1d7299fb496..8fc986056657d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -22,18 +22,20 @@ import scala.collection.mutable
 import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
 import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
 
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkException}
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS._
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.StatCounter
 
@@ -132,7 +134,6 @@ class LinearRegression(override val uid: String)
     val numFeatures = summarizer.mean.size
     val yMean = statCounter.mean
     val yStd = math.sqrt(statCounter.variance)
-    // look at glmnet5.m L761 maaaybe that has info
 
     // If the yStd is zero, then the intercept is yMean with zero weights;
     // as a result, training is not needed.
@@ -140,7 +141,16 @@ class LinearRegression(override val uid: String)
       logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
         s"and the intercept will be the mean of the label; as a result, training is not needed.")
       if (handlePersistence) instances.unpersist()
-      return new LinearRegressionModel(uid, Vectors.sparse(numFeatures, Seq()), yMean)
+      val weights = Vectors.sparse(numFeatures, Seq())
+      val intercept = yMean
+
+      val model = new LinearRegressionModel(uid, weights, intercept)
+      val trainingSummary = new LinearRegressionTrainingSummary(
+        model.transform(dataset).select($(predictionCol), $(labelCol)),
+        $(predictionCol),
+        $(labelCol),
+        Array(0D))
+      return copyValues(model.setSummary(trainingSummary))
     }
 
     val featuresMean = summarizer.mean.toArray
@@ -162,21 +172,33 @@ class LinearRegression(override val uid: String)
     }
 
     val initialWeights = Vectors.zeros(numFeatures)
-    val states =
-      optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
-
-    var state = states.next()
-    val lossHistory = mutable.ArrayBuilder.make[Double]
-
-    while (states.hasNext) {
-      lossHistory += state.value
-      state = states.next()
-    }
-    lossHistory += state.value
+    val states = optimizer.iterations(new CachedDiffFunction(costFun),
+      initialWeights.toBreeze.toDenseVector)
+
+    val (weights, objectiveHistory) = {
+      /*
+         Note that in Linear Regression, the objective history (loss + regularization) returned
+         from optimizer is computed in the scaled space given by the following formula.
+         {{{
+         L = 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2 + regTerms
+         }}}
+       */
+      val arrayBuilder = mutable.ArrayBuilder.make[Double]
+      var state: optimizer.State = null
+      while (states.hasNext) {
+        state = states.next()
+        arrayBuilder += state.adjustedValue
+      }
+      if (state == null) {
+        val msg = s"${optimizer.getClass.getName} failed."
+        logError(msg)
+        throw new SparkException(msg)
+      }
 
-    // The weights are trained in the scaled space; we're converting them back to
-    // the original space.
-    val weights = {
+      /*
+         The weights are trained in the scaled space; we're converting them back to
+         the original space.
+       */
       val rawWeights = state.x.toArray.clone()
       var i = 0
       val len = rawWeights.length
@@ -184,17 +206,26 @@ class LinearRegression(override val uid: String)
         rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 }
         i += 1
       }
-      Vectors.dense(rawWeights)
+
+      (Vectors.dense(rawWeights).compressed, arrayBuilder.result())
     }
 
-    // The intercept in R's GLMNET is computed using closed form after the coefficients are
-    // converged. See the following discussion for detail.
-    // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
+    /*
+       The intercept in R's GLMNET is computed using closed form after the coefficients are
+       converged. See the following discussion for detail.
+       http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
+     */
     val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0
+
     if (handlePersistence) instances.unpersist()
 
-    // TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
-    copyValues(new LinearRegressionModel(uid, weights.compressed, intercept))
+    val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
+    val trainingSummary = new LinearRegressionTrainingSummary(
+      model.transform(dataset).select($(predictionCol), $(labelCol)),
+      $(predictionCol),
+      $(labelCol),
+      objectiveHistory)
+    model.setSummary(trainingSummary)
   }
 
   override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
@@ -212,13 +243,124 @@ class LinearRegressionModel private[ml] (
   extends RegressionModel[Vector, LinearRegressionModel]
   with LinearRegressionParams {
 
+  private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
+
+  /**
+   * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is
+   * thrown if `trainingSummary == None`.
+   */
+  def summary: LinearRegressionTrainingSummary = trainingSummary match {
+    case Some(summ) => summ
+    case None =>
+      throw new SparkException(
+        "No training summary available for this LinearRegressionModel",
+        new NullPointerException())
+  }
+
+  private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = {
+    this.trainingSummary = Some(summary)
+    this
+  }
+
+  /** Indicates whether a training summary exists for this model instance. */
+  def hasSummary: Boolean = trainingSummary.isDefined
+
+  /**
+   * Evaluates the model on a testset.
+   * @param dataset Test dataset to evaluate model on.
+   */
+  // TODO: decide on a good name before exposing to public API
+  private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = {
+    val t = udf { features: Vector => predict(features) }
+    val predictionAndObservations = dataset
+      .select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol)))
+
+    new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol))
+  }
+
   override protected def predict(features: Vector): Double = {
     dot(features, weights) + intercept
   }
 
   override def copy(extra: ParamMap): LinearRegressionModel = {
-    copyValues(new LinearRegressionModel(uid, weights, intercept), extra)
+    val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept))
+    if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
+    newModel
+  }
+}
+
+/**
+ * :: Experimental ::
+ * Linear regression training results.
+ * @param predictions predictions outputted by the model's `transform` method.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+@Experimental
+class LinearRegressionTrainingSummary private[regression] (
+    predictions: DataFrame,
+    predictionCol: String,
+    labelCol: String,
+    val objectiveHistory: Array[Double])
+  extends LinearRegressionSummary(predictions, predictionCol, labelCol) {
+
+  /** Number of training iterations until termination */
+  val totalIterations = objectiveHistory.length
+
+}
+
+/**
+ * :: Experimental ::
+ * Linear regression results evaluated on a dataset.
+ * @param predictions predictions outputted by the model's `transform` method.
+ */
+@Experimental
+class LinearRegressionSummary private[regression] (
+    @transient val predictions: DataFrame,
+    val predictionCol: String,
+    val labelCol: String) extends Serializable {
+
+  @transient private val metrics = new RegressionMetrics(
+    predictions
+      .select(predictionCol, labelCol)
+      .map { case Row(pred: Double, label: Double) => (pred, label) } )
+
+  /**
+   * Returns the explained variance regression score.
+   * explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
+   * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
+   */
+  val explainedVariance: Double = metrics.explainedVariance
+
+  /**
+   * Returns the mean absolute error, which is a risk function corresponding to the
+   * expected value of the absolute error loss or l1-norm loss.
+   */
+  val meanAbsoluteError: Double = metrics.meanAbsoluteError
+
+  /**
+   * Returns the mean squared error, which is a risk function corresponding to the
+   * expected value of the squared error loss or quadratic loss.
+   */
+  val meanSquaredError: Double = metrics.meanSquaredError
+
+  /**
+   * Returns the root mean squared error, which is defined as the square root of
+   * the mean squared error.
+   */
+  val rootMeanSquaredError: Double = metrics.rootMeanSquaredError
+
+  /**
+   * Returns R^2^, the coefficient of determination.
+   * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
+   */
+  val r2: Double = metrics.r2
+
+  /** Residuals (predicted value - label value) */
+  @transient lazy val residuals: DataFrame = {
+    val t = udf { (pred: Double, label: Double) => pred - label}
+    predictions.select(t(col(predictionCol), col(labelCol)).as("residuals"))
   }
+
 }
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 1929f9d02156e..22873909c33fa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.ml.tree
 
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
 
 /**
  * Abstraction for Decision Tree models.
@@ -70,6 +71,10 @@ private[ml] trait TreeEnsembleModel {
   /** Weights for each tree, zippable with [[trees]] */
   def treeWeights: Array[Double]
 
+  /** Weights used by the python wrappers. */
+  // Note: An array cannot be returned directly due to serialization problems.
+  private[spark] def javaTreeWeights: Vector = Vectors.dense(treeWeights)
+
   /** Summary of the model */
   override def toString: String = {
     // Implementing classes should generally override this method to be more descriptive.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index 7cd53c6d7ef79..76f651488aef9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -32,10 +32,15 @@ private[spark] object SchemaUtils {
    * @param colName  column name
    * @param dataType  required column data type
    */
-  def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = {
+  def checkColumnType(
+      schema: StructType,
+      colName: String,
+      dataType: DataType,
+      msg: String = ""): Unit = {
     val actualDataType = schema(colName).dataType
+    val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
     require(actualDataType.equals(dataType),
-      s"Column $colName must be of type $dataType but was actually $actualDataType.")
+      s"Column $colName must be of type $dataType but was actually $actualDataType.$message")
   }
 
   /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index a66a404d5c846..e628059c4af8e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -28,6 +28,7 @@ import scala.reflect.ClassTag
 
 import net.razorvine.pickle._
 
+import org.apache.spark.SparkContext
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.api.python.SerDeUtil
 import org.apache.spark.mllib.classification._
@@ -75,6 +76,15 @@ private[python] class PythonMLLibAPI extends Serializable {
       minPartitions: Int): JavaRDD[LabeledPoint] =
     MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions)
 
+  /**
+   * Loads and serializes vectors saved with `RDD#saveAsTextFile`.
+   * @param jsc Java SparkContext
+   * @param path file or directory path in any Hadoop-supported file system URI
+   * @return serialized vectors in a RDD
+   */
+  def loadVectors(jsc: JavaSparkContext, path: String): RDD[Vector] =
+    MLUtils.loadVectors(jsc.sc, path)
+
   private def trainRegressionModel(
       learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
       data: JavaRDD[LabeledPoint],
@@ -632,6 +642,8 @@ private[python] class PythonMLLibAPI extends Serializable {
     def getVectors: JMap[String, JList[Float]] = {
       model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
     }
+
+    def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
   }
 
   /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
index fc509d2ba1470..e459367333d26 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
@@ -140,6 +140,10 @@ class GaussianMixture private (
     // Get length of the input vectors
     val d = breezeData.first().length
 
+    // Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when
+    // d > 25 except for when k is very small
+    val distributeGaussians = ((k - 1.0) / k) * d > 25
+
     // Determine initial weights and corresponding Gaussians.
     // If the user supplied an initial GMM, we use those values, otherwise
     // we start with uniform weights, a random mean from the data, and
@@ -171,14 +175,25 @@ class GaussianMixture private (
       // Create new distributions based on the partial assignments
       // (often referred to as the "M" step in literature)
       val sumWeights = sums.weights.sum
-      var i = 0
-      while (i < k) {
-        val mu = sums.means(i) / sums.weights(i)
-        BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu),
-          Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
-        weights(i) = sums.weights(i) / sumWeights
-        gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
-        i = i + 1
+
+      if (distributeGaussians) {
+        val numPartitions = math.min(k, 1024)
+        val tuples =
+          Seq.tabulate(k)(i => (sums.means(i), sums.sigmas(i), sums.weights(i)))
+        val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, sigma, weight) =>
+          updateWeightsAndGaussians(mean, sigma, weight, sumWeights)
+        }.collect.unzip
+        Array.copy(ws, 0, weights, 0, ws.length)
+        Array.copy(gs, 0, gaussians, 0, gs.length)
+      } else {
+        var i = 0
+        while (i < k) {
+          val (weight, gaussian) =
+            updateWeightsAndGaussians(sums.means(i), sums.sigmas(i), sums.weights(i), sumWeights)
+          weights(i) = weight
+          gaussians(i) = gaussian
+          i = i + 1
+        }
       }
 
       llhp = llh // current becomes previous
@@ -192,6 +207,19 @@ class GaussianMixture private (
   /** Java-friendly version of [[run()]] */
   def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd)
 
+  private def updateWeightsAndGaussians(
+      mean: BDV[Double],
+      sigma: BreezeMatrix[Double],
+      weight: Double,
+      sumWeights: Double): (Double, MultivariateGaussian) = {
+    val mu = (mean /= weight)
+    BLAS.syr(-weight, Vectors.fromBreeze(mu),
+      Matrices.fromBreeze(sigma).asInstanceOf[DenseMatrix])
+    val newWeight = weight / sumWeights
+    val newGaussian = new MultivariateGaussian(mu, sigma / weight)
+    (newWeight, newGaussian)
+  }
+
   /** Average of dense breeze vectors */
   private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
     val v = BDV.zeros[Double](x(0).length)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
index 4e01e402b4283..2a66263d8b7d6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
@@ -68,7 +68,7 @@ class PCA(val k: Int) {
  * @param k number of principal components.
  * @param pc a principal components Matrix. Each column is one principal component.
  */
-class PCAModel private[mllib] (val k: Int, val pc: DenseMatrix) extends VectorTransformer {
+class PCAModel private[spark] (val k: Int, val pc: DenseMatrix) extends VectorTransformer {
   /**
    * Transform a vector by computed Principal Components.
    *
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
new file mode 100644
index 0000000000000..72d0ea0c12e1e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.fpm
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.mllib.fpm.AssociationRules.Rule
+import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: Experimental ::
+ *
+ * Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates
+ * association rules which have a single item as the consequent.
+ *
+ * @since 1.5.0
+ */
+@Experimental
+class AssociationRules private[fpm] (
+    private var minConfidence: Double) extends Logging with Serializable {
+
+  /**
+   * Constructs a default instance with default parameters {minConfidence = 0.8}.
+   *
+   * @since 1.5.0
+   */
+  def this() = this(0.8)
+
+  /**
+   * Sets the minimal confidence (default: `0.8`).
+   *
+   * @since 1.5.0
+   */
+  def setMinConfidence(minConfidence: Double): this.type = {
+    require(minConfidence >= 0.0 && minConfidence <= 1.0)
+    this.minConfidence = minConfidence
+    this
+  }
+
+  /**
+   * Computes the association rules with confidence above [[minConfidence]].
+   * @param freqItemsets frequent itemset model obtained from [[FPGrowth]]
+   * @return a [[Set[Rule[Item]]] containing the assocation rules.
+   *
+   * @since 1.5.0
+   */
+  def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = {
+    // For candidate rule X => Y, generate (X, (Y, freq(X union Y)))
+    val candidates = freqItemsets.flatMap { itemset =>
+      val items = itemset.items
+      items.flatMap { item =>
+        items.partition(_ == item) match {
+          case (consequent, antecedent) if !antecedent.isEmpty =>
+            Some((antecedent.toSeq, (consequent.toSeq, itemset.freq)))
+          case _ => None
+        }
+      }
+    }
+
+    // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence
+    candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq)))
+      .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) =>
+      new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent)
+    }.filter(_.confidence >= minConfidence)
+  }
+
+  def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = {
+    val tag = fakeClassTag[Item]
+    run(freqItemsets.rdd)(tag)
+  }
+}
+
+object AssociationRules {
+
+  /**
+   * :: Experimental ::
+   *
+   * An association rule between sets of items.
+   * @param antecedent hypotheses of the rule
+   * @param consequent conclusion of the rule
+   * @tparam Item item type
+   *
+   * @since 1.5.0
+   */
+  @Experimental
+  class Rule[Item] private[fpm] (
+      val antecedent: Array[Item],
+      val consequent: Array[Item],
+      freqUnion: Double,
+      freqAntecedent: Double) extends Serializable {
+
+    def confidence: Double = freqUnion.toDouble / freqAntecedent
+
+    require(antecedent.toSet.intersect(consequent.toSet).isEmpty, {
+      val sharedItems = antecedent.toSet.intersect(consequent.toSet)
+      s"A valid association rule must have disjoint antecedent and " +
+        s"consequent but ${sharedItems} is present in both."
+    })
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index abac08022ea47..e2370a52f4930 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
-import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
+import org.apache.spark.mllib.fpm.FPGrowth._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
 
@@ -36,11 +36,23 @@ import org.apache.spark.storage.StorageLevel
  * :: Experimental ::
  *
  * Model trained by [[FPGrowth]], which holds frequent itemsets.
- * @param freqItemsets frequent itemsets, which is an RDD of [[FreqItemset]]
+ * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]]
  * @tparam Item item type
+ *
+ * @since 1.3.0
  */
 @Experimental
-class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable
+class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable {
+  /**
+   * Generates association rules for the [[Item]]s in [[freqItemsets]].
+   * @param confidence minimal confidence of the rules produced
+   * @since 1.5.0
+   */
+  def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = {
+    val associationRules = new AssociationRules(confidence)
+    associationRules.run(freqItemsets)
+  }
+}
 
 /**
  * :: Experimental ::
@@ -58,21 +70,26 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex
  *
  * @see [[http://en.wikipedia.org/wiki/Association_rule_learning Association rule learning
  *       (Wikipedia)]]
+ *
+ * @since 1.3.0
  */
 @Experimental
 class FPGrowth private (
     private var minSupport: Double,
-    private var numPartitions: Int,
-    private var ordered: Boolean) extends Logging with Serializable {
+    private var numPartitions: Int) extends Logging with Serializable {
 
   /**
    * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same
-   * as the input data, ordered: `false`}.
+   * as the input data}.
+   *
+   * @since 1.3.0
    */
-  def this() = this(0.3, -1, false)
+  def this() = this(0.3, -1)
 
   /**
    * Sets the minimal support level (default: `0.3`).
+   *
+   * @since 1.3.0
    */
   def setMinSupport(minSupport: Double): this.type = {
     this.minSupport = minSupport
@@ -81,25 +98,20 @@ class FPGrowth private (
 
   /**
    * Sets the number of partitions used by parallel FP-growth (default: same as input data).
+   *
+   * @since 1.3.0
    */
   def setNumPartitions(numPartitions: Int): this.type = {
     this.numPartitions = numPartitions
     this
   }
 
-  /**
-   * Indicates whether to mine itemsets (unordered) or sequences (ordered) (default: false, mine
-   * itemsets).
-   */
-  def setOrdered(ordered: Boolean): this.type = {
-    this.ordered = ordered
-    this
-  }
-
   /**
    * Computes an FP-Growth model that contains frequent itemsets.
    * @param data input data set, each element contains a transaction
    * @return an [[FPGrowthModel]]
+   *
+   * @since 1.3.0
    */
   def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = {
     if (data.getStorageLevel == StorageLevel.NONE) {
@@ -165,7 +177,7 @@ class FPGrowth private (
     .flatMap { case (part, tree) =>
       tree.extract(minCount, x => partitioner.getPartition(x) == part)
     }.map { case (ranks, count) =>
-      new FreqItemset(ranks.map(i => freqItems(i)).reverse.toArray, count, ordered)
+      new FreqItemset(ranks.map(i => freqItems(i)).toArray, count)
     }
   }
 
@@ -181,12 +193,9 @@ class FPGrowth private (
       itemToRank: Map[Item, Int],
       partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
     val output = mutable.Map.empty[Int, Array[Int]]
-    // Filter the basket by frequent items pattern
+    // Filter the basket by frequent items pattern and sort their ranks.
     val filtered = transaction.flatMap(itemToRank.get)
-    if (!this.ordered) {
-      ju.Arrays.sort(filtered)
-    }
-    // Generate conditional transactions
+    ju.Arrays.sort(filtered)
     val n = filtered.length
     var i = n - 1
     while (i >= 0) {
@@ -203,6 +212,8 @@ class FPGrowth private (
 
 /**
  * :: Experimental ::
+ *
+ * @since 1.3.0
  */
 @Experimental
 object FPGrowth {
@@ -211,21 +222,16 @@ object FPGrowth {
    * Frequent itemset.
    * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead.
    * @param freq frequency
-   * @param ordered indicates if items represents an itemset (false) or sequence (true)
    * @tparam Item item type
+   *
+   * @since 1.3.0
    */
-  class FreqItemset[Item](val items: Array[Item], val freq: Long, val ordered: Boolean)
-    extends Serializable {
-
-    /**
-     * Auxillary constructor, assumes unordered by default.
-     */
-    def this(items: Array[Item], freq: Long) {
-      this(items, freq, false)
-    }
+  class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable {
 
     /**
      * Returns items in a Java List.
+     *
+     * @since 1.3.0
      */
     def javaItems: java.util.List[Item] = {
       items.toList.asJava
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
new file mode 100644
index 0000000000000..39c48b084e550
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.fpm
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+
+/**
+ *
+ * :: Experimental ::
+ *
+ * Calculate all patterns of a projected database in local.
+ */
+@Experimental
+private[fpm] object LocalPrefixSpan extends Logging with Serializable {
+
+  /**
+   * Calculate all patterns of a projected database.
+   * @param minCount minimum count
+   * @param maxPatternLength maximum pattern length
+   * @param prefix prefix
+   * @param projectedDatabase the projected dabase
+   * @return a set of sequential pattern pairs,
+   *         the key of pair is sequential pattern (a list of items),
+   *         the value of pair is the pattern's count.
+   */
+  def run(
+      minCount: Long,
+      maxPatternLength: Int,
+      prefix: Array[Int],
+      projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
+    val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
+    val frequentPatternAndCounts = frequentPrefixAndCounts
+      .map(x => (prefix ++ Array(x._1), x._2))
+    val prefixProjectedDatabases = getPatternAndProjectedDatabase(
+      prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
+
+    val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
+    if (continueProcess) {
+      val nextPatterns = prefixProjectedDatabases
+        .map(x => run(minCount, maxPatternLength, x._1, x._2))
+        .reduce(_ ++ _)
+      frequentPatternAndCounts ++ nextPatterns
+    } else {
+      frequentPatternAndCounts
+    }
+  }
+
+  /**
+   * calculate suffix sequence following a prefix in a sequence
+   * @param prefix prefix
+   * @param sequence sequence
+   * @return suffix sequence
+   */
+  def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
+    val index = sequence.indexOf(prefix)
+    if (index == -1) {
+      Array()
+    } else {
+      sequence.drop(index + 1)
+    }
+  }
+
+  /**
+   * Generates frequent items by filtering the input data using minimal count level.
+   * @param minCount the absolute minimum count
+   * @param sequences sequences data
+   * @return array of item and count pair
+   */
+  private def getFreqItemAndCounts(
+      minCount: Long,
+      sequences: Array[Array[Int]]): Array[(Int, Long)] = {
+    sequences.flatMap(_.distinct)
+      .groupBy(x => x)
+      .mapValues(_.length.toLong)
+      .filter(_._2 >= minCount)
+      .toArray
+  }
+
+  /**
+   * Get the frequent prefixes' projected database.
+   * @param prePrefix the frequent prefixes' prefix
+   * @param frequentPrefixes frequent prefixes
+   * @param sequences sequences data
+   * @return prefixes and projected database
+   */
+  private def getPatternAndProjectedDatabase(
+      prePrefix: Array[Int],
+      frequentPrefixes: Array[Int],
+      sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
+    val filteredProjectedDatabase = sequences
+      .map(x => x.filter(frequentPrefixes.contains(_)))
+    frequentPrefixes.map { x =>
+      val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
+      (prePrefix ++ Array(x), sub)
+    }.filter(x => x._2.nonEmpty)
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
new file mode 100644
index 0000000000000..9d8c60ef0fc45
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.fpm
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+/**
+ *
+ * :: Experimental ::
+ *
+ * A parallel PrefixSpan algorithm to mine sequential pattern.
+ * The PrefixSpan algorithm is described in
+ * [[http://doi.org/10.1109/ICDE.2001.914830]].
+ *
+ * @param minSupport the minimal support level of the sequential pattern, any pattern appears
+ *                   more than  (minSupport * size-of-the-dataset) times will be output
+ * @param maxPatternLength the maximal length of the sequential pattern, any pattern appears
+ *                   less than maxPatternLength will be output
+ *
+ * @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining
+ *       (Wikipedia)]]
+ */
+@Experimental
+class PrefixSpan private (
+    private var minSupport: Double,
+    private var maxPatternLength: Int) extends Logging with Serializable {
+
+  /**
+   * Constructs a default instance with default parameters
+   * {minSupport: `0.1`, maxPatternLength: `10`}.
+   */
+  def this() = this(0.1, 10)
+
+  /**
+   * Sets the minimal support level (default: `0.1`).
+   */
+  def setMinSupport(minSupport: Double): this.type = {
+    require(minSupport >= 0 && minSupport <= 1,
+      "The minimum support value must be between 0 and 1, including 0 and 1.")
+    this.minSupport = minSupport
+    this
+  }
+
+  /**
+   * Sets maximal pattern length (default: `10`).
+   */
+  def setMaxPatternLength(maxPatternLength: Int): this.type = {
+    require(maxPatternLength >= 1,
+      "The maximum pattern length value must be greater than 0.")
+    this.maxPatternLength = maxPatternLength
+    this
+  }
+
+  /**
+   * Find the complete set of sequential patterns in the input sequences.
+   * @param sequences input data set, contains a set of sequences,
+   *                  a sequence is an ordered list of elements.
+   * @return a set of sequential pattern pairs,
+   *         the key of pair is pattern (a list of elements),
+   *         the value of pair is the pattern's count.
+   */
+  def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
+    if (sequences.getStorageLevel == StorageLevel.NONE) {
+      logWarning("Input data is not cached.")
+    }
+    val minCount = getMinCount(sequences)
+    val lengthOnePatternsAndCounts =
+      getFreqItemAndCounts(minCount, sequences).collect()
+    val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
+      lengthOnePatternsAndCounts.map(_._1), sequences)
+    val groupedProjectedDatabase = prefixAndProjectedDatabase
+      .map(x => (x._1.toSeq, x._2))
+      .groupByKey()
+      .map(x => (x._1.toArray, x._2.toArray))
+    val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
+    val lengthOnePatternsAndCountsRdd =
+      sequences.sparkContext.parallelize(
+        lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
+    val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
+    allPatterns
+  }
+
+  /**
+   * Get the minimum count (sequences count * minSupport).
+   * @param sequences input data set, contains a set of sequences,
+   * @return minimum count,
+   */
+  private def getMinCount(sequences: RDD[Array[Int]]): Long = {
+    if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
+  }
+
+  /**
+   * Generates frequent items by filtering the input data using minimal count level.
+   * @param minCount the absolute minimum count
+   * @param sequences original sequences data
+   * @return array of item and count pair
+   */
+  private def getFreqItemAndCounts(
+      minCount: Long,
+      sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
+    sequences.flatMap(_.distinct.map((_, 1L)))
+      .reduceByKey(_ + _)
+      .filter(_._2 >= minCount)
+  }
+
+  /**
+   * Get the frequent prefixes' projected database.
+   * @param frequentPrefixes frequent prefixes
+   * @param sequences sequences data
+   * @return prefixes and projected database
+   */
+  private def getPrefixAndProjectedDatabase(
+      frequentPrefixes: Array[Int],
+      sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
+    val filteredSequences = sequences.map { p =>
+      p.filter (frequentPrefixes.contains(_) )
+    }
+    filteredSequences.flatMap { x =>
+      frequentPrefixes.map { y =>
+        val sub = LocalPrefixSpan.getSuffix(y, x)
+        (Array(y), sub)
+      }.filter(_._2.nonEmpty)
+    }
+  }
+
+  /**
+   * calculate the patterns in local.
+   * @param minCount the absolute minimum count
+   * @param data patterns and projected sequences data data
+   * @return patterns
+   */
+  private def getPatternsInLocal(
+      minCount: Long,
+      data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
+    data.flatMap { x =>
+      LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
+    }
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 85e63b1382b5e..0df07663405a3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -24,9 +24,9 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash
 import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.types._
 import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
 
 /**
  * Trait for a local matrix.
@@ -114,6 +114,16 @@ sealed trait Matrix extends Serializable {
    *          corresponding value in the matrix with type `Double`.
    */
   private[spark] def foreachActive(f: (Int, Int, Double) => Unit)
+
+  /**
+   * Find the number of non-zero active values.
+   */
+  def numNonzeros: Int
+
+  /**
+   * Find the number of values stored explicitly. These values can be zero as well.
+   */
+  def numActives: Int
 }
 
 @DeveloperApi
@@ -137,7 +147,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
       ))
   }
 
-  override def serialize(obj: Any): Row = {
+  override def serialize(obj: Any): InternalRow = {
     val row = new GenericMutableRow(7)
     obj match {
       case sm: SparseMatrix =>
@@ -163,9 +173,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
 
   override def deserialize(datum: Any): Matrix = {
     datum match {
-      // TODO: something wrong with UDT serialization, should never happen.
-      case m: Matrix => m
-      case row: Row =>
+      case row: InternalRow =>
         require(row.length == 7,
           s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7")
         val tpe = row.getByte(0)
@@ -193,7 +201,8 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
     }
   }
 
-  override def hashCode(): Int = 1994
+  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
+  override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode()
 
   override def typeName: String = "matrix"
 
@@ -323,6 +332,10 @@ class DenseMatrix(
     }
   }
 
+  override def numNonzeros: Int = values.count(_ != 0)
+
+  override def numActives: Int = values.length
+
   /**
    * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed
    * set to false.
@@ -592,6 +605,11 @@ class SparseMatrix(
   def toDense: DenseMatrix = {
     new DenseMatrix(numRows, numCols, toArray)
   }
+
+  override def numNonzeros: Int = values.count(_ != 0)
+
+  override def numActives: Int = values.length
+
 }
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 2ffa497a99d93..e048b01d92462 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -28,7 +28,7 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.util.NumericParser
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
 import org.apache.spark.sql.types._
 
@@ -175,7 +175,7 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
       StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
   }
 
-  override def serialize(obj: Any): Row = {
+  override def serialize(obj: Any): InternalRow = {
     obj match {
       case SparseVector(size, indices, values) =>
         val row = new GenericMutableRow(4)
@@ -191,17 +191,12 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
         row.setNullAt(2)
         row.update(3, values.toSeq)
         row
-      // TODO: There are bugs in UDT serialization because we don't have a clear separation between
-      // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
-      // TODO: deserialize may get called twice. See SPARK-7186.
-      case row: Row =>
-        row
     }
   }
 
   override def deserialize(datum: Any): Vector = {
     datum match {
-      case row: Row =>
+      case row: InternalRow =>
         require(row.length == 4,
           s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
         val tpe = row.getByte(0)
@@ -215,11 +210,6 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
             val values = row.getAs[Iterable[Double]](3).toArray
             new DenseVector(values)
         }
-      // TODO: There are bugs in UDT serialization because we don't have a clear separation between
-      // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
-      // TODO: deserialize may get called twice. See SPARK-7186.
-      case v: Vector =>
-        v
     }
   }
 
@@ -234,7 +224,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
     }
   }
 
-  override def hashCode: Int = 7919
+  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
+  override def hashCode(): Int = classOf[VectorUDT].getName.hashCode()
 
   override def typeName: String = "vector"
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
index 3be530fa07537..1c33b43ea7a8a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
@@ -146,7 +146,7 @@ class IndexedRowMatrix(
       val indexedRows = indices.zip(svd.U.rows).map { case (i, v) =>
         IndexedRow(i, v)
       }
-      new IndexedRowMatrix(indexedRows, nRows, nCols)
+      new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt)
     } else {
       null
     }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 06e45e10c5bf4..ab7611fd077ef 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -19,13 +19,14 @@ package org.apache.spark.mllib.optimization
 
 import scala.collection.mutable.ArrayBuffer
 
-import breeze.linalg.{DenseVector => BDV}
+import breeze.linalg.{DenseVector => BDV, norm}
 
 import org.apache.spark.annotation.{Experimental, DeveloperApi}
 import org.apache.spark.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.mllib.linalg.{Vectors, Vector}
 
+
 /**
  * Class used to solve an optimization problem using Gradient Descent.
  * @param gradient Gradient function to be used.
@@ -38,6 +39,7 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
   private var numIterations: Int = 100
   private var regParam: Double = 0.0
   private var miniBatchFraction: Double = 1.0
+  private var convergenceTol: Double = 0.001
 
   /**
    * Set the initial step size of SGD for the first step. Default 1.0.
@@ -75,6 +77,23 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
     this
   }
 
+  /**
+   * Set the convergence tolerance. Default 0.001
+   * convergenceTol is a condition which decides iteration termination.
+   * The end of iteration is decided based on below logic.
+   * - If the norm of the new solution vector is >1, the diff of solution vectors
+   *   is compared to relative tolerance which means normalizing by the norm of
+   *   the new solution vector.
+   * - If the norm of the new solution vector is <=1, the diff of solution vectors
+   *   is compared to absolute tolerance which is not normalizing.
+   * Must be between 0.0 and 1.0 inclusively.
+   */
+  def setConvergenceTol(tolerance: Double): this.type = {
+    require(0.0 <= tolerance && tolerance <= 1.0)
+    this.convergenceTol = tolerance
+    this
+  }
+
   /**
    * Set the gradient function (of the loss function of one single data example)
    * to be used for SGD.
@@ -112,7 +131,8 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
       numIterations,
       regParam,
       miniBatchFraction,
-      initialWeights)
+      initialWeights,
+      convergenceTol)
     weights
   }
 
@@ -131,17 +151,20 @@ object GradientDescent extends Logging {
    * Sampling, and averaging the subgradients over this subset is performed using one standard
    * spark map-reduce in each iteration.
    *
-   * @param data - Input data for SGD. RDD of the set of data examples, each of
-   *               the form (label, [feature values]).
-   * @param gradient - Gradient object (used to compute the gradient of the loss function of
-   *                   one single data example)
-   * @param updater - Updater function to actually perform a gradient step in a given direction.
-   * @param stepSize - initial step size for the first step
-   * @param numIterations - number of iterations that SGD should be run.
-   * @param regParam - regularization parameter
-   * @param miniBatchFraction - fraction of the input data set that should be used for
-   *                            one iteration of SGD. Default value 1.0.
-   *
+   * @param data Input data for SGD. RDD of the set of data examples, each of
+   *             the form (label, [feature values]).
+   * @param gradient Gradient object (used to compute the gradient of the loss function of
+   *                 one single data example)
+   * @param updater Updater function to actually perform a gradient step in a given direction.
+   * @param stepSize initial step size for the first step
+   * @param numIterations number of iterations that SGD should be run.
+   * @param regParam regularization parameter
+   * @param miniBatchFraction fraction of the input data set that should be used for
+   *                          one iteration of SGD. Default value 1.0.
+   * @param convergenceTol Minibatch iteration will end before numIterations if the relative
+   *                       difference between the current weight and the previous weight is less
+   *                       than this value. In measuring convergence, L2 norm is calculated.
+   *                       Default value 0.001. Must be between 0.0 and 1.0 inclusively.
    * @return A tuple containing two elements. The first element is a column matrix containing
    *         weights for every feature, and the second element is an array containing the
    *         stochastic loss computed for every iteration.
@@ -154,9 +177,20 @@ object GradientDescent extends Logging {
       numIterations: Int,
       regParam: Double,
       miniBatchFraction: Double,
-      initialWeights: Vector): (Vector, Array[Double]) = {
+      initialWeights: Vector,
+      convergenceTol: Double): (Vector, Array[Double]) = {
+
+    // convergenceTol should be set with non minibatch settings
+    if (miniBatchFraction < 1.0 && convergenceTol > 0.0) {
+      logWarning("Testing against a convergenceTol when using miniBatchFraction " +
+        "< 1.0 can be unstable because of the stochasticity in sampling.")
+    }
 
     val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
+    // Record previous weight and current one to calculate solution vector difference
+
+    var previousWeights: Option[Vector] = None
+    var currentWeights: Option[Vector] = None
 
     val numExamples = data.count()
 
@@ -181,7 +215,9 @@ object GradientDescent extends Logging {
     var regVal = updater.compute(
       weights, Vectors.zeros(weights.size), 0, 1, regParam)._2
 
-    for (i <- 1 to numIterations) {
+    var converged = false // indicates whether converged based on convergenceTol
+    var i = 1
+    while (!converged && i <= numIterations) {
       val bcWeights = data.context.broadcast(weights)
       // Sample a subset (fraction miniBatchFraction) of the total data
       // compute and sum up the subgradients on this subset (this is one map-reduce)
@@ -204,12 +240,21 @@ object GradientDescent extends Logging {
          */
         stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
         val update = updater.compute(
-          weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam)
+          weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble),
+          stepSize, i, regParam)
         weights = update._1
         regVal = update._2
+
+        previousWeights = currentWeights
+        currentWeights = Some(weights)
+        if (previousWeights != None && currentWeights != None) {
+          converged = isConverged(previousWeights.get,
+            currentWeights.get, convergenceTol)
+        }
       } else {
         logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero")
       }
+      i += 1
     }
 
     logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
@@ -218,4 +263,32 @@ object GradientDescent extends Logging {
     (weights, stochasticLossHistory.toArray)
 
   }
+
+  def runMiniBatchSGD(
+      data: RDD[(Double, Vector)],
+      gradient: Gradient,
+      updater: Updater,
+      stepSize: Double,
+      numIterations: Int,
+      regParam: Double,
+      miniBatchFraction: Double,
+      initialWeights: Vector): (Vector, Array[Double]) =
+    GradientDescent.runMiniBatchSGD(data, gradient, updater, stepSize, numIterations,
+                                    regParam, miniBatchFraction, initialWeights, 0.001)
+
+
+  private def isConverged(
+      previousWeights: Vector,
+      currentWeights: Vector,
+      convergenceTol: Double): Boolean = {
+    // To compare with convergence tolerance.
+    val previousBDV = previousWeights.toBreeze.toDenseVector
+    val currentBDV = currentWeights.toBreeze.toDenseVector
+
+    // This represents the difference of updated weights in the iteration.
+    val solutionVecDiff: Double = norm(previousBDV - currentBDV)
+
+    solutionVecDiff < convergenceTol * Math.max(norm(currentBDV), 1.0)
+  }
+
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 93aa41e49961e..43d219a49cf4e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -22,6 +22,7 @@ import java.lang.{Integer => JavaInteger}
 
 import scala.collection.mutable
 
+import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
 import com.github.fommil.netlib.BLAS.{getInstance => blas}
 import org.apache.hadoop.fs.Path
 import org.json4s._
@@ -79,6 +80,30 @@ class MatrixFactorizationModel(
     blas.ddot(rank, userVector, 1, productVector, 1)
   }
 
+  /**
+   * Return approximate numbers of users and products in the given usersProducts tuples.
+   * This method is based on `countApproxDistinct` in class `RDD`.
+   *
+   * @param usersProducts  RDD of (user, product) pairs.
+   * @return approximate numbers of users and products.
+   */
+  private[this] def countApproxDistinctUserProduct(usersProducts: RDD[(Int, Int)]): (Long, Long) = {
+    val zeroCounterUser = new HyperLogLogPlus(4, 0)
+    val zeroCounterProduct = new HyperLogLogPlus(4, 0)
+    val aggregated = usersProducts.aggregate((zeroCounterUser, zeroCounterProduct))(
+      (hllTuple: (HyperLogLogPlus, HyperLogLogPlus), v: (Int, Int)) => {
+        hllTuple._1.offer(v._1)
+        hllTuple._2.offer(v._2)
+        hllTuple
+      },
+      (h1: (HyperLogLogPlus, HyperLogLogPlus), h2: (HyperLogLogPlus, HyperLogLogPlus)) => {
+        h1._1.addAll(h2._1)
+        h1._2.addAll(h2._2)
+        h1
+      })
+    (aggregated._1.cardinality(), aggregated._2.cardinality())
+  }
+
   /**
    * Predict the rating of many users for many products.
    * The output RDD has an element per each element in the input RDD (including all duplicates)
@@ -88,12 +113,30 @@ class MatrixFactorizationModel(
    * @return RDD of Ratings.
    */
   def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = {
-    val users = userFeatures.join(usersProducts).map {
-      case (user, (uFeatures, product)) => (product, (user, uFeatures))
-    }
-    users.join(productFeatures).map {
-      case (product, ((user, uFeatures), pFeatures)) =>
-        Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
+    // Previously the partitions of ratings are only based on the given products.
+    // So if the usersProducts given for prediction contains only few products or
+    // even one product, the generated ratings will be pushed into few or single partition
+    // and can't use high parallelism.
+    // Here we calculate approximate numbers of users and products. Then we decide the
+    // partitions should be based on users or products.
+    val (usersCount, productsCount) = countApproxDistinctUserProduct(usersProducts)
+
+    if (usersCount < productsCount) {
+      val users = userFeatures.join(usersProducts).map {
+        case (user, (uFeatures, product)) => (product, (user, uFeatures))
+      }
+      users.join(productFeatures).map {
+        case (product, ((user, uFeatures), pFeatures)) =>
+          Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
+      }
+    } else {
+      val products = productFeatures.join(usersProducts.map(_.swap)).map {
+        case (product, (pFeatures, user)) => (user, (product, pFeatures))
+      }
+      products.join(userFeatures).map {
+        case (user, ((product, pFeatures), uFeatures)) =>
+          Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
+      }
     }
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
index 235e043c7754b..c6d04464a12ba 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
@@ -85,4 +85,10 @@ class StreamingLinearRegressionWithSGD private[mllib] (
     this
   }
 
+  /** Set the convergence tolerance. */
+  def setConvergenceTol(tolerance: Double): this.type = {
+    this.algorithm.optimizer.setConvergenceTol(tolerance)
+    this
+  }
+
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
index 900007ec6bc74..90332028cfb3a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
@@ -17,13 +17,16 @@
 
 package org.apache.spark.mllib.stat
 
+import scala.annotation.varargs
+
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.distributed.RowMatrix
 import org.apache.spark.mllib.linalg.{Matrix, Vector}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.stat.correlation.Correlations
-import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult}
+import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult, KolmogorovSmirnovTest,
+  KolmogorovSmirnovTestResult}
 import org.apache.spark.rdd.RDD
 
 /**
@@ -158,4 +161,39 @@ object Statistics {
   def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = {
     ChiSqTest.chiSquaredFeatures(data)
   }
+
+  /**
+   * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a
+   * continuous distribution. By comparing the largest difference between the empirical cumulative
+   * distribution of the sample data and the theoretical distribution we can provide a test for the
+   * the null hypothesis that the sample data comes from that theoretical distribution.
+   * For more information on KS Test:
+   * @see [[https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test]]
+   *
+   * @param data an `RDD[Double]` containing the sample of data to test
+   * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value
+   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test
+   *        statistic, p-value, and null hypothesis.
+   */
+  def kolmogorovSmirnovTest(data: RDD[Double], cdf: Double => Double)
+    : KolmogorovSmirnovTestResult = {
+    KolmogorovSmirnovTest.testOneSample(data, cdf)
+  }
+
+  /**
+   * Convenience function to conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability
+   * distribution equality. Currently supports the normal distribution, taking as parameters
+   * the mean and standard deviation.
+   * (distName = "norm")
+   * @param data an `RDD[Double]` containing the sample of data to test
+   * @param distName a `String` name for a theoretical distribution
+   * @param params `Double*` specifying the parameters to be used for the theoretical distribution
+   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test
+   *        statistic, p-value, and null hypothesis.
+   */
+  @varargs
+  def kolmogorovSmirnovTest(data: RDD[Double], distName: String, params: Double*)
+    : KolmogorovSmirnovTestResult = {
+    KolmogorovSmirnovTest.testOneSample(data, distName, params: _*)
+  }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
new file mode 100644
index 0000000000000..d89b0059d83f3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat.test
+
+import scala.annotation.varargs
+
+import org.apache.commons.math3.distribution.{NormalDistribution, RealDistribution}
+import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest
+
+import org.apache.spark.Logging
+import org.apache.spark.rdd.RDD
+
+/**
+ * Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a
+ * continuous distribution. By comparing the largest difference between the empirical cumulative
+ * distribution of the sample data and the theoretical distribution we can provide a test for the
+ * the null hypothesis that the sample data comes from that theoretical distribution.
+ * For more information on KS Test:
+ * @see [[https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test]]
+ *
+ * Implementation note: We seek to implement the KS test with a minimal number of distributed
+ * passes. We sort the RDD, and then perform the following operations on a per-partition basis:
+ * calculate an empirical cumulative distribution value for each observation, and a theoretical
+ * cumulative distribution value. We know the latter to be correct, while the former will be off by
+ * a constant (how large the constant is depends on how many values precede it in other partitions).
+ * However, given that this constant simply shifts the empirical CDF upwards, but doesn't
+ * change its shape, and furthermore, that constant is the same within a given partition, we can
+ * pick 2 values in each partition that can potentially resolve to the largest global distance.
+ * Namely, we pick the minimum distance and the maximum distance. Additionally, we keep track of how
+ * many elements are in each partition. Once these three values have been returned for every
+ * partition, we can collect and operate locally. Locally, we can now adjust each distance by the
+ * appropriate constant (the cumulative sum of number of elements in the prior partitions divided by
+ * thedata set size). Finally, we take the maximum absolute value, and this is the statistic.
+ */
+private[stat] object KolmogorovSmirnovTest extends Logging {
+
+  // Null hypothesis for the type of KS test to be included in the result.
+  object NullHypothesis extends Enumeration {
+    type NullHypothesis = Value
+    val OneSampleTwoSided = Value("Sample follows theoretical distribution")
+  }
+
+  /**
+   * Runs a KS test for 1 set of sample data, comparing it to a theoretical distribution
+   * @param data `RDD[Double]` data on which to run test
+   * @param cdf `Double => Double` function to calculate the theoretical CDF
+   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the test
+   *        results (p-value, statistic, and null hypothesis)
+   */
+  def testOneSample(data: RDD[Double], cdf: Double => Double): KolmogorovSmirnovTestResult = {
+    val n = data.count().toDouble
+    val localData = data.sortBy(x => x).mapPartitions { part =>
+      val partDiffs = oneSampleDifferences(part, n, cdf) // local distances
+      searchOneSampleCandidates(partDiffs) // candidates: local extrema
+    }.collect()
+    val ksStat = searchOneSampleStatistic(localData, n) // result: global extreme
+    evalOneSampleP(ksStat, n.toLong)
+  }
+
+  /**
+   * Runs a KS test for 1 set of sample data, comparing it to a theoretical distribution
+   * @param data `RDD[Double]` data on which to run test
+   * @param distObj `RealDistribution` a theoretical distribution
+   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the test
+   *        results (p-value, statistic, and null hypothesis)
+   */
+  def testOneSample(data: RDD[Double], distObj: RealDistribution): KolmogorovSmirnovTestResult = {
+    val cdf = (x: Double) => distObj.cumulativeProbability(x)
+    testOneSample(data, cdf)
+  }
+
+  /**
+   * Calculate unadjusted distances between the empirical CDF and the theoretical CDF in a
+   * partition
+   * @param partData `Iterator[Double]` 1 partition of a sorted RDD
+   * @param n `Double` the total size of the RDD
+   * @param cdf `Double => Double` a function the calculates the theoretical CDF of a value
+   * @return `Iterator[(Double, Double)] `Unadjusted (ie. off by a constant) potential extrema
+   *        in a partition. The first element corresponds to the (empirical CDF - 1/N) - CDF,
+   *        the second element corresponds to empirical CDF - CDF.  We can then search the resulting
+   *        iterator for the minimum of the first and the maximum of the second element, and provide
+   *        this as a partition's candidate extrema
+   */
+  private def oneSampleDifferences(partData: Iterator[Double], n: Double, cdf: Double => Double)
+    : Iterator[(Double, Double)] = {
+    // zip data with index (within that partition)
+    // calculate local (unadjusted) empirical CDF and subtract CDF
+    partData.zipWithIndex.map { case (v, ix) =>
+      // dp and dl are later adjusted by constant, when global info is available
+      val dp = (ix + 1) / n
+      val dl = ix / n
+      val cdfVal = cdf(v)
+      (dl - cdfVal, dp - cdfVal)
+    }
+  }
+
+  /**
+   * Search the unadjusted differences in a partition and return the
+   * two extrema (furthest below and furthest above CDF), along with a count of elements in that
+   * partition
+   * @param partDiffs `Iterator[(Double, Double)]` the unadjusted differences between empirical CDF
+   *                 and CDFin a partition, which come as a tuple of
+   *                 (empirical CDF - 1/N - CDF, empirical CDF - CDF)
+   * @return `Iterator[(Double, Double, Double)]` the local extrema and a count of elements
+   */
+  private def searchOneSampleCandidates(partDiffs: Iterator[(Double, Double)])
+    : Iterator[(Double, Double, Double)] = {
+    val initAcc = (Double.MaxValue, Double.MinValue, 0.0)
+    val pResults = partDiffs.foldLeft(initAcc) { case ((pMin, pMax, pCt), (dl, dp)) =>
+      (math.min(pMin, dl), math.max(pMax, dp), pCt + 1)
+    }
+    val results = if (pResults == initAcc) Array[(Double, Double, Double)]() else Array(pResults)
+    results.iterator
+  }
+
+  /**
+   * Find the global maximum distance between empirical CDF and CDF (i.e. the KS statistic) after
+   * adjusting local extrema estimates from individual partitions with the amount of elements in
+   * preceding partitions
+   * @param localData `Array[(Double, Double, Double)]` A local array containing the collected
+   *                 results of `searchOneSampleCandidates` across all partitions
+   * @param n `Double`The size of the RDD
+   * @return The one-sample Kolmogorov Smirnov Statistic
+   */
+  private def searchOneSampleStatistic(localData: Array[(Double, Double, Double)], n: Double)
+    : Double = {
+    val initAcc = (Double.MinValue, 0.0)
+    // adjust differences based on the number of elements preceding it, which should provide
+    // the correct distance between empirical CDF and CDF
+    val results = localData.foldLeft(initAcc) { case ((prevMax, prevCt), (minCand, maxCand, ct)) =>
+      val adjConst = prevCt / n
+      val dist1 = math.abs(minCand + adjConst)
+      val dist2 = math.abs(maxCand + adjConst)
+      val maxVal = Array(prevMax, dist1, dist2).max
+      (maxVal, prevCt + ct)
+    }
+    results._1
+  }
+
+  /**
+   * A convenience function that allows running the KS test for 1 set of sample data against
+   * a named distribution
+   * @param data the sample data that we wish to evaluate
+   * @param distName the name of the theoretical distribution
+   * @param params Variable length parameter for distribution's parameters
+   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the
+   *        test results (p-value, statistic, and null hypothesis)
+   */
+  @varargs
+  def testOneSample(data: RDD[Double], distName: String, params: Double*)
+    : KolmogorovSmirnovTestResult = {
+    val distObj =
+      distName match {
+        case "norm" => {
+          if (params.nonEmpty) {
+            // parameters are passed, then can only be 2
+            require(params.length == 2, "Normal distribution requires mean and standard " +
+              "deviation as parameters")
+            new NormalDistribution(params(0), params(1))
+          } else {
+            // if no parameters passed in initializes to standard normal
+            logInfo("No parameters specified for normal distribution," +
+              "initialized to standard normal (i.e. N(0, 1))")
+            new NormalDistribution(0, 1)
+          }
+        }
+        case  _ => throw new UnsupportedOperationException(s"$distName not yet supported through" +
+          s" convenience method. Current options are:['norm'].")
+      }
+
+    testOneSample(data, distObj)
+  }
+
+  private def evalOneSampleP(ksStat: Double, n: Long): KolmogorovSmirnovTestResult = {
+    val pval = 1 - new KolmogorovSmirnovTest().cdf(ksStat, n.toInt)
+    new KolmogorovSmirnovTestResult(pval, ksStat, NullHypothesis.OneSampleTwoSided.toString)
+  }
+}
+
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
index 4784f9e947908..f44be13706695 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
@@ -90,3 +90,20 @@ class ChiSqTestResult private[stat] (override val pValue: Double,
       super.toString
   }
 }
+
+/**
+ * :: Experimental ::
+ * Object containing the test results for the Kolmogorov-Smirnov test.
+ */
+@Experimental
+class KolmogorovSmirnovTestResult private[stat] (
+    override val pValue: Double,
+    override val statistic: Double,
+    override val nullHypothesis: String) extends TestResult[Int] {
+
+  override val degreesOfFreedom = 0
+
+  override def toString: String = {
+    "Kolmogorov-Smirnov test summary:\n" + super.toString
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 25bb1453db404..f2c78bbabff0b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -198,7 +198,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
         val driverMemory = sc.getConf.getOption("spark.driver.memory")
           .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY")))
           .map(Utils.memoryStringToMb)
-          .getOrElse(512)
+          .getOrElse(Utils.DEFAULT_DRIVER_MEM_MB)
         if (driverMemory <= memThreshold) {
           logWarning(s"$thisClassName.save() was called, but it may fail because of too little" +
             s" driver memory (${driverMemory}m)." +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 1e3333d8d81d0..905c5fb42bd44 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -387,7 +387,7 @@ private[tree] object TreeEnsembleModel extends Logging {
         val driverMemory = sc.getConf.getOption("spark.driver.memory")
           .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY")))
           .map(Utils.memoryStringToMb)
-          .getOrElse(512)
+          .getOrElse(Utils.DEFAULT_DRIVER_MEM_MB)
         if (driverMemory <= memThreshold) {
           logWarning(s"$className.save() was called, but it may fail because of too little" +
             s" driver memory (${driverMemory}m)." +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
index 6eaebaf7dba9f..e6bcff48b022c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
@@ -64,8 +64,10 @@ object KMeansDataGenerator {
 
   def main(args: Array[String]) {
     if (args.length < 6) {
+      // scalastyle:off println
       println("Usage: KMeansGenerator " +
         "      []")
+      // scalastyle:on println
       System.exit(1)
     }
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
index b4e33c98ba7e5..87eeb5db05d26 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
@@ -153,8 +153,10 @@ object LinearDataGenerator {
 
   def main(args: Array[String]) {
     if (args.length < 2) {
+      // scalastyle:off println
       println("Usage: LinearDataGenerator " +
         "  [num_examples] [num_features] [num_partitions]")
+      // scalastyle:on println
       System.exit(1)
     }
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
index 9d802678c4a77..c09cbe69bb971 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
@@ -64,8 +64,10 @@ object LogisticRegressionDataGenerator {
 
   def main(args: Array[String]) {
     if (args.length != 5) {
+      // scalastyle:off println
       println("Usage: LogisticRegressionGenerator " +
         "    ")
+      // scalastyle:on println
       System.exit(1)
     }
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
index bd73a866c8a82..16f430599a515 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
@@ -55,8 +55,10 @@ import org.apache.spark.rdd.RDD
 object MFDataGenerator {
   def main(args: Array[String]) {
     if (args.length < 2) {
+      // scalastyle:off println
       println("Usage: MFDataGenerator " +
         "  [m] [n] [rank] [trainSampFact] [noise] [sigma] [test] [testSampFact]")
+      // scalastyle:on println
       System.exit(1)
     }
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
index a8e30cc9d730c..ad20b7694a779 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
@@ -37,8 +37,10 @@ object SVMDataGenerator {
 
   def main(args: Array[String]) {
     if (args.length < 2) {
+      // scalastyle:off println
       println("Usage: SVMGenerator " +
         "  [num_examples] [num_features] [num_partitions]")
+      // scalastyle:on println
       System.exit(1)
     }
 
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
new file mode 100644
index 0000000000000..845eed61c45c6
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import com.google.common.collect.Lists;
+import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+public class JavaDCTSuite {
+  private transient JavaSparkContext jsc;
+  private transient SQLContext jsql;
+
+  @Before
+  public void setUp() {
+    jsc = new JavaSparkContext("local", "JavaDCTSuite");
+    jsql = new SQLContext(jsc);
+  }
+
+  @After
+  public void tearDown() {
+    jsc.stop();
+    jsc = null;
+  }
+
+  @Test
+  public void javaCompatibilityTest() {
+    double[] input = new double[] {1D, 2D, 3D, 4D};
+    JavaRDD data = jsc.parallelize(Lists.newArrayList(
+      RowFactory.create(Vectors.dense(input))
+    ));
+    DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{
+      new StructField("vec", (new VectorUDT()), false, Metadata.empty())
+    }));
+
+    double[] expectedResult = input.clone();
+    (new DoubleDCT_1D(input.length)).forward(expectedResult, true);
+
+    DCT dct = new DCT()
+      .setInputCol("vec")
+      .setOutputCol("resultVec");
+
+    Row[] result = dct.transform(dataset).select("resultVec").collect();
+    Vector resultVec = result[0].getAs("resultVec");
+
+    Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6);
+  }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
new file mode 100644
index 0000000000000..5cf43fec6f29e
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.io.Serializable;
+import java.util.List;
+
+import scala.Tuple2;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.distributed.RowMatrix;
+import org.apache.spark.mllib.linalg.Matrix;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaPCASuite implements Serializable {
+  private transient JavaSparkContext jsc;
+  private transient SQLContext sqlContext;
+
+  @Before
+  public void setUp() {
+    jsc = new JavaSparkContext("local", "JavaPCASuite");
+    sqlContext = new SQLContext(jsc);
+  }
+
+  @After
+  public void tearDown() {
+    jsc.stop();
+    jsc = null;
+  }
+
+  public static class VectorPair implements Serializable {
+    private Vector features = Vectors.dense(0.0);
+    private Vector expected = Vectors.dense(0.0);
+
+    public void setFeatures(Vector features) {
+      this.features = features;
+    }
+
+    public Vector getFeatures() {
+      return this.features;
+    }
+
+    public void setExpected(Vector expected) {
+      this.expected = expected;
+    }
+
+    public Vector getExpected() {
+      return this.expected;
+    }
+  }
+
+  @Test
+  public void testPCA() {
+    List points = Lists.newArrayList(
+      Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0}),
+      Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
+      Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
+    );
+    JavaRDD dataRDD = jsc.parallelize(points, 2);
+
+    RowMatrix mat = new RowMatrix(dataRDD.rdd());
+    Matrix pc = mat.computePrincipalComponents(3);
+    JavaRDD expected = mat.multiply(pc).rows().toJavaRDD();
+
+    JavaRDD featuresExpected = dataRDD.zip(expected).map(
+      new Function, VectorPair>() {
+        public VectorPair call(Tuple2 pair) {
+          VectorPair featuresExpected = new VectorPair();
+          featuresExpected.setFeatures(pair._1());
+          featuresExpected.setExpected(pair._2());
+          return featuresExpected;
+        }
+      }
+    );
+
+    DataFrame df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
+    PCAModel pca = new PCA()
+      .setInputCol("features")
+      .setOutputCol("pca_features")
+      .setK(3)
+      .fit(df);
+    List result = pca.transform(df).select("pca_features", "expected").toJavaRDD().collect();
+    for (Row r : result) {
+      Assert.assertEquals(r.get(1), r.get(0));
+    }
+  }
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index 581c033f08ebe..b48f190f599a2 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -28,12 +28,13 @@
 import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Matrix;
 import org.apache.spark.mllib.linalg.Vector;
-
+import org.apache.spark.mllib.linalg.Vectors;
 
 public class JavaLDASuite implements Serializable {
   private transient JavaSparkContext sc;
@@ -110,7 +111,15 @@ public void distributedLDAModel() {
 
     // Check: topic distributions
     JavaPairRDD topicDistributions = model.javaTopicDistributions();
-    assertEquals(topicDistributions.count(), corpus.count());
+    // SPARK-5562. since the topicDistribution returns the distribution of the non empty docs
+    // over topics. Compare it against nonEmptyCorpus instead of corpus
+    JavaPairRDD nonEmptyCorpus = corpus.filter(
+      new Function, Boolean>() {
+        public Boolean call(Tuple2 tuple2) {
+          return Vectors.norm(tuple2._2(), 1.0) != 0.0;
+        }
+    });
+    assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
   }
 
   @Test
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
new file mode 100644
index 0000000000000..b3815ae6039c0
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.fpm;
+
+import java.io.Serializable;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import com.google.common.collect.Lists;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
+
+
+public class JavaAssociationRulesSuite implements Serializable {
+  private transient JavaSparkContext sc;
+
+  @Before
+  public void setUp() {
+    sc = new JavaSparkContext("local", "JavaFPGrowth");
+  }
+
+  @After
+  public void tearDown() {
+    sc.stop();
+    sc = null;
+  }
+
+  @Test
+  public void runAssociationRules() {
+
+    @SuppressWarnings("unchecked")
+    JavaRDD> freqItemsets = sc.parallelize(Lists.newArrayList(
+      new FreqItemset(new String[] {"a"}, 15L),
+      new FreqItemset(new String[] {"b"}, 35L),
+      new FreqItemset(new String[] {"a", "b"}, 18L)
+    ));
+
+    JavaRDD> results = (new AssociationRules()).run(freqItemsets);
+  }
+}
+
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index bd0edf2b9ea62..9ce2c52dca8b6 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -29,7 +29,6 @@
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
 
 public class JavaFPGrowthSuite implements Serializable {
   private transient JavaSparkContext sc;
@@ -62,10 +61,10 @@ public void runFPGrowth() {
       .setNumPartitions(2)
       .run(rdd);
 
-    List> freqItemsets = model.freqItemsets().toJavaRDD().collect();
+    List> freqItemsets = model.freqItemsets().toJavaRDD().collect();
     assertEquals(18, freqItemsets.size());
 
-    for (FreqItemset itemset: freqItemsets) {
+    for (FPGrowth.FreqItemset itemset: freqItemsets) {
       // Test return types.
       List items = itemset.javaItems();
       long freq = itemset.freq();
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 5a6265ea992c6..b7dd44753896a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -36,19 +36,19 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
 
     dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
 
-    /**
-     * Here is the instruction describing how to export the test data into CSV format
-     * so we can validate the training accuracy compared with R's glmnet package.
-     *
-     * import org.apache.spark.mllib.classification.LogisticRegressionSuite
-     * val nPoints = 10000
-     * val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
-     * val xMean = Array(5.843, 3.057, 3.758, 1.199)
-     * val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
-     * val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput(
-     *   weights, xMean, xVariance, true, nPoints, 42), 1)
-     * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", "
-     *   + x.features(2) + ", " + x.features(3)).saveAsTextFile("path")
+    /*
+       Here is the instruction describing how to export the test data into CSV format
+       so we can validate the training accuracy compared with R's glmnet package.
+
+       import org.apache.spark.mllib.classification.LogisticRegressionSuite
+       val nPoints = 10000
+       val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
+       val xMean = Array(5.843, 3.057, 3.758, 1.199)
+       val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+       val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput(
+         weights, xMean, xVariance, true, nPoints, 42), 1)
+       data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", "
+         + x.features(2) + ", " + x.features(3)).saveAsTextFile("path")
      */
     binaryDataset = {
       val nPoints = 10000
@@ -77,6 +77,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(lr.getRawPredictionCol === "rawPrediction")
     assert(lr.getProbabilityCol === "probability")
     assert(lr.getFitIntercept)
+    assert(lr.getStandardization)
     val model = lr.fit(dataset)
     model.transform(dataset)
       .select("label", "probability", "prediction", "rawPrediction")
@@ -208,267 +209,443 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
   }
 
   test("binary logistic regression with intercept without regularization") {
-    val trainer = (new LogisticRegression).setFitIntercept(true)
-    val model = trainer.fit(binaryDataset)
-
-    /**
-     * Using the following R code to load the data and train the model using glmnet package.
-     *
-     * > library("glmnet")
-     * > data <- read.csv("path", header=FALSE)
-     * > label = factor(data$V1)
-     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
-     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0))
-     * > weights
-     * 5 x 1 sparse Matrix of class "dgCMatrix"
-     *                     s0
-     * (Intercept)  2.8366423
-     * data.V2     -0.5895848
-     * data.V3      0.8931147
-     * data.V4     -0.3925051
-     * data.V5     -0.7996864
+    val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true)
+    val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false)
+
+    val model1 = trainer1.fit(binaryDataset)
+    val model2 = trainer2.fit(binaryDataset)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                           s0
+       (Intercept)  2.8366423
+       data.V2     -0.5895848
+       data.V3      0.8931147
+       data.V4     -0.3925051
+       data.V5     -0.7996864
      */
     val interceptR = 2.8366423
-    val weightsR = Array(-0.5895848, 0.8931147, -0.3925051, -0.7996864)
+    val weightsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864)
 
-    assert(model.intercept ~== interceptR relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
-    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
-    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
+    assert(model1.intercept ~== interceptR relTol 1E-3)
+    assert(model1.weights ~= weightsR relTol 1E-3)
+
+    // Without regularization, with or without standardization will converge to the same solution.
+    assert(model2.intercept ~== interceptR relTol 1E-3)
+    assert(model2.weights ~= weightsR relTol 1E-3)
   }
 
   test("binary logistic regression without intercept without regularization") {
-    val trainer = (new LogisticRegression).setFitIntercept(false)
-    val model = trainer.fit(binaryDataset)
-
-    /**
-     * Using the following R code to load the data and train the model using glmnet package.
-     *
-     * > library("glmnet")
-     * > data <- read.csv("path", header=FALSE)
-     * > label = factor(data$V1)
-     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
-     * > weights =
-     *     coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE))
-     * > weights
-     * 5 x 1 sparse Matrix of class "dgCMatrix"
-     *                     s0
-     * (Intercept)   .
-     * data.V2     -0.3534996
-     * data.V3      1.2964482
-     * data.V4     -0.3571741
-     * data.V5     -0.7407946
+    val trainer1 = (new LogisticRegression).setFitIntercept(false).setStandardization(true)
+    val trainer2 = (new LogisticRegression).setFitIntercept(false).setStandardization(false)
+
+    val model1 = trainer1.fit(binaryDataset)
+    val model2 = trainer2.fit(binaryDataset)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights =
+           coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                           s0
+       (Intercept)   .
+       data.V2     -0.3534996
+       data.V3      1.2964482
+       data.V4     -0.3571741
+       data.V5     -0.7407946
      */
     val interceptR = 0.0
-    val weightsR = Array(-0.3534996, 1.2964482, -0.3571741, -0.7407946)
+    val weightsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946)
+
+    assert(model1.intercept ~== interceptR relTol 1E-3)
+    assert(model1.weights ~= weightsR relTol 1E-2)
 
-    assert(model.intercept ~== interceptR relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-2)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-2)
-    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
-    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
+    // Without regularization, with or without standardization should converge to the same solution.
+    assert(model2.intercept ~== interceptR relTol 1E-3)
+    assert(model2.weights ~= weightsR relTol 1E-2)
   }
 
   test("binary logistic regression with intercept with L1 regularization") {
-    val trainer = (new LogisticRegression).setFitIntercept(true)
-      .setElasticNetParam(1.0).setRegParam(0.12)
-    val model = trainer.fit(binaryDataset)
-
-    /**
-     * Using the following R code to load the data and train the model using glmnet package.
-     *
-     * > library("glmnet")
-     * > data <- read.csv("path", header=FALSE)
-     * > label = factor(data$V1)
-     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
-     * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12))
-     * > weights
-     * 5 x 1 sparse Matrix of class "dgCMatrix"
-     *                      s0
-     * (Intercept) -0.05627428
-     * data.V2       .
-     * data.V3       .
-     * data.V4     -0.04325749
-     * data.V5     -0.02481551
+    val trainer1 = (new LogisticRegression).setFitIntercept(true)
+      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true)
+    val trainer2 = (new LogisticRegression).setFitIntercept(true)
+      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false)
+
+    val model1 = trainer1.fit(binaryDataset)
+    val model2 = trainer2.fit(binaryDataset)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                            s0
+       (Intercept) -0.05627428
+       data.V2       .
+       data.V3       .
+       data.V4     -0.04325749
+       data.V5     -0.02481551
      */
-    val interceptR = -0.05627428
-    val weightsR = Array(0.0, 0.0, -0.04325749, -0.02481551)
-
-    assert(model.intercept ~== interceptR relTol 1E-2)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
-    assert(model.weights(2) ~== weightsR(2) relTol 1E-2)
-    assert(model.weights(3) ~== weightsR(3) relTol 2E-2)
+    val interceptR1 = -0.05627428
+    val weightsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551)
+
+    assert(model1.intercept ~== interceptR1 relTol 1E-2)
+    assert(model1.weights ~= weightsR1 absTol 2E-2)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
+           standardize=FALSE))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                           s0
+       (Intercept)  0.3722152
+       data.V2       .
+       data.V3       .
+       data.V4     -0.1665453
+       data.V5       .
+     */
+    val interceptR2 = 0.3722152
+    val weightsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0)
+
+    assert(model2.intercept ~== interceptR2 relTol 1E-2)
+    assert(model2.weights ~= weightsR2 absTol 1E-3)
   }
 
   test("binary logistic regression without intercept with L1 regularization") {
-    val trainer = (new LogisticRegression).setFitIntercept(false)
-      .setElasticNetParam(1.0).setRegParam(0.12)
-    val model = trainer.fit(binaryDataset)
-
-    /**
-     * Using the following R code to load the data and train the model using glmnet package.
-     *
-     * > library("glmnet")
-     * > data <- read.csv("path", header=FALSE)
-     * > label = factor(data$V1)
-     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
-     * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
-     *     intercept=FALSE))
-     * > weights
-     * 5 x 1 sparse Matrix of class "dgCMatrix"
-     *                      s0
-     * (Intercept)   .
-     * data.V2       .
-     * data.V3       .
-     * data.V4     -0.05189203
-     * data.V5     -0.03891782
+    val trainer1 = (new LogisticRegression).setFitIntercept(false)
+      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true)
+    val trainer2 = (new LogisticRegression).setFitIntercept(false)
+      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false)
+
+    val model1 = trainer1.fit(binaryDataset)
+    val model2 = trainer2.fit(binaryDataset)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
+           intercept=FALSE))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                            s0
+       (Intercept)   .
+       data.V2       .
+       data.V3       .
+       data.V4     -0.05189203
+       data.V5     -0.03891782
      */
-    val interceptR = 0.0
-    val weightsR = Array(0.0, 0.0, -0.05189203, -0.03891782)
+    val interceptR1 = 0.0
+    val weightsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782)
+
+    assert(model1.intercept ~== interceptR1 relTol 1E-3)
+    assert(model1.weights ~= weightsR1 absTol 1E-3)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
+           intercept=FALSE, standardize=FALSE))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                            s0
+       (Intercept)   .
+       data.V2       .
+       data.V3       .
+       data.V4     -0.08420782
+       data.V5       .
+     */
+    val interceptR2 = 0.0
+    val weightsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0)
 
-    assert(model.intercept ~== interceptR relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
-    assert(model.weights(2) ~== weightsR(2) relTol 1E-2)
-    assert(model.weights(3) ~== weightsR(3) relTol 1E-2)
+    assert(model2.intercept ~== interceptR2 absTol 1E-3)
+    assert(model2.weights ~= weightsR2 absTol 1E-3)
   }
 
   test("binary logistic regression with intercept with L2 regularization") {
-    val trainer = (new LogisticRegression).setFitIntercept(true)
-      .setElasticNetParam(0.0).setRegParam(1.37)
-    val model = trainer.fit(binaryDataset)
-
-    /**
-     * Using the following R code to load the data and train the model using glmnet package.
-     *
-     * > library("glmnet")
-     * > data <- read.csv("path", header=FALSE)
-     * > label = factor(data$V1)
-     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
-     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37))
-     * > weights
-     * 5 x 1 sparse Matrix of class "dgCMatrix"
-     *                      s0
-     * (Intercept)  0.15021751
-     * data.V2     -0.07251837
-     * data.V3      0.10724191
-     * data.V4     -0.04865309
-     * data.V5     -0.10062872
+    val trainer1 = (new LogisticRegression).setFitIntercept(true)
+      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true)
+    val trainer2 = (new LogisticRegression).setFitIntercept(true)
+      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false)
+
+    val model1 = trainer1.fit(binaryDataset)
+    val model2 = trainer2.fit(binaryDataset)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                            s0
+       (Intercept)  0.15021751
+       data.V2     -0.07251837
+       data.V3      0.10724191
+       data.V4     -0.04865309
+       data.V5     -0.10062872
      */
-    val interceptR = 0.15021751
-    val weightsR = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872)
-
-    assert(model.intercept ~== interceptR relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
-    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
-    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
+    val interceptR1 = 0.15021751
+    val weightsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872)
+
+    assert(model1.intercept ~== interceptR1 relTol 1E-3)
+    assert(model1.weights ~= weightsR1 relTol 1E-3)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
+           standardize=FALSE))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                            s0
+       (Intercept)  0.48657516
+       data.V2     -0.05155371
+       data.V3      0.02301057
+       data.V4     -0.11482896
+       data.V5     -0.06266838
+     */
+    val interceptR2 = 0.48657516
+    val weightsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838)
+
+    assert(model2.intercept ~== interceptR2 relTol 1E-3)
+    assert(model2.weights ~= weightsR2 relTol 1E-3)
   }
 
   test("binary logistic regression without intercept with L2 regularization") {
-    val trainer = (new LogisticRegression).setFitIntercept(false)
-      .setElasticNetParam(0.0).setRegParam(1.37)
-    val model = trainer.fit(binaryDataset)
-
-    /**
-     * Using the following R code to load the data and train the model using glmnet package.
-     *
-     * > library("glmnet")
-     * > data <- read.csv("path", header=FALSE)
-     * > label = factor(data$V1)
-     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
-     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
-     *     intercept=FALSE))
-     * > weights
-     * 5 x 1 sparse Matrix of class "dgCMatrix"
-     *                      s0
-     * (Intercept)   .
-     * data.V2     -0.06099165
-     * data.V3      0.12857058
-     * data.V4     -0.04708770
-     * data.V5     -0.09799775
+    val trainer1 = (new LogisticRegression).setFitIntercept(false)
+      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true)
+    val trainer2 = (new LogisticRegression).setFitIntercept(false)
+      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false)
+
+    val model1 = trainer1.fit(binaryDataset)
+    val model2 = trainer2.fit(binaryDataset)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
+           intercept=FALSE))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                            s0
+       (Intercept)   .
+       data.V2     -0.06099165
+       data.V3      0.12857058
+       data.V4     -0.04708770
+       data.V5     -0.09799775
      */
-    val interceptR = 0.0
-    val weightsR = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775)
+    val interceptR1 = 0.0
+    val weightsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775)
+
+    assert(model1.intercept ~== interceptR1 absTol 1E-3)
+    assert(model1.weights ~= weightsR1 relTol 1E-2)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
+           intercept=FALSE, standardize=FALSE))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                             s0
+       (Intercept)   .
+       data.V2     -0.005679651
+       data.V3      0.048967094
+       data.V4     -0.093714016
+       data.V5     -0.053314311
+     */
+    val interceptR2 = 0.0
+    val weightsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311)
 
-    assert(model.intercept ~== interceptR relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-2)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-2)
-    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
-    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
+    assert(model2.intercept ~== interceptR2 absTol 1E-3)
+    assert(model2.weights ~= weightsR2 relTol 1E-2)
   }
 
   test("binary logistic regression with intercept with ElasticNet regularization") {
-    val trainer = (new LogisticRegression).setFitIntercept(true)
-      .setElasticNetParam(0.38).setRegParam(0.21)
-    val model = trainer.fit(binaryDataset)
-
-    /**
-     * Using the following R code to load the data and train the model using glmnet package.
-     *
-     * > library("glmnet")
-     * > data <- read.csv("path", header=FALSE)
-     * > label = factor(data$V1)
-     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
-     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21))
-     * > weights
-     * 5 x 1 sparse Matrix of class "dgCMatrix"
-     *                      s0
-     * (Intercept)  0.57734851
-     * data.V2     -0.05310287
-     * data.V3       .
-     * data.V4     -0.08849250
-     * data.V5     -0.15458796
+    val trainer1 = (new LogisticRegression).setFitIntercept(true)
+      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+    val trainer2 = (new LogisticRegression).setFitIntercept(true)
+      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+
+    val model1 = trainer1.fit(binaryDataset)
+    val model2 = trainer2.fit(binaryDataset)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                            s0
+       (Intercept)  0.57734851
+       data.V2     -0.05310287
+       data.V3       .
+       data.V4     -0.08849250
+       data.V5     -0.15458796
      */
-    val interceptR = 0.57734851
-    val weightsR = Array(-0.05310287, 0.0, -0.08849250, -0.15458796)
-
-    assert(model.intercept ~== interceptR relTol 6E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 5E-3)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
-    assert(model.weights(2) ~== weightsR(2) relTol 5E-3)
-    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
+    val interceptR1 = 0.57734851
+    val weightsR1 = Vectors.dense(-0.05310287, 0.0, -0.08849250, -0.15458796)
+
+    assert(model1.intercept ~== interceptR1 relTol 6E-3)
+    assert(model1.weights ~== weightsR1 absTol 5E-3)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
+           standardize=FALSE))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                            s0
+       (Intercept)  0.51555993
+       data.V2       .
+       data.V3       .
+       data.V4     -0.18807395
+       data.V5     -0.05350074
+     */
+    val interceptR2 = 0.51555993
+    val weightsR2 = Vectors.dense(0.0, 0.0, -0.18807395, -0.05350074)
+
+    assert(model2.intercept ~== interceptR2 relTol 6E-3)
+    assert(model2.weights ~= weightsR2 absTol 1E-3)
   }
 
   test("binary logistic regression without intercept with ElasticNet regularization") {
-    val trainer = (new LogisticRegression).setFitIntercept(false)
-      .setElasticNetParam(0.38).setRegParam(0.21)
-    val model = trainer.fit(binaryDataset)
-
-    /**
-     * Using the following R code to load the data and train the model using glmnet package.
-     *
-     * > library("glmnet")
-     * > data <- read.csv("path", header=FALSE)
-     * > label = factor(data$V1)
-     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
-     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
-     *     intercept=FALSE))
-     * > weights
-     * 5 x 1 sparse Matrix of class "dgCMatrix"
-     *                      s0
-     * (Intercept)   .
-     * data.V2     -0.001005743
-     * data.V3      0.072577857
-     * data.V4     -0.081203769
-     * data.V5     -0.142534158
+    val trainer1 = (new LogisticRegression).setFitIntercept(false)
+      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+    val trainer2 = (new LogisticRegression).setFitIntercept(false)
+      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+
+    val model1 = trainer1.fit(binaryDataset)
+    val model2 = trainer2.fit(binaryDataset)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
+           intercept=FALSE))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                            s0
+       (Intercept)   .
+       data.V2     -0.001005743
+       data.V3      0.072577857
+       data.V4     -0.081203769
+       data.V5     -0.142534158
      */
-    val interceptR = 0.0
-    val weightsR = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158)
+    val interceptR1 = 0.0
+    val weightsR1 = Vectors.dense(-0.001005743, 0.072577857, -0.081203769, -0.142534158)
+
+    assert(model1.intercept ~== interceptR1 relTol 1E-3)
+    assert(model1.weights ~= weightsR1 absTol 1E-2)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
+           intercept=FALSE, standardize=FALSE))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                            s0
+       (Intercept)   .
+       data.V2       .
+       data.V3      0.03345223
+       data.V4     -0.11304532
+       data.V5       .
+     */
+    val interceptR2 = 0.0
+    val weightsR2 = Vectors.dense(0.0, 0.03345223, -0.11304532, 0.0)
 
-    assert(model.intercept ~== interceptR relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) absTol 1E-3)
-    assert(model.weights(1) ~== weightsR(1) absTol 1E-2)
-    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
-    assert(model.weights(3) ~== weightsR(3) relTol 1E-2)
+    assert(model2.intercept ~== interceptR2 absTol 1E-3)
+    assert(model2.weights ~= weightsR2 absTol 1E-3)
   }
 
   test("binary logistic regression with intercept with strong L1 regularization") {
-    val trainer = (new LogisticRegression).setFitIntercept(true)
-      .setElasticNetParam(1.0).setRegParam(6.0)
-    val model = trainer.fit(binaryDataset)
+    val trainer1 = (new LogisticRegression).setFitIntercept(true)
+      .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(true)
+    val trainer2 = (new LogisticRegression).setFitIntercept(true)
+      .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(false)
+
+    val model1 = trainer1.fit(binaryDataset)
+    val model2 = trainer2.fit(binaryDataset)
 
     val histogram = binaryDataset.map { case Row(label: Double, features: Vector) => label }
       .treeAggregate(new MultiClassSummarizer)(
@@ -480,50 +657,48 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
             classSummarizer1.merge(classSummarizer2)
         }).histogram
 
-    /**
-     * For binary logistic regression with strong L1 regularization, all the weights will be zeros.
-     * As a result,
-     * {{{
-     * P(0) = 1 / (1 + \exp(b)), and
-     * P(1) = \exp(b) / (1 + \exp(b))
-     * }}}, hence
-     * {{{
-     * b = \log{P(1) / P(0)} = \log{count_1 / count_0}
-     * }}}
+    /*
+       For binary logistic regression with strong L1 regularization, all the weights will be zeros.
+       As a result,
+       {{{
+       P(0) = 1 / (1 + \exp(b)), and
+       P(1) = \exp(b) / (1 + \exp(b))
+       }}}, hence
+       {{{
+       b = \log{P(1) / P(0)} = \log{count_1 / count_0}
+       }}}
      */
     val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble)
-    val weightsTheory = Array(0.0, 0.0, 0.0, 0.0)
-
-    assert(model.intercept ~== interceptTheory relTol 1E-5)
-    assert(model.weights(0) ~== weightsTheory(0) absTol 1E-6)
-    assert(model.weights(1) ~== weightsTheory(1) absTol 1E-6)
-    assert(model.weights(2) ~== weightsTheory(2) absTol 1E-6)
-    assert(model.weights(3) ~== weightsTheory(3) absTol 1E-6)
-
-    /**
-     * Using the following R code to load the data and train the model using glmnet package.
-     *
-     * > library("glmnet")
-     * > data <- read.csv("path", header=FALSE)
-     * > label = factor(data$V1)
-     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
-     * > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0))
-     * > weights
-     * 5 x 1 sparse Matrix of class "dgCMatrix"
-     *                      s0
-     * (Intercept) -0.2480643
-     * data.V2      0.0000000
-     * data.V3       .
-     * data.V4       .
-     * data.V5       .
+    val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0)
+
+    assert(model1.intercept ~== interceptTheory relTol 1E-5)
+    assert(model1.weights ~= weightsTheory absTol 1E-6)
+
+    assert(model2.intercept ~== interceptTheory relTol 1E-5)
+    assert(model2.weights ~= weightsTheory absTol 1E-6)
+
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE)
+       label = factor(data$V1)
+       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+       weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0))
+       weights
+
+       5 x 1 sparse Matrix of class "dgCMatrix"
+                            s0
+       (Intercept) -0.2480643
+       data.V2      0.0000000
+       data.V3       .
+       data.V4       .
+       data.V5       .
      */
     val interceptR = -0.248065
-    val weightsR = Array(0.0, 0.0, 0.0, 0.0)
+    val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0)
 
-    assert(model.intercept ~== interceptR relTol 1E-5)
-    assert(model.weights(0) ~== weightsR(0) absTol 1E-6)
-    assert(model.weights(1) ~== weightsR(1) absTol 1E-6)
-    assert(model.weights(2) ~== weightsR(2) absTol 1E-6)
-    assert(model.weights(3) ~== weightsR(3) absTol 1E-6)
+    assert(model1.intercept ~== interceptR relTol 1E-5)
+    assert(model1.weights ~= weightsR absTol 1E-6)
   }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala
new file mode 100644
index 0000000000000..e90d9d4ef21ff
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  test("params") {
+    ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
+  }
+
+  test("CountVectorizerModel common cases") {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, "a b c d".split(" ").toSeq,
+        Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
+      (1, "a b b c d  a".split(" ").toSeq,
+        Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))),
+      (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))),
+      (3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string
+      (4, "a notInDict d".split(" ").toSeq,
+        Vectors.sparse(4, Seq((0, 1.0), (3, 1.0))))  // with words not in vocabulary
+    )).toDF("id", "words", "expected")
+    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
+      .setInputCol("words")
+      .setOutputCol("features")
+    val output = cv.transform(df).collect()
+    output.foreach { p =>
+      val features = p.getAs[Vector]("features")
+      val expected = p.getAs[Vector]("expected")
+      assert(features ~== expected absTol 1e-14)
+    }
+  }
+
+  test("CountVectorizerModel with minTermFreq") {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
+      (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))),
+      (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())),
+      (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq())))
+    ).toDF("id", "words", "expected")
+    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
+      .setInputCol("words")
+      .setOutputCol("features")
+      .setMinTermFreq(3)
+    val output = cv.transform(df).collect()
+    output.foreach { p =>
+      val features = p.getAs[Vector]("features")
+      val expected = p.getAs[Vector]("expected")
+      assert(features ~== expected absTol 1e-14)
+    }
+  }
+}
+
+
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
new file mode 100644
index 0000000000000..37ed2367c33f7
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.beans.BeanInfo
+
+import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row}
+
+@BeanInfo
+case class DCTTestData(vec: Vector, wantedVec: Vector)
+
+class DCTSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  test("forward transform of discrete cosine matches jTransforms result") {
+    val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
+    val inverse = false
+
+    testDCT(data, inverse)
+  }
+
+  test("inverse transform of discrete cosine matches jTransforms result") {
+    val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
+    val inverse = true
+
+    testDCT(data, inverse)
+  }
+
+  private def testDCT(data: Vector, inverse: Boolean): Unit = {
+    val expectedResultBuffer = data.toArray.clone()
+    if (inverse) {
+      (new DoubleDCT_1D(data.size)).inverse(expectedResultBuffer, true)
+    } else {
+      (new DoubleDCT_1D(data.size)).forward(expectedResultBuffer, true)
+    }
+    val expectedResult = Vectors.dense(expectedResultBuffer)
+
+    val dataset = sqlContext.createDataFrame(Seq(
+      DCTTestData(data, expectedResult)
+    ))
+
+    val transformer = new DCT()
+      .setInputCol("vec")
+      .setOutputCol("resultVec")
+      .setInverse(inverse)
+
+    transformer.transform(dataset)
+      .select("resultVec", "wantedVec")
+      .collect()
+      .foreach { case Row(resultVec: Vector, wantedVec: Vector) =>
+      assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6)
+    }
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
new file mode 100644
index 0000000000000..c452054bec92f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{Row, SQLContext}
+
+class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  test("MinMaxScaler fit basic case") {
+    val sqlContext = new SQLContext(sc)
+
+    val data = Array(
+      Vectors.dense(1, 0, Long.MinValue),
+      Vectors.dense(2, 0, 0),
+      Vectors.sparse(3, Array(0, 2), Array(3, Long.MaxValue)),
+      Vectors.sparse(3, Array(0), Array(1.5)))
+
+    val expected: Array[Vector] = Array(
+      Vectors.dense(-5, 0, -5),
+      Vectors.dense(0, 0, 0),
+      Vectors.sparse(3, Array(0, 2), Array(5, 5)),
+      Vectors.sparse(3, Array(0), Array(-2.5)))
+
+    val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
+    val scaler = new MinMaxScaler()
+      .setInputCol("features")
+      .setOutputCol("scaled")
+      .setMin(-5)
+      .setMax(5)
+
+    val model = scaler.fit(df)
+    model.transform(df).select("expected", "scaled").collect()
+      .foreach { case Row(vector1: Vector, vector2: Vector) =>
+        assert(vector1.equals(vector2), "Transformed vector is different with expected.")
+    }
+  }
+
+  test("MinMaxScaler arguments max must be larger than min") {
+    withClue("arguments max must be larger than min") {
+      intercept[IllegalArgumentException] {
+        val scaler = new MinMaxScaler().setMin(10).setMax(0)
+        scaler.validateParams()
+      }
+      intercept[IllegalArgumentException] {
+        val scaler = new MinMaxScaler().setMin(0).setMax(0)
+        scaler.validateParams()
+      }
+    }
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
new file mode 100644
index 0000000000000..d0ae36b28c7a9
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.distributed.RowMatrix
+import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel}
+import org.apache.spark.sql.Row
+
+class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  test("params") {
+    ParamsSuite.checkParams(new PCA)
+    val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]
+    val model = new PCAModel("pca", new OldPCAModel(2, mat))
+    ParamsSuite.checkParams(model)
+  }
+
+  test("pca") {
+    val data = Array(
+      Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
+      Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
+      Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
+    )
+
+    val dataRDD = sc.parallelize(data, 2)
+
+    val mat = new RowMatrix(dataRDD)
+    val pc = mat.computePrincipalComponents(3)
+    val expected = mat.multiply(pc).rows
+
+    val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected")
+
+    val pca = new PCA()
+      .setInputCol("features")
+      .setOutputCol("pca_features")
+      .setK(3)
+      .fit(df)
+
+    pca.transform(df).select("pca_features", "expected").collect().foreach {
+      case Row(x: Vector, y: Vector) =>
+        assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
+    }
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 8c85c96d5c6d8..03120c828ca96 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
 
 import scala.beans.{BeanInfo, BeanProperty}
 
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.{Logging, SparkException, SparkFunSuite}
 import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
@@ -27,7 +27,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 
-class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
 
   import VectorIndexerSuite.FeatureData
 
@@ -113,11 +113,11 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
     model.transform(sparsePoints1) // should work
     intercept[SparkException] {
       model.transform(densePoints2).collect()
-      println("Did not throw error when fit, transform were called on vectors of different lengths")
+      logInfo("Did not throw error when fit, transform were called on vectors of different lengths")
     }
     intercept[SparkException] {
       vectorIndexer.fit(badPoints)
-      println("Did not throw error when fitting vectors of different lengths in same RDD.")
+      logInfo("Did not throw error when fitting vectors of different lengths in same RDD.")
     }
   }
 
@@ -196,7 +196,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
         }
       } catch {
         case e: org.scalatest.exceptions.TestFailedException =>
-          println(errMsg)
+          logError(errMsg)
           throw e
       }
     }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 98fb3d3f5f22c..9682edcd9ba84 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -19,12 +19,13 @@ package org.apache.spark.ml.regression
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
 
 
 /**
@@ -67,6 +68,26 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
     }
   }
 
+  test("GBTRegressor behaves reasonably on toy data") {
+    val df = sqlContext.createDataFrame(Seq(
+      LabeledPoint(10, Vectors.dense(1, 2, 3, 4)),
+      LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)),
+      LabeledPoint(11, Vectors.dense(2, 2, 3, 4)),
+      LabeledPoint(-6, Vectors.dense(6, 4, 2, 1)),
+      LabeledPoint(9, Vectors.dense(1, 2, 6, 4)),
+      LabeledPoint(-4, Vectors.dense(6, 3, 2, 2))
+    ))
+    val gbt = new GBTRegressor()
+      .setMaxDepth(2)
+      .setMaxIter(2)
+    val model = gbt.fit(df)
+    val preds = model.transform(df)
+    val predictions = preds.select("prediction").map(_.getDouble(0))
+    // Checks based on SPARK-8736 (to ensure it is not doing classification)
+    assert(predictions.max() > 2)
+    assert(predictions.min() < -1)
+  }
+
   // TODO: Reinstate test once runWithValidation is implemented  SPARK-7132
   /*
   test("runWithValidation stops early and performs better on a validation dataset") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index ad1e9da692ee2..cf120cf2a4b47 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.ml.regression
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.linalg.DenseVector
+import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.{DataFrame, Row}
@@ -28,26 +28,26 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
   @transient var dataset: DataFrame = _
   @transient var datasetWithoutIntercept: DataFrame = _
 
-  /**
-   * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
-   * is the same as the one trained by R's glmnet package. The following instruction
-   * describes how to reproduce the data in R.
-   *
-   * import org.apache.spark.mllib.util.LinearDataGenerator
-   * val data =
-   *   sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2),
-   *     Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)
-   * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1)
-   *   .saveAsTextFile("path")
+  /*
+     In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
+     is the same as the one trained by R's glmnet package. The following instruction
+     describes how to reproduce the data in R.
+
+     import org.apache.spark.mllib.util.LinearDataGenerator
+     val data =
+       sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2),
+         Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)
+     data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1)
+       .saveAsTextFile("path")
    */
   override def beforeAll(): Unit = {
     super.beforeAll()
     dataset = sqlContext.createDataFrame(
       sc.parallelize(LinearDataGenerator.generateLinearInput(
         6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
-    /**
-     * datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
-     * training model without intercept
+    /*
+       datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
+       training model without intercept
      */
     datasetWithoutIntercept = sqlContext.createDataFrame(
       sc.parallelize(LinearDataGenerator.generateLinearInput(
@@ -59,27 +59,26 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     val trainer = new LinearRegression
     val model = trainer.fit(dataset)
 
-    /**
-     * Using the following R code to load the data and train the model using glmnet package.
-     *
-     * library("glmnet")
-     * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
-     * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
-     * label <- as.numeric(data$V1)
-     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0))
-     * > weights
-     *  3 x 1 sparse Matrix of class "dgCMatrix"
-     *                           s0
-     * (Intercept)         6.300528
-     * as.numeric.data.V2. 4.701024
-     * as.numeric.data.V3. 7.198257
+    /*
+       Using the following R code to load the data and train the model using glmnet package.
+
+       library("glmnet")
+       data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+       features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
+       label <- as.numeric(data$V1)
+       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0))
+       > weights
+        3 x 1 sparse Matrix of class "dgCMatrix"
+                                 s0
+       (Intercept)         6.300528
+       as.numeric.data.V2. 4.701024
+       as.numeric.data.V3. 7.198257
      */
     val interceptR = 6.298698
-    val weightsR = Array(4.700706, 7.199082)
+    val weightsR = Vectors.dense(4.700706, 7.199082)
 
     assert(model.intercept ~== interceptR relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+    assert(model.weights ~= weightsR relTol 1E-3)
 
     model.transform(dataset).select("features", "prediction").collect().foreach {
       case Row(features: DenseVector, prediction1: Double) =>
@@ -94,56 +93,53 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     val model = trainer.fit(dataset)
     val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept)
 
-    /**
-     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
-     *   intercept = FALSE))
-     * > weights
-     *  3 x 1 sparse Matrix of class "dgCMatrix"
-     *                           s0
-     * (Intercept)         .
-     * as.numeric.data.V2. 6.995908
-     * as.numeric.data.V3. 5.275131
+    /*
+       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
+         intercept = FALSE))
+       > weights
+        3 x 1 sparse Matrix of class "dgCMatrix"
+                                 s0
+       (Intercept)         .
+       as.numeric.data.V2. 6.995908
+       as.numeric.data.V3. 5.275131
      */
-    val weightsR = Array(6.995908, 5.275131)
-
-    assert(model.intercept ~== 0 relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
-    /**
-     * Then again with the data with no intercept:
-     * > weightsWithoutIntercept
-     * 3 x 1 sparse Matrix of class "dgCMatrix"
-     *                             s0
-     * (Intercept)           .
-     * as.numeric.data3.V2. 4.70011
-     * as.numeric.data3.V3. 7.19943
+    val weightsR = Vectors.dense(6.995908, 5.275131)
+
+    assert(model.intercept ~== 0 absTol 1E-3)
+    assert(model.weights ~= weightsR relTol 1E-3)
+    /*
+       Then again with the data with no intercept:
+       > weightsWithoutIntercept
+       3 x 1 sparse Matrix of class "dgCMatrix"
+                                   s0
+       (Intercept)           .
+       as.numeric.data3.V2. 4.70011
+       as.numeric.data3.V3. 7.19943
      */
-    val weightsWithoutInterceptR = Array(4.70011, 7.19943)
+    val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943)
 
-    assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3)
-    assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3)
-    assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3)
+    assert(modelWithoutIntercept.intercept ~== 0 absTol 1E-3)
+    assert(modelWithoutIntercept.weights ~= weightsWithoutInterceptR relTol 1E-3)
   }
 
   test("linear regression with intercept with L1 regularization") {
     val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
     val model = trainer.fit(dataset)
 
-    /**
-     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
-     * > weights
-     *  3 x 1 sparse Matrix of class "dgCMatrix"
-     *                           s0
-     * (Intercept)         6.24300
-     * as.numeric.data.V2. 4.024821
-     * as.numeric.data.V3. 6.679841
+    /*
+       weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
+       > weights
+        3 x 1 sparse Matrix of class "dgCMatrix"
+                                 s0
+       (Intercept)         6.24300
+       as.numeric.data.V2. 4.024821
+       as.numeric.data.V3. 6.679841
      */
     val interceptR = 6.24300
-    val weightsR = Array(4.024821, 6.679841)
+    val weightsR = Vectors.dense(4.024821, 6.679841)
 
     assert(model.intercept ~== interceptR relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+    assert(model.weights ~= weightsR relTol 1E-3)
 
     model.transform(dataset).select("features", "prediction").collect().foreach {
       case Row(features: DenseVector, prediction1: Double) =>
@@ -158,22 +154,21 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
       .setFitIntercept(false)
     val model = trainer.fit(dataset)
 
-    /**
-     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
-     *   intercept=FALSE))
-     * > weights
-     *  3 x 1 sparse Matrix of class "dgCMatrix"
-     *                           s0
-     * (Intercept)          .
-     * as.numeric.data.V2. 6.299752
-     * as.numeric.data.V3. 4.772913
+    /*
+       weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
+         intercept=FALSE))
+       > weights
+        3 x 1 sparse Matrix of class "dgCMatrix"
+                                 s0
+       (Intercept)          .
+       as.numeric.data.V2. 6.299752
+       as.numeric.data.V3. 4.772913
      */
     val interceptR = 0.0
-    val weightsR = Array(6.299752, 4.772913)
+    val weightsR = Vectors.dense(6.299752, 4.772913)
 
-    assert(model.intercept ~== interceptR relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+    assert(model.intercept ~== interceptR absTol 1E-5)
+    assert(model.weights ~= weightsR relTol 1E-3)
 
     model.transform(dataset).select("features", "prediction").collect().foreach {
       case Row(features: DenseVector, prediction1: Double) =>
@@ -187,21 +182,20 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
     val model = trainer.fit(dataset)
 
-    /**
-     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
-     * > weights
-     *  3 x 1 sparse Matrix of class "dgCMatrix"
-     *                           s0
-     * (Intercept)         6.328062
-     * as.numeric.data.V2. 3.222034
-     * as.numeric.data.V3. 4.926260
+    /*
+       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
+       > weights
+        3 x 1 sparse Matrix of class "dgCMatrix"
+                                 s0
+       (Intercept)         6.328062
+       as.numeric.data.V2. 3.222034
+       as.numeric.data.V3. 4.926260
      */
     val interceptR = 5.269376
-    val weightsR = Array(3.736216, 5.712356)
+    val weightsR = Vectors.dense(3.736216, 5.712356)
 
     assert(model.intercept ~== interceptR relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+    assert(model.weights ~= weightsR relTol 1E-3)
 
     model.transform(dataset).select("features", "prediction").collect().foreach {
       case Row(features: DenseVector, prediction1: Double) =>
@@ -216,22 +210,21 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
       .setFitIntercept(false)
     val model = trainer.fit(dataset)
 
-    /**
-     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
-     *   intercept = FALSE))
-     * > weights
-     *  3 x 1 sparse Matrix of class "dgCMatrix"
-     *                           s0
-     * (Intercept)         .
-     * as.numeric.data.V2. 5.522875
-     * as.numeric.data.V3. 4.214502
+    /*
+       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
+         intercept = FALSE))
+       > weights
+        3 x 1 sparse Matrix of class "dgCMatrix"
+                                 s0
+       (Intercept)         .
+       as.numeric.data.V2. 5.522875
+       as.numeric.data.V3. 4.214502
      */
     val interceptR = 0.0
-    val weightsR = Array(5.522875, 4.214502)
+    val weightsR = Vectors.dense(5.522875, 4.214502)
 
-    assert(model.intercept ~== interceptR relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+    assert(model.intercept ~== interceptR absTol 1E-3)
+    assert(model.weights ~== weightsR relTol 1E-3)
 
     model.transform(dataset).select("features", "prediction").collect().foreach {
       case Row(features: DenseVector, prediction1: Double) =>
@@ -245,21 +238,20 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
     val model = trainer.fit(dataset)
 
-    /**
-     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
-     * > weights
-     * 3 x 1 sparse Matrix of class "dgCMatrix"
-     * s0
-     * (Intercept)         6.324108
-     * as.numeric.data.V2. 3.168435
-     * as.numeric.data.V3. 5.200403
+    /*
+       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
+       > weights
+       3 x 1 sparse Matrix of class "dgCMatrix"
+       s0
+       (Intercept)         6.324108
+       as.numeric.data.V2. 3.168435
+       as.numeric.data.V3. 5.200403
      */
     val interceptR = 5.696056
-    val weightsR = Array(3.670489, 6.001122)
+    val weightsR = Vectors.dense(3.670489, 6.001122)
 
     assert(model.intercept ~== interceptR relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+    assert(model.weights ~== weightsR relTol 1E-3)
 
     model.transform(dataset).select("features", "prediction").collect().foreach {
       case Row(features: DenseVector, prediction1: Double) =>
@@ -274,22 +266,21 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
       .setFitIntercept(false)
     val model = trainer.fit(dataset)
 
-    /**
-     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
-     *   intercept=FALSE))
-     * > weights
-     * 3 x 1 sparse Matrix of class "dgCMatrix"
-     * s0
-     * (Intercept)         .
-     * as.numeric.dataM.V2. 5.673348
-     * as.numeric.dataM.V3. 4.322251
+    /*
+       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
+         intercept=FALSE))
+       > weights
+       3 x 1 sparse Matrix of class "dgCMatrix"
+       s0
+       (Intercept)         .
+       as.numeric.dataM.V2. 5.673348
+       as.numeric.dataM.V3. 4.322251
      */
     val interceptR = 0.0
-    val weightsR = Array(5.673348, 4.322251)
+    val weightsR = Vectors.dense(5.673348, 4.322251)
 
-    assert(model.intercept ~== interceptR relTol 1E-3)
-    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
-    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+    assert(model.intercept ~== interceptR absTol 1E-3)
+    assert(model.weights ~= weightsR relTol 1E-3)
 
     model.transform(dataset).select("features", "prediction").collect().foreach {
       case Row(features: DenseVector, prediction1: Double) =>
@@ -298,4 +289,63 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
         assert(prediction1 ~== prediction2 relTol 1E-5)
     }
   }
+
+  test("linear regression model training summary") {
+    val trainer = new LinearRegression
+    val model = trainer.fit(dataset)
+
+    // Training results for the model should be available
+    assert(model.hasSummary)
+
+    // Residuals in [[LinearRegressionResults]] should equal those manually computed
+    val expectedResiduals = dataset.select("features", "label")
+      .map { case Row(features: DenseVector, label: Double) =>
+      val prediction =
+        features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+      prediction - label
+    }
+      .zip(model.summary.residuals.map(_.getDouble(0)))
+      .collect()
+      .foreach { case (manualResidual: Double, resultResidual: Double) =>
+      assert(manualResidual ~== resultResidual relTol 1E-5)
+    }
+
+    /*
+       Use the following R code to generate model training results.
+
+       predictions <- predict(fit, newx=features)
+       residuals <- predictions - label
+       > mean(residuals^2) # MSE
+       [1] 0.009720325
+       > mean(abs(residuals)) # MAD
+       [1] 0.07863206
+       > cor(predictions, label)^2# r^2
+               [,1]
+       s0 0.9998749
+     */
+    assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5)
+    assert(model.summary.meanAbsoluteError ~== 0.07863206  relTol 1E-5)
+    assert(model.summary.r2 ~== 0.9998749 relTol 1E-5)
+
+    // Objective function should be monotonically decreasing for linear regression
+    assert(
+      model.summary
+        .objectiveHistory
+        .sliding(2)
+        .forall(x => x(0) >= x(1)))
+  }
+
+  test("linear regression model testset evaluation summary") {
+    val trainer = new LinearRegression
+    val model = trainer.fit(dataset)
+
+    // Evaluating on training dataset should yield results summary equal to training summary
+    val testSummary = model.evaluate(dataset)
+    assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5)
+    assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5)
+    model.summary.residuals.select("residuals").collect()
+      .zip(testSummary.residuals.select("residuals").collect())
+      .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 }
+  }
+
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index e8f3d0c4db20a..2473510e13514 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -196,6 +196,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
       .setStepSize(10.0)
       .setRegParam(0.0)
       .setNumIterations(20)
+      .setConvergenceTol(0.0005)
 
     val model = lr.run(testRDD)
 
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 406affa25539d..03a8a2538b464 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -99,9 +99,13 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
 
     // Check: per-doc topic distributions
     val topicDistributions = model.topicDistributions.collect()
+
     //  Ensure all documents are covered.
-    assert(topicDistributions.length === tinyCorpus.length)
-    assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
+    // SPARK-5562. since the topicDistribution returns the distribution of the non empty docs
+    // over topics. Compare it against nonEmptyTinyCorpus instead of tinyCorpus
+    val nonEmptyTinyCorpus = getNonEmptyDoc(tinyCorpus)
+    assert(topicDistributions.length === nonEmptyTinyCorpus.length)
+    assert(nonEmptyTinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
     //  Ensure we have proper distributions
     topicDistributions.foreach { case (docId, topicDistribution) =>
       assert(topicDistribution.size === tinyK)
@@ -232,12 +236,17 @@ private[clustering] object LDASuite {
   }
 
   def tinyCorpus: Array[(Long, Vector)] = Array(
+    Vectors.dense(0, 0, 0, 0, 0), // empty doc
     Vectors.dense(1, 3, 0, 2, 8),
     Vectors.dense(0, 2, 1, 0, 4),
     Vectors.dense(2, 3, 12, 3, 1),
+    Vectors.dense(0, 0, 0, 0, 0), // empty doc
     Vectors.dense(0, 3, 1, 9, 8),
     Vectors.dense(1, 1, 4, 2, 6)
   ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
   assert(tinyCorpus.forall(_._2.size == tinyVocabSize)) // sanity check for test data
 
+  def getNonEmptyDoc(corpus: Array[(Long, Vector)]): Array[(Long, Vector)] = corpus.filter {
+    case (_, wc: Vector) => Vectors.norm(wc, p = 1.0) != 0.0
+  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
new file mode 100644
index 0000000000000..77a2773c36f56
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.fpm
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  test("association rules using String type") {
+    val freqItemsets = sc.parallelize(Seq(
+      (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
+      (Set("r"), 3L),
+      (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
+      (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
+      (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
+      (Set("t", "y", "x"), 3L),
+      (Set("t", "y", "x", "z"), 3L)
+    ).map {
+      case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq)
+    })
+
+    val ar = new AssociationRules()
+
+    val results1 = ar
+      .setMinConfidence(0.9)
+      .run(freqItemsets)
+      .collect()
+
+    /* Verify results using the `R` code:
+       transactions = as(sapply(
+         list("r z h k p",
+              "z y x w v u t s",
+              "s x o n r",
+              "x z y m t s q e",
+              "z",
+              "x z y r q t p"),
+         FUN=function(x) strsplit(x," ",fixed=TRUE)),
+         "transactions")
+       ars = apriori(transactions,
+                     parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
+       arsDF = as(ars, "data.frame")
+       arsDF$support = arsDF$support * length(transactions)
+       names(arsDF)[names(arsDF) == "support"] = "freq"
+       > nrow(arsDF)
+       [1] 23
+       > sum(arsDF$confidence == 1)
+       [1] 23
+     */
+    assert(results1.size === 23)
+    assert(results1.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
+
+    val results2 = ar
+      .setMinConfidence(0)
+      .run(freqItemsets)
+      .collect()
+
+    /* Verify results using the `R` code:
+       ars = apriori(transactions,
+                  parameter = list(support = 0.5, confidence = 0.5, target="rules", minlen=2))
+       arsDF = as(ars, "data.frame")
+       arsDF$support = arsDF$support * length(transactions)
+       names(arsDF)[names(arsDF) == "support"] = "freq"
+       nrow(arsDF)
+       sum(arsDF$confidence == 1)
+       > nrow(arsDF)
+       [1] 30
+       > sum(arsDF$confidence == 1)
+       [1] 23
+     */
+    assert(results2.size === 30)
+    assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
+  }
+}
+
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index 1a8a1e79f2810..4a9bfdb348d9f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
 
 
-  test("FP-Growth frequent itemsets using String type") {
+  test("FP-Growth using String type") {
     val transactions = Seq(
       "r z h k p",
       "z y x w v u t s",
@@ -38,18 +38,59 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
     val model6 = fpg
       .setMinSupport(0.9)
       .setNumPartitions(1)
-      .setOrdered(false)
       .run(rdd)
+
+    /* Verify results using the `R` code:
+       transactions = as(sapply(
+         list("r z h k p",
+              "z y x w v u t s",
+              "s x o n r",
+              "x z y m t s q e",
+              "z",
+              "x z y r q t p"),
+         FUN=function(x) strsplit(x," ",fixed=TRUE)),
+         "transactions")
+       > eclat(transactions, parameter = list(support = 0.9))
+       ...
+       eclat - zero frequent items
+       set of 0 itemsets
+     */
     assert(model6.freqItemsets.count() === 0)
 
     val model3 = fpg
       .setMinSupport(0.5)
       .setNumPartitions(2)
-      .setOrdered(false)
       .run(rdd)
     val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
       (itemset.items.toSet, itemset.freq)
     }
+
+    /* Verify results using the `R` code:
+       fp = eclat(transactions, parameter = list(support = 0.5))
+       fpDF = as(sort(fp), "data.frame")
+       fpDF$support = fpDF$support * length(transactions)
+       names(fpDF)[names(fpDF) == "support"] = "freq"
+       > fpDF
+              items freq
+       13       {z}    5
+       14       {x}    4
+       1      {s,x}    3
+       2  {t,x,y,z}    3
+       3    {t,y,z}    3
+       4    {t,x,y}    3
+       5    {x,y,z}    3
+       6      {y,z}    3
+       7      {x,y}    3
+       8      {t,y}    3
+       9    {t,x,z}    3
+       10     {t,z}    3
+       11     {t,x}    3
+       12     {x,z}    3
+       15       {t}    3
+       16       {y}    3
+       17       {s}    3
+       18       {r}    3
+     */
     val expected = Set(
       (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
       (Set("r"), 3L),
@@ -63,19 +104,35 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
     val model2 = fpg
       .setMinSupport(0.3)
       .setNumPartitions(4)
-      .setOrdered(false)
       .run(rdd)
+
+    /* Verify results using the `R` code:
+       fp = eclat(transactions, parameter = list(support = 0.3))
+       fpDF = as(fp, "data.frame")
+       fpDF$support = fpDF$support * length(transactions)
+       names(fpDF)[names(fpDF) == "support"] = "freq"
+       > nrow(fpDF)
+       [1] 54
+     */
     assert(model2.freqItemsets.count() === 54)
 
     val model1 = fpg
       .setMinSupport(0.1)
       .setNumPartitions(8)
-      .setOrdered(false)
       .run(rdd)
+
+    /* Verify results using the `R` code:
+       fp = eclat(transactions, parameter = list(support = 0.1))
+       fpDF = as(fp, "data.frame")
+       fpDF$support = fpDF$support * length(transactions)
+       names(fpDF)[names(fpDF) == "support"] = "freq"
+       > nrow(fpDF)
+       [1] 625
+     */
     assert(model1.freqItemsets.count() === 625)
   }
 
-  test("FP-Growth frequent sequences using String type"){
+  test("FP-Growth String type association rule generation") {
     val transactions = Seq(
       "r z h k p",
       "z y x w v u t s",
@@ -86,36 +143,38 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
       .map(_.split(" "))
     val rdd = sc.parallelize(transactions, 2).cache()
 
-    val fpg = new FPGrowth()
-
-    val model1 = fpg
+    /* Verify results using the `R` code:
+       transactions = as(sapply(
+         list("r z h k p",
+              "z y x w v u t s",
+              "s x o n r",
+              "x z y m t s q e",
+              "z",
+              "x z y r q t p"),
+         FUN=function(x) strsplit(x," ",fixed=TRUE)),
+         "transactions")
+       ars = apriori(transactions,
+                     parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
+       arsDF = as(ars, "data.frame")
+       arsDF$support = arsDF$support * length(transactions)
+       names(arsDF)[names(arsDF) == "support"] = "freq"
+       > nrow(arsDF)
+       [1] 23
+       > sum(arsDF$confidence == 1)
+       [1] 23
+     */
+    val rules = (new FPGrowth())
       .setMinSupport(0.5)
       .setNumPartitions(2)
-      .setOrdered(true)
       .run(rdd)
+      .generateAssociationRules(0.9)
+      .collect()
 
-    /*
-      Use the following R code to verify association rules using arulesSequences package.
-
-      data = read_baskets("path", info = c("sequenceID","eventID","SIZE"))
-      freqItemSeq = cspade(data, parameter = list(support = 0.5))
-      resSeq = as(freqItemSeq, "data.frame")
-      resSeq$support = resSeq$support * length(transactions)
-      names(resSeq)[names(resSeq) == "support"] = "freq"
-      resSeq
-     */
-    val expected = Set(
-      (Seq("r"), 3L), (Seq("s"), 3L), (Seq("t"), 3L), (Seq("x"), 4L), (Seq("y"), 3L),
-      (Seq("z"), 5L), (Seq("z", "y"), 3L), (Seq("x", "t"), 3L), (Seq("y", "t"), 3L),
-      (Seq("z", "t"), 3L), (Seq("z", "y", "t"), 3L)
-    )
-    val freqItemseqs1 = model1.freqItemsets.collect().map { itemset =>
-      (itemset.items.toSeq, itemset.freq)
-    }.toSet
-    assert(freqItemseqs1 == expected)
+    assert(rules.size === 23)
+    assert(rules.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
   }
 
-  test("FP-Growth frequent itemsets using Int type") {
+  test("FP-Growth using Int type") {
     val transactions = Seq(
       "1 2 3",
       "1 2 3 4",
@@ -132,20 +191,53 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
     val model6 = fpg
       .setMinSupport(0.9)
       .setNumPartitions(1)
-      .setOrdered(false)
       .run(rdd)
+
+    /* Verify results using the `R` code:
+       transactions = as(sapply(
+         list("1 2 3",
+              "1 2 3 4",
+              "5 4 3 2 1",
+              "6 5 4 3 2 1",
+              "2 4",
+              "1 3",
+              "1 7"),
+         FUN=function(x) strsplit(x," ",fixed=TRUE)),
+         "transactions")
+       > eclat(transactions, parameter = list(support = 0.9))
+       ...
+       eclat - zero frequent items
+       set of 0 itemsets
+     */
     assert(model6.freqItemsets.count() === 0)
 
     val model3 = fpg
       .setMinSupport(0.5)
       .setNumPartitions(2)
-      .setOrdered(false)
       .run(rdd)
     assert(model3.freqItemsets.first().items.getClass === Array(1).getClass,
       "frequent itemsets should use primitive arrays")
     val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
       (itemset.items.toSet, itemset.freq)
     }
+
+    /* Verify results using the `R` code:
+       fp = eclat(transactions, parameter = list(support = 0.5))
+       fpDF = as(sort(fp), "data.frame")
+       fpDF$support = fpDF$support * length(transactions)
+       names(fpDF)[names(fpDF) == "support"] = "freq"
+       > fpDF
+          items freq
+      6     {1}    6
+      3   {1,3}    5
+      7     {2}    5
+      8     {3}    5
+      1   {2,4}    4
+      2 {1,2,3}    4
+      4   {2,3}    4
+      5   {1,2}    4
+      9     {4}    4
+     */
     val expected = Set(
       (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L),
       (Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L),
@@ -155,15 +247,31 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
     val model2 = fpg
       .setMinSupport(0.3)
       .setNumPartitions(4)
-      .setOrdered(false)
       .run(rdd)
+
+    /* Verify results using the `R` code:
+       fp = eclat(transactions, parameter = list(support = 0.3))
+       fpDF = as(fp, "data.frame")
+       fpDF$support = fpDF$support * length(transactions)
+       names(fpDF)[names(fpDF) == "support"] = "freq"
+       > nrow(fpDF)
+       [1] 15
+     */
     assert(model2.freqItemsets.count() === 15)
 
     val model1 = fpg
       .setMinSupport(0.1)
       .setNumPartitions(8)
-      .setOrdered(false)
       .run(rdd)
+
+    /* Verify results using the `R` code:
+       fp = eclat(transactions, parameter = list(support = 0.1))
+       fpDF = as(fp, "data.frame")
+       fpDF$support = fpDF$support * length(transactions)
+       names(fpDF)[names(fpDF) == "support"] = "freq"
+       > nrow(fpDF)
+       [1] 65
+     */
     assert(model1.freqItemsets.count() === 65)
   }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
new file mode 100644
index 0000000000000..413436d3db85f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.fpm
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+
+class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  test("PrefixSpan using Integer type") {
+
+    /*
+      library("arulesSequences")
+      prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE"))
+      freqItemSeq = cspade(
+        prefixSpanSeqs,
+        parameter = list(support =
+          2 / length(unique(transactionInfo(prefixSpanSeqs)$sequenceID)), maxlen = 2 ))
+      resSeq = as(freqItemSeq, "data.frame")
+      resSeq
+    */
+
+    val sequences = Array(
+      Array(1, 3, 4, 5),
+      Array(2, 3, 1),
+      Array(2, 4, 1),
+      Array(3, 1, 3, 4, 5),
+      Array(3, 4, 4, 3),
+      Array(6, 5, 3))
+
+    val rdd = sc.parallelize(sequences, 2).cache()
+
+    def compareResult(
+        expectedValue: Array[(Array[Int], Long)],
+        actualValue: Array[(Array[Int], Long)]): Boolean = {
+      val sortedExpectedValue = expectedValue.sortWith{ (x, y) =>
+        x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
+      }
+      val sortedActualValue = actualValue.sortWith{ (x, y) =>
+        x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
+      }
+      sortedExpectedValue.zip(sortedActualValue)
+        .map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2)
+        .reduce(_&&_)
+    }
+
+    val prefixspan = new PrefixSpan()
+      .setMinSupport(0.33)
+      .setMaxPatternLength(50)
+    val result1 = prefixspan.run(rdd)
+    val expectedValue1 = Array(
+      (Array(1), 4L),
+      (Array(1, 3), 2L),
+      (Array(1, 3, 4), 2L),
+      (Array(1, 3, 4, 5), 2L),
+      (Array(1, 3, 5), 2L),
+      (Array(1, 4), 2L),
+      (Array(1, 4, 5), 2L),
+      (Array(1, 5), 2L),
+      (Array(2), 2L),
+      (Array(2, 1), 2L),
+      (Array(3), 5L),
+      (Array(3, 1), 2L),
+      (Array(3, 3), 2L),
+      (Array(3, 4), 3L),
+      (Array(3, 4, 5), 2L),
+      (Array(3, 5), 2L),
+      (Array(4), 4L),
+      (Array(4, 5), 2L),
+      (Array(5), 3L)
+    )
+    assert(compareResult(expectedValue1, result1.collect()))
+
+    prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
+    val result2 = prefixspan.run(rdd)
+    val expectedValue2 = Array(
+      (Array(1), 4L),
+      (Array(3), 5L),
+      (Array(3, 4), 3L),
+      (Array(4), 4L),
+      (Array(5), 3L)
+    )
+    assert(compareResult(expectedValue2, result2.collect()))
+
+    prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
+    val result3 = prefixspan.run(rdd)
+    val expectedValue3 = Array(
+      (Array(1), 4L),
+      (Array(1, 3), 2L),
+      (Array(1, 4), 2L),
+      (Array(1, 5), 2L),
+      (Array(2, 1), 2L),
+      (Array(2), 2L),
+      (Array(3), 5L),
+      (Array(3, 1), 2L),
+      (Array(3, 3), 2L),
+      (Array(3, 4), 3L),
+      (Array(3, 5), 2L),
+      (Array(4), 4L),
+      (Array(4, 5), 2L),
+      (Array(5), 3L)
+    )
+    assert(compareResult(expectedValue3, result3.collect()))
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index 8dbb70f5d1c4c..a270ba2562db9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -455,4 +455,14 @@ class MatricesSuite extends SparkFunSuite {
     lines = mat.toString(5, 100).lines.toArray
     assert(lines.size == 5 && lines.forall(_.size <= 100))
   }
+
+  test("numNonzeros and numActives") {
+    val dm1 = Matrices.dense(3, 2, Array(0, 0, -1, 1, 0, 1))
+    assert(dm1.numNonzeros === 3)
+    assert(dm1.numActives === 6)
+
+    val sm1 = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0))
+    assert(sm1.numNonzeros === 1)
+    assert(sm1.numActives === 3)
+  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index c4ae0a16f7c04..178d95a7b94ec 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -21,10 +21,10 @@ import scala.util.Random
 
 import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance}
 
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.{Logging, SparkException, SparkFunSuite}
 import org.apache.spark.mllib.util.TestingUtils._
 
-class VectorsSuite extends SparkFunSuite {
+class VectorsSuite extends SparkFunSuite with Logging {
 
   val arr = Array(0.1, 0.0, 0.3, 0.4)
   val n = 4
@@ -142,7 +142,7 @@ class VectorsSuite extends SparkFunSuite {
     malformatted.foreach { s =>
       intercept[SparkException] {
         Vectors.parse(s)
-        println(s"Didn't detect malformatted string $s.")
+        logInfo(s"Didn't detect malformatted string $s.")
       }
     }
   }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index 4a7b99a976f0a..0ecb7a221a503 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -135,6 +135,17 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(closeToZero(U * brzDiag(s) * V.t - localA))
   }
 
+  test("validate matrix sizes of svd") {
+    val k = 2
+    val A = new IndexedRowMatrix(indexedRows)
+    val svd = A.computeSVD(k, computeU = true)
+    assert(svd.U.numRows() === m)
+    assert(svd.U.numCols() === k)
+    assert(svd.s.size === k)
+    assert(svd.V.numRows === n)
+    assert(svd.V.numCols === k)
+  }
+
   test("validate k in svd") {
     val A = new IndexedRowMatrix(indexedRows)
     intercept[IllegalArgumentException] {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index a5a59e9fad5ae..13b754a03943a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.Matchers
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
+import org.apache.spark.mllib.util.{MLUtils, LocalClusterSparkContext, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
 
 object GradientDescentSuite {
@@ -82,11 +82,11 @@ class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with
     // Add a extra variable consisting of all 1.0's for the intercept.
     val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
     val data = testData.map { case LabeledPoint(label, features) =>
-      label -> Vectors.dense(1.0 +: features.toArray)
+      label -> MLUtils.appendBias(features)
     }
 
     val dataRDD = sc.parallelize(data, 2).cache()
-    val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray)
+    val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0)
 
     val (_, loss) = GradientDescent.runMiniBatchSGD(
       dataRDD,
@@ -139,6 +139,45 @@ class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with
       "The different between newWeights with/without regularization " +
         "should be initialWeightsWithIntercept.")
   }
+
+  test("iteration should end with convergence tolerance") {
+    val nPoints = 10000
+    val A = 2.0
+    val B = -1.5
+
+    val initialB = -1.0
+    val initialWeights = Array(initialB)
+
+    val gradient = new LogisticGradient()
+    val updater = new SimpleUpdater()
+    val stepSize = 1.0
+    val numIterations = 10
+    val regParam = 0
+    val miniBatchFrac = 1.0
+    val convergenceTolerance = 5.0e-1
+
+    // Add a extra variable consisting of all 1.0's for the intercept.
+    val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
+    val data = testData.map { case LabeledPoint(label, features) =>
+      label -> MLUtils.appendBias(features)
+    }
+
+    val dataRDD = sc.parallelize(data, 2).cache()
+    val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0)
+
+    val (_, loss) = GradientDescent.runMiniBatchSGD(
+      dataRDD,
+      gradient,
+      updater,
+      stepSize,
+      numIterations,
+      regParam,
+      miniBatchFrac,
+      initialWeightsWithIntercept,
+      convergenceTolerance)
+
+    assert(loss.length < numIterations, "convergenceTolerance failed to stop optimization early")
+  }
 }
 
 class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index d07b9d5b89227..75ae0eb32fb7b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -122,7 +122,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers
       numGDIterations,
       regParam,
       miniBatchFrac,
-      initialWeightsWithIntercept)
+      initialWeightsWithIntercept,
+      convergenceTol)
 
     assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5,
       "The first losses of LBFGS and GD should be the same.")
@@ -221,7 +222,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers
       numGDIterations,
       regParam,
       miniBatchFrac,
-      initialWeightsWithIntercept)
+      initialWeightsWithIntercept,
+      convergenceTol)
 
     // for class LBFGS and the optimize method, we only look at the weights
     assert(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 08a152ffc7a23..39537e7bb4c72 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -100,7 +100,7 @@ class LassoSuite extends SparkFunSuite with MLlibTestSparkContext {
     val testRDD = sc.parallelize(testData, 2).cache()
 
     val ls = new LassoWithSGD()
-    ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)
+    ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40).setConvergenceTol(0.0005)
 
     val model = ls.run(testRDD, initialWeights)
     val weight0 = model.weights(0)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index f5e2d31056cbd..a2a4c5f6b8b70 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -53,6 +53,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
       .setInitialWeights(Vectors.dense(0.0, 0.0))
       .setStepSize(0.2)
       .setNumIterations(25)
+      .setConvergenceTol(0.0001)
 
     // generate sequence of simulated data
     val numBatches = 10
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
index c292ced75e870..c3eeda012571c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
@@ -19,13 +19,13 @@ package org.apache.spark.mllib.stat
 
 import breeze.linalg.{DenseMatrix => BDM, Matrix => BM}
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{Logging, SparkFunSuite}
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation,
   SpearmanCorrelation}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 
-class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext {
+class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
 
   // test input data
   val xData = Array(1.0, 0.0, -2.0)
@@ -146,7 +146,7 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext {
   def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = {
     for (i <- 0 until A.rows; j <- 0 until A.cols) {
       if (!approxEqual(A(i, j), B(i, j), threshold)) {
-        println("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j))
+        logInfo("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j))
         return false
       }
     }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
index b084a5fb4313f..142b90e764a7c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
@@ -19,6 +19,10 @@ package org.apache.spark.mllib.stat
 
 import java.util.Random
 
+import org.apache.commons.math3.distribution.{ExponentialDistribution,
+  NormalDistribution, UniformRealDistribution}
+import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest
+
 import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
@@ -153,4 +157,101 @@ class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext {
       Statistics.chiSqTest(sc.parallelize(continuousFeature, 2))
     }
   }
+
+  test("1 sample Kolmogorov-Smirnov test: apache commons math3 implementation equivalence") {
+    // Create theoretical distributions
+    val stdNormalDist = new NormalDistribution(0, 1)
+    val expDist = new ExponentialDistribution(0.6)
+    val unifDist = new UniformRealDistribution()
+
+    // set seeds
+    val seed = 10L
+    stdNormalDist.reseedRandomGenerator(seed)
+    expDist.reseedRandomGenerator(seed)
+    unifDist.reseedRandomGenerator(seed)
+
+    // Sample data from the distributions and parallelize it
+    val n = 100000
+    val sampledNorm = sc.parallelize(stdNormalDist.sample(n), 10)
+    val sampledExp = sc.parallelize(expDist.sample(n), 10)
+    val sampledUnif = sc.parallelize(unifDist.sample(n), 10)
+
+    // Use a apache math commons local KS test to verify calculations
+    val ksTest = new KolmogorovSmirnovTest()
+    val pThreshold = 0.05
+
+    // Comparing a standard normal sample to a standard normal distribution
+    val result1 = Statistics.kolmogorovSmirnovTest(sampledNorm, "norm", 0, 1)
+    val referenceStat1 = ksTest.kolmogorovSmirnovStatistic(stdNormalDist, sampledNorm.collect())
+    val referencePVal1 = 1 - ksTest.cdf(referenceStat1, n)
+    // Verify vs apache math commons ks test
+    assert(result1.statistic ~== referenceStat1 relTol 1e-4)
+    assert(result1.pValue ~== referencePVal1 relTol 1e-4)
+    // Cannot reject null hypothesis
+    assert(result1.pValue > pThreshold)
+
+    // Comparing an exponential sample to a standard normal distribution
+    val result2 = Statistics.kolmogorovSmirnovTest(sampledExp, "norm", 0, 1)
+    val referenceStat2 = ksTest.kolmogorovSmirnovStatistic(stdNormalDist, sampledExp.collect())
+    val referencePVal2 = 1 - ksTest.cdf(referenceStat2, n)
+    // verify vs apache math commons ks test
+    assert(result2.statistic ~== referenceStat2 relTol 1e-4)
+    assert(result2.pValue ~== referencePVal2 relTol 1e-4)
+    // reject null hypothesis
+    assert(result2.pValue < pThreshold)
+
+    // Testing the use of a user provided CDF function
+    // Distribution is not serializable, so will have to create in the lambda
+    val expCDF = (x: Double) => new ExponentialDistribution(0.2).cumulativeProbability(x)
+
+    // Comparing an exponential sample with mean X to an exponential distribution with mean Y
+    // Where X != Y
+    val result3 = Statistics.kolmogorovSmirnovTest(sampledExp, expCDF)
+    val referenceStat3 = ksTest.kolmogorovSmirnovStatistic(new ExponentialDistribution(0.2),
+      sampledExp.collect())
+    val referencePVal3 = 1 - ksTest.cdf(referenceStat3, sampledNorm.count().toInt)
+    // verify vs apache math commons ks test
+    assert(result3.statistic ~== referenceStat3 relTol 1e-4)
+    assert(result3.pValue ~== referencePVal3 relTol 1e-4)
+    // reject null hypothesis
+    assert(result3.pValue < pThreshold)
+  }
+
+  test("1 sample Kolmogorov-Smirnov test: R implementation equivalence") {
+    /*
+      Comparing results with R's implementation of Kolmogorov-Smirnov for 1 sample
+      > sessionInfo()
+      R version 3.2.0 (2015-04-16)
+      Platform: x86_64-apple-darwin13.4.0 (64-bit)
+      > set.seed(20)
+      > v <- rnorm(20)
+      > v
+       [1]  1.16268529 -0.58592447  1.78546500 -1.33259371 -0.44656677  0.56960612
+       [7] -2.88971761 -0.86901834 -0.46170268 -0.55554091 -0.02013537 -0.15038222
+      [13] -0.62812676  1.32322085 -1.52135057 -0.43742787  0.97057758  0.02822264
+      [19] -0.08578219  0.38921440
+      > ks.test(v, pnorm, alternative = "two.sided")
+
+               One-sample Kolmogorov-Smirnov test
+
+      data:  v
+      D = 0.18874, p-value = 0.4223
+      alternative hypothesis: two-sided
+    */
+
+    val rKSStat = 0.18874
+    val rKSPVal = 0.4223
+    val rData = sc.parallelize(
+      Array(
+        1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501,
+        -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555,
+        -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063,
+        -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691,
+        0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942
+      )
+    )
+    val rCompResult = Statistics.kolmogorovSmirnovTest(rData, "norm", 0, 1)
+    assert(rCompResult.statistic ~== rKSStat relTol 1e-4)
+    assert(rCompResult.pValue ~== rKSPVal relTol 1e-4)
+  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 84dd3b342d4c0..2521b3342181a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.mllib.tree
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{Logging, SparkFunSuite}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
@@ -31,7 +31,7 @@ import org.apache.spark.util.Utils
 /**
  * Test suite for [[GradientBoostedTrees]].
  */
-class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
 
   test("Regression with continuous features: SquaredError") {
     GradientBoostedTreesSuite.testCombinations.foreach {
@@ -50,7 +50,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
           EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
         } catch {
           case e: java.lang.AssertionError =>
-            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
+            logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
               s" subsamplingRate=$subsamplingRate")
             throw e
         }
@@ -80,7 +80,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
           EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.85, "mae")
         } catch {
           case e: java.lang.AssertionError =>
-            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
+            logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
               s" subsamplingRate=$subsamplingRate")
             throw e
         }
@@ -111,7 +111,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
           EnsembleTestHelper.validateClassifier(gbt, GradientBoostedTreesSuite.data, 0.9)
         } catch {
           case e: java.lang.AssertionError =>
-            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
+            logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
               s" subsamplingRate=$subsamplingRate")
             throw e
         }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
index fa4f74d71b7e7..16d7c3ab39b03 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
@@ -33,7 +33,7 @@ class NumericParserSuite extends SparkFunSuite {
     malformatted.foreach { s =>
       intercept[SparkException] {
         NumericParser.parse(s)
-        println(s"Didn't detect malformatted string $s.")
+        throw new RuntimeException(s"Didn't detect malformatted string $s.")
       }
     }
   }
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
index 6b514aaa1290d..7d27439cfde7a 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -39,6 +39,12 @@
 public class JavaUtils {
   private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class);
 
+  /**
+   * Define a default value for driver memory here since this value is referenced across the code
+   * base and nearly all files already use Utils.scala
+   */
+  public static final long DEFAULT_DRIVER_MEM_MB = 1024;
+
   /** Closes the given object, ignoring IOExceptions. */
   public static void closeQuietly(Closeable closeable) {
     try {
diff --git a/pom.xml b/pom.xml
index 4c18bd5e42c87..370c95dd03632 100644
--- a/pom.xml
+++ b/pom.xml
@@ -102,6 +102,7 @@
     external/twitter
     external/flume
     external/flume-sink
+    external/flume-assembly
     external/mqtt
     external/zeromq
     examples
@@ -127,7 +128,7 @@
     ${hadoop.version}
     0.98.7-hadoop2
     hbase
-    1.4.0
+    1.6.0
     3.4.5
     2.4.0
     org.spark-project.hive
@@ -160,6 +161,8 @@
     2.4.4
     1.1.1.7
     1.1.2
+    
+    false
 
     ${java.home}
 
@@ -176,6 +179,7 @@
     compile
     compile
     compile
+    test
 
     
+    
+      twttr-repo
+      Twttr Repository
+      http://maven.twttr.com
+      
+        true
+      
+      
+        false
+      
+    
     
     
       spark-1.4-staging
@@ -325,11 +341,6 @@
   
   
     
-      
-        ${jline.groupid}
-        jline
-        ${jline.version}
-      
       
         com.twitter
         chill_${scala.binary.version}
@@ -747,6 +758,10 @@
             asm
             asm
           
+          
+            org.codehaus.jackson
+            jackson-mapper-asl
+          
           
             org.ow2.asm
             asm
@@ -759,6 +774,10 @@
             commons-logging
             commons-logging
           
+          
+            org.mockito
+            mockito-all
+          
           
             org.mortbay.jetty
             servlet-api-2.5
@@ -1090,6 +1109,12 @@
         ${parquet.version}
         ${parquet.deps.scope}
       
+      
+        org.apache.parquet
+        parquet-avro
+        ${parquet.version}
+        ${parquet.test.deps.scope}
+      
       
         org.apache.flume
         flume-ng-core
@@ -1100,6 +1125,10 @@
             io.netty
             netty
           
+          
+            org.apache.flume
+            flume-ng-auth
+          
           
             org.apache.thrift
             libthrift
@@ -1295,6 +1324,7 @@
               false
               false
               true
+              true
             
           
           
@@ -1431,8 +1461,8 @@
         2.3
         
           false
-          
-          false
+          
+          ${create.dependency.reduced.pom}
           
             
               
@@ -1706,7 +1736,6 @@
       
         2.3.0
         0.9.3
-        3.1.1
       
     
 
@@ -1715,7 +1744,6 @@
       
         2.4.0
         0.9.3
-        3.1.1
       
     
 
@@ -1724,7 +1752,6 @@
       
         2.6.0
         0.9.3
-        3.1.1
         3.4.6
         2.6.0
       
@@ -1794,6 +1821,15 @@
         ${scala.version}
         org.scala-lang
       
+      
+        
+          
+            ${jline.groupid}
+            jline
+            ${jline.version}
+          
+        
+      
     
 
     
@@ -1812,10 +1848,28 @@
         scala-2.11
       
       
-        2.11.6
+        2.11.7
         2.11
-        2.12.1
-        jline
+      
+    
+
+    
+      
+      release
+      
+        
+        true
       
     
 
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 6f86a505b3ae4..4e4e810ec36e3 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -58,24 +58,45 @@ object MimaExcludes {
               "org.apache.spark.ml.regression.LeastSquaresAggregator.this"),
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.ml.regression.LeastSquaresCostFun.this"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.ml.classification.LogisticCostFun.this"),
             // SQL execution is considered private.
             excludePackage("org.apache.spark.sql.execution"),
-            // NanoTime and CatalystTimestampConverter is only used inside catalyst,
-            // not needed anymore
-            ProblemFilters.exclude[MissingClassProblem](
-              "org.apache.spark.sql.parquet.timestamp.NanoTime"),
-              ProblemFilters.exclude[MissingClassProblem](
-              "org.apache.spark.sql.parquet.timestamp.NanoTime$"),
-            ProblemFilters.exclude[MissingClassProblem](
-              "org.apache.spark.sql.parquet.CatalystTimestampConverter"),
+            // Parquet support is considered private.
+            excludePackage("org.apache.spark.sql.parquet"),
+            // local function inside a method
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1")
+          ) ++ Seq(
+            // SPARK-8479 Add numNonzeros and numActives to Matrix.
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.mllib.linalg.Matrix.numNonzeros"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.mllib.linalg.Matrix.numActives")
+          ) ++ Seq(
+            // SPARK-8914 Remove RDDApi
             ProblemFilters.exclude[MissingClassProblem](
-              "org.apache.spark.sql.parquet.CatalystTimestampConverter$"),
-            // SPARK-6777 Implements backwards compatibility rules in CatalystSchemaConverter
+            "org.apache.spark.sql.RDDApi")
+          ) ++ Seq(
+            // SPARK-8701 Add input metadata in the batch page.
             ProblemFilters.exclude[MissingClassProblem](
-              "org.apache.spark.sql.parquet.ParquetTypeInfo"),
+              "org.apache.spark.streaming.scheduler.InputInfo$"),
             ProblemFilters.exclude[MissingClassProblem](
-              "org.apache.spark.sql.parquet.ParquetTypeInfo$")
+              "org.apache.spark.streaming.scheduler.InputInfo")
+          ) ++ Seq(
+            // SPARK-6797 Support YARN modes for SparkR
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.api.r.PairwiseRRDD.this"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.api.r.RRDD.createRWorker"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.api.r.RRDD.this"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.api.r.StringRRDD.this"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.api.r.BaseRRDD.this")
           )
+
         case v if v.startsWith("1.4") =>
           Seq(
             MimaBuild.excludeSparkPackage("deploy"),
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index f5f1c9a1a247a..4291b0be2a616 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -45,8 +45,8 @@ object BuildCommons {
     sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl",
     "kinesis-asl").map(ProjectRef(buildLocation, _))
 
-  val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly) =
-    Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly")
+  val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly) =
+    Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly")
       .map(ProjectRef(buildLocation, _))
 
   val tools = ProjectRef(buildLocation, "tools")
@@ -69,6 +69,7 @@ object SparkBuild extends PomBuild {
     import scala.collection.mutable
     var isAlphaYarn = false
     var profiles: mutable.Seq[String] = mutable.Seq("sbt")
+    // scalastyle:off println
     if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) {
       println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pspark-ganglia-lgpl flag.")
       profiles ++= Seq("spark-ganglia-lgpl")
@@ -88,6 +89,7 @@ object SparkBuild extends PomBuild {
       println("NOTE: SPARK_YARN is deprecated, please use -Pyarn flag.")
       profiles ++= Seq("yarn")
     }
+    // scalastyle:on println
     profiles
   }
 
@@ -96,8 +98,10 @@ object SparkBuild extends PomBuild {
     case None => backwardCompatibility
     case Some(v) =>
       if (backwardCompatibility.nonEmpty)
+        // scalastyle:off println
         println("Note: We ignore environment variables, when use of profile is detected in " +
           "conjunction with environment variable.")
+        // scalastyle:on println
       v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq
     }
 
@@ -161,7 +165,7 @@ object SparkBuild extends PomBuild {
   // Note ordering of these settings matter.
   /* Enable shared settings on all projects */
   (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools))
-    .foreach(enable(sharedSettings ++ ExludedDependencies.settings ++ Revolver.settings))
+    .foreach(enable(sharedSettings ++ ExcludedDependencies.settings ++ Revolver.settings))
 
   /* Enable tests settings for all projects except examples, assembly and tools */
   (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
@@ -206,7 +210,7 @@ object SparkBuild extends PomBuild {
     fork := true,
     outputStrategy in run := Some (StdoutOutput),
 
-    javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=1g"),
+    javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=256m"),
 
     sparkShell := {
       (runMain in Compile).toTask(" org.apache.spark.repl.Main -usejavacp").value
@@ -246,7 +250,7 @@ object Flume {
   This excludes library dependencies in sbt, which are specified in maven but are
   not needed by sbt build.
   */
-object ExludedDependencies {
+object ExcludedDependencies {
   lazy val settings = Seq(
     libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") }
   )
@@ -299,7 +303,7 @@ object SQL {
 object Hive {
 
   lazy val settings = Seq(
-    javaOptions += "-XX:MaxPermSize=1g",
+    javaOptions += "-XX:MaxPermSize=256m",
     // Specially disable assertions since some Hive tests fail them
     javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"),
     // Multiple queries rely on the TestHive singleton. See comments there for more details.
@@ -347,7 +351,7 @@ object Assembly {
         .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String])
     },
     jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) =>
-      if (mName.contains("streaming-kafka-assembly")) {
+      if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly")) {
         // This must match the same name used in maven (see external/kafka-assembly/pom.xml)
         s"${mName}-${v}.jar"
       } else {
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 90b2fffbb9c7c..d7466729b8f36 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -291,6 +291,21 @@ def version(self):
         """
         return self._jsc.version()
 
+    @property
+    @ignore_unicode_prefix
+    def applicationId(self):
+        """
+        A unique identifier for the Spark application.
+        Its format depends on the scheduler implementation.
+        (i.e.
+            in case of local spark app something like 'local-1433865536131'
+            in case of YARN something like 'application_1433865536131_34483'
+        )
+        >>> sc.applicationId  # doctest: +ELLIPSIS
+        u'local-...'
+        """
+        return self._jsc.sc().applicationId()
+
     @property
     def startTime(self):
         """Return the epoch time when the Spark Context was started."""
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 3cee4ea6e3a35..90cd342a6cf7f 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -51,6 +51,8 @@ def launch_gateway():
         on_windows = platform.system() == "Windows"
         script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
         submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
+        if os.environ.get("SPARK_TESTING"):
+            submit_args = "--conf spark.ui.enabled=false " + submit_args
         command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args)
 
         # Start a socket that will be used by PythonGatewayServer to communicate its port to us
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 7abbde8b260eb..89117e492846b 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -18,7 +18,8 @@
 from pyspark.ml.util import keyword_only
 from pyspark.ml.wrapper import JavaEstimator, JavaModel
 from pyspark.ml.param.shared import *
-from pyspark.ml.regression import RandomForestParams
+from pyspark.ml.regression import (
+    RandomForestParams, DecisionTreeModel, TreeEnsembleModels)
 from pyspark.mllib.common import inherit_doc
 
 
@@ -202,6 +203,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
     >>> td = si_model.transform(df)
     >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed")
     >>> model = dt.fit(td)
+    >>> model.numNodes
+    3
+    >>> model.depth
+    1
     >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
     0.0
@@ -269,7 +274,8 @@ def getImpurity(self):
         return self.getOrDefault(self.impurity)
 
 
-class DecisionTreeClassificationModel(JavaModel):
+@inherit_doc
+class DecisionTreeClassificationModel(DecisionTreeModel):
     """
     Model fitted by DecisionTreeClassifier.
     """
@@ -284,6 +290,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
     It supports both binary and multiclass labels, as well as both continuous and categorical
     features.
 
+    >>> from numpy import allclose
     >>> from pyspark.mllib.linalg import Vectors
     >>> from pyspark.ml.feature import StringIndexer
     >>> df = sqlContext.createDataFrame([
@@ -294,6 +301,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
     >>> td = si_model.transform(df)
     >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
     >>> model = rf.fit(td)
+    >>> allclose(model.treeWeights, [1.0, 1.0])
+    True
     >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
     0.0
@@ -423,7 +432,7 @@ def getFeatureSubsetStrategy(self):
         return self.getOrDefault(self.featureSubsetStrategy)
 
 
-class RandomForestClassificationModel(JavaModel):
+class RandomForestClassificationModel(TreeEnsembleModels):
     """
     Model fitted by RandomForestClassifier.
     """
@@ -438,6 +447,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
     It supports binary labels, as well as both continuous and categorical features.
     Note: Multiclass labels are not currently supported.
 
+    >>> from numpy import allclose
     >>> from pyspark.mllib.linalg import Vectors
     >>> from pyspark.ml.feature import StringIndexer
     >>> df = sqlContext.createDataFrame([
@@ -448,6 +458,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
     >>> td = si_model.transform(df)
     >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed")
     >>> model = gbt.fit(td)
+    >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
+    True
     >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
     0.0
@@ -558,7 +570,7 @@ def getStepSize(self):
         return self.getOrDefault(self.stepSize)
 
 
-class GBTClassificationModel(JavaModel):
+class GBTClassificationModel(TreeEnsembleModels):
     """
     Model fitted by GBTClassifier.
     """
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index ddb33f427ac64..9bca7cc000aa5 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -21,7 +21,7 @@
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
 from pyspark.mllib.common import inherit_doc
 
-__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'Normalizer', 'OneHotEncoder',
+__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder',
            'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel',
            'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer',
            'Word2Vec', 'Word2VecModel']
@@ -265,6 +265,75 @@ class IDFModel(JavaModel):
     """
 
 
+@inherit_doc
+@ignore_unicode_prefix
+class NGram(JavaTransformer, HasInputCol, HasOutputCol):
+    """
+    A feature transformer that converts the input array of strings into an array of n-grams. Null
+    values in the input array are ignored.
+    It returns an array of n-grams where each n-gram is represented by a space-separated string of
+    words.
+    When the input is empty, an empty array is returned.
+    When the input array length is less than n (number of elements per n-gram), no n-grams are
+    returned.
+
+    >>> df = sqlContext.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])])
+    >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams")
+    >>> ngram.transform(df).head()
+    Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e'])
+    >>> # Change n-gram length
+    >>> ngram.setParams(n=4).transform(df).head()
+    Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e'])
+    >>> # Temporarily modify output column.
+    >>> ngram.transform(df, {ngram.outputCol: "output"}).head()
+    Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], output=[u'a b c d', u'b c d e'])
+    >>> ngram.transform(df).head()
+    Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e'])
+    >>> # Must use keyword arguments to specify params.
+    >>> ngram.setParams("text")
+    Traceback (most recent call last):
+        ...
+    TypeError: Method setParams forces keyword arguments.
+    """
+
+    # a placeholder to make it appear in the generated doc
+    n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)")
+
+    @keyword_only
+    def __init__(self, n=2, inputCol=None, outputCol=None):
+        """
+        __init__(self, n=2, inputCol=None, outputCol=None)
+        """
+        super(NGram, self).__init__()
+        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid)
+        self.n = Param(self, "n", "number of elements per n-gram (>=1)")
+        self._setDefault(n=2)
+        kwargs = self.__init__._input_kwargs
+        self.setParams(**kwargs)
+
+    @keyword_only
+    def setParams(self, n=2, inputCol=None, outputCol=None):
+        """
+        setParams(self, n=2, inputCol=None, outputCol=None)
+        Sets params for this NGram.
+        """
+        kwargs = self.setParams._input_kwargs
+        return self._set(**kwargs)
+
+    def setN(self, value):
+        """
+        Sets the value of :py:attr:`n`.
+        """
+        self._paramMap[self.n] = value
+        return self
+
+    def getN(self):
+        """
+        Gets the value of n or its default value.
+        """
+        return self.getOrDefault(self.n)
+
+
 @inherit_doc
 class Normalizer(JavaTransformer, HasInputCol, HasOutputCol):
     """
@@ -558,6 +627,10 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
     >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"])
     >>> standardScaler = StandardScaler(inputCol="a", outputCol="scaled")
     >>> model = standardScaler.fit(df)
+    >>> model.mean
+    DenseVector([1.0])
+    >>> model.std
+    DenseVector([1.4142])
     >>> model.transform(df).collect()[1].scaled
     DenseVector([1.4142])
     """
@@ -623,6 +696,20 @@ class StandardScalerModel(JavaModel):
     Model fitted by StandardScaler.
     """
 
+    @property
+    def std(self):
+        """
+        Standard deviation of the StandardScalerModel.
+        """
+        return self._call_java("std")
+
+    @property
+    def mean(self):
+        """
+        Mean of the StandardScalerModel.
+        """
+        return self._call_java("mean")
+
 
 @inherit_doc
 class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index a563024b2cdcb..9889f56cac9e4 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -42,7 +42,7 @@ def _fit(self, dataset):
         """
         raise NotImplementedError()
 
-    def fit(self, dataset, params={}):
+    def fit(self, dataset, params=None):
         """
         Fits a model to the input dataset with optional parameters.
 
@@ -54,6 +54,8 @@ def fit(self, dataset, params={}):
                        list of models.
         :returns: fitted model(s)
         """
+        if params is None:
+            params = dict()
         if isinstance(params, (list, tuple)):
             return [self.fit(dataset, paramMap) for paramMap in params]
         elif isinstance(params, dict):
@@ -86,7 +88,7 @@ def _transform(self, dataset):
         """
         raise NotImplementedError()
 
-    def transform(self, dataset, params={}):
+    def transform(self, dataset, params=None):
         """
         Transforms the input dataset with optional parameters.
 
@@ -96,6 +98,8 @@ def transform(self, dataset, params={}):
                        params.
         :returns: transformed dataset
         """
+        if params is None:
+            params = dict()
         if isinstance(params, dict):
             if params:
                 return self.copy(params,)._transform(dataset)
@@ -135,10 +139,12 @@ class Pipeline(Estimator):
     """
 
     @keyword_only
-    def __init__(self, stages=[]):
+    def __init__(self, stages=None):
         """
         __init__(self, stages=[])
         """
+        if stages is None:
+            stages = []
         super(Pipeline, self).__init__()
         #: Param for pipeline stages.
         self.stages = Param(self, "stages", "pipeline stages")
@@ -162,11 +168,13 @@ def getStages(self):
             return self._paramMap[self.stages]
 
     @keyword_only
-    def setParams(self, stages=[]):
+    def setParams(self, stages=None):
         """
         setParams(self, stages=[])
         Sets params for Pipeline.
         """
+        if stages is None:
+            stages = []
         kwargs = self.setParams._input_kwargs
         return self._set(**kwargs)
 
@@ -195,7 +203,9 @@ def _fit(self, dataset):
                 transformers.append(stage)
         return PipelineModel(transformers)
 
-    def copy(self, extra={}):
+    def copy(self, extra=None):
+        if extra is None:
+            extra = dict()
         that = Params.copy(self, extra)
         stages = [stage.copy(extra) for stage in that.getStages()]
         return that.setStages(stages)
@@ -216,6 +226,8 @@ def _transform(self, dataset):
             dataset = t.transform(dataset)
         return dataset
 
-    def copy(self, extra={}):
+    def copy(self, extra=None):
+        if extra is None:
+            extra = dict()
         stages = [stage.copy(extra) for stage in self.stages]
         return PipelineModel(stages)
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index b139e27372d80..44f60a769566d 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -172,6 +172,10 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
     >>> dt = DecisionTreeRegressor(maxDepth=2)
     >>> model = dt.fit(df)
+    >>> model.depth
+    1
+    >>> model.numNodes
+    3
     >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
     0.0
@@ -239,7 +243,37 @@ def getImpurity(self):
         return self.getOrDefault(self.impurity)
 
 
-class DecisionTreeRegressionModel(JavaModel):
+@inherit_doc
+class DecisionTreeModel(JavaModel):
+
+    @property
+    def numNodes(self):
+        """Return number of nodes of the decision tree."""
+        return self._call_java("numNodes")
+
+    @property
+    def depth(self):
+        """Return depth of the decision tree."""
+        return self._call_java("depth")
+
+    def __repr__(self):
+        return self._call_java("toString")
+
+
+@inherit_doc
+class TreeEnsembleModels(JavaModel):
+
+    @property
+    def treeWeights(self):
+        """Return the weights for each tree"""
+        return list(self._call_java("javaTreeWeights"))
+
+    def __repr__(self):
+        return self._call_java("toString")
+
+
+@inherit_doc
+class DecisionTreeRegressionModel(DecisionTreeModel):
     """
     Model fitted by DecisionTreeRegressor.
     """
@@ -253,12 +287,15 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
     learning algorithm for regression.
     It supports both continuous and categorical features.
 
+    >>> from numpy import allclose
     >>> from pyspark.mllib.linalg import Vectors
     >>> df = sqlContext.createDataFrame([
     ...     (1.0, Vectors.dense(1.0)),
     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
     >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
     >>> model = rf.fit(df)
+    >>> allclose(model.treeWeights, [1.0, 1.0])
+    True
     >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
     0.0
@@ -389,7 +426,7 @@ def getFeatureSubsetStrategy(self):
         return self.getOrDefault(self.featureSubsetStrategy)
 
 
-class RandomForestRegressionModel(JavaModel):
+class RandomForestRegressionModel(TreeEnsembleModels):
     """
     Model fitted by RandomForestRegressor.
     """
@@ -403,12 +440,15 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
     learning algorithm for regression.
     It supports both continuous and categorical features.
 
+    >>> from numpy import allclose
     >>> from pyspark.mllib.linalg import Vectors
     >>> df = sqlContext.createDataFrame([
     ...     (1.0, Vectors.dense(1.0)),
     ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
     >>> gbt = GBTRegressor(maxIter=5, maxDepth=2)
     >>> model = gbt.fit(df)
+    >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
+    True
     >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
     0.0
@@ -518,7 +558,7 @@ def getStepSize(self):
         return self.getOrDefault(self.stepSize)
 
 
-class GBTRegressionModel(JavaModel):
+class GBTRegressionModel(TreeEnsembleModels):
     """
     Model fitted by GBTRegressor.
     """
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 6adbf166f34a8..c151d21fd661a 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -252,6 +252,17 @@ def test_idf(self):
         output = idf0m.transform(dataset)
         self.assertIsNotNone(output.head().idf)
 
+    def test_ngram(self):
+        sqlContext = SQLContext(self.sc)
+        dataset = sqlContext.createDataFrame([
+            ([["a", "b", "c", "d", "e"]])], ["input"])
+        ngram0 = NGram(n=4, inputCol="input", outputCol="output")
+        self.assertEqual(ngram0.getN(), 4)
+        self.assertEqual(ngram0.getInputCol(), "input")
+        self.assertEqual(ngram0.getOutputCol(), "output")
+        transformedDF = ngram0.transform(dataset)
+        self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"])
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 7b0893e2cdadc..253705bde913e 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -166,7 +166,7 @@ def __init__(self, java_model):
         self._java_obj = java_model
         self.uid = java_model.uid()
 
-    def copy(self, extra={}):
+    def copy(self, extra=None):
         """
         Creates a copy of this instance with the same uid and some
         extra params. This implementation first calls Params.copy and
@@ -175,6 +175,8 @@ def copy(self, extra={}):
         :param extra: Extra parameters to copy to the new instance
         :return: Copy of this instance
         """
+        if extra is None:
+            extra = dict()
         that = super(JavaModel, self).copy(extra)
         that._java_obj = self._java_obj.copy(self._empty_java_param_map())
         that._transfer_params_to_java()
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 735d45ba03d27..8f27c446a66e8 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -24,7 +24,9 @@
 from pyspark.streaming import DStream
 from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
 from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector
-from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
+from pyspark.mllib.regression import (
+    LabeledPoint, LinearModel, _regression_train_wrapper,
+    StreamingLinearAlgorithm)
 from pyspark.mllib.util import Saveable, Loader, inherit_doc
 
 
@@ -585,55 +587,13 @@ def train(cls, data, lambda_=1.0):
         return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta))
 
 
-class StreamingLinearAlgorithm(object):
-    """
-    Base class that has to be inherited by any StreamingLinearAlgorithm.
-
-    Prevents reimplementation of methods predictOn and predictOnValues.
-    """
-    def __init__(self, model):
-        self._model = model
-
-    def latestModel(self):
-        """
-        Returns the latest model.
-        """
-        return self._model
-
-    def _validate(self, dstream):
-        if not isinstance(dstream, DStream):
-            raise TypeError(
-                "dstream should be a DStream object, got %s" % type(dstream))
-        if not self._model:
-            raise ValueError(
-                "Model must be intialized using setInitialWeights")
-
-    def predictOn(self, dstream):
-        """
-        Make predictions on a dstream.
-
-        :return: Transformed dstream object.
-        """
-        self._validate(dstream)
-        return dstream.map(lambda x: self._model.predict(x))
-
-    def predictOnValues(self, dstream):
-        """
-        Make predictions on a keyed dstream.
-
-        :return: Transformed dstream object.
-        """
-        self._validate(dstream)
-        return dstream.mapValues(lambda x: self._model.predict(x))
-
-
 @inherit_doc
 class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm):
     """
-    Run LogisticRegression with SGD on a stream of data.
+    Run LogisticRegression with SGD on a batch of data.
 
     The weights obtained at the end of training a stream are used as initial
-    weights for the next stream.
+    weights for the next batch.
 
     :param stepSize: Step size for each iteration of gradient descent.
     :param numIterations: Number of iterations run for each batch of data.
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index e3c8a24c4a751..ed4d78a2c6788 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -282,22 +282,30 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader):
 
     Model produced by [[PowerIterationClustering]].
 
-    >>> data = [(0, 1, 1.0), (0, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0),
-    ...     (0, 3, 1.0), (1, 2, 1.0), (0, 4, 0.1)]
+    >>> data = [(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (1, 3, 1.0),
+    ... (2, 3, 1.0), (3, 4, 0.1), (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0),
+    ... (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), (10, 11, 1.0),
+    ... (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)]
     >>> rdd = sc.parallelize(data, 2)
     >>> model = PowerIterationClustering.train(rdd, 2, 100)
     >>> model.k
     2
-    >>> sorted(model.assignments().collect())
-    [Assignment(id=0, cluster=1), Assignment(id=1, cluster=0), ...
+    >>> result = sorted(model.assignments().collect(), key=lambda x: x.id)
+    >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster
+    True
+    >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster
+    True
     >>> import os, tempfile
     >>> path = tempfile.mkdtemp()
     >>> model.save(sc, path)
     >>> sameModel = PowerIterationClusteringModel.load(sc, path)
     >>> sameModel.k
     2
-    >>> sorted(sameModel.assignments().collect())
-    [Assignment(id=0, cluster=1), Assignment(id=1, cluster=0), ...
+    >>> result = sorted(model.assignments().collect(), key=lambda x: x.id)
+    >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster
+    True
+    >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster
+    True
     >>> from shutil import rmtree
     >>> try:
     ...     rmtree(path)
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index c5cf3a4e7ff22..f21403707e12a 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -152,6 +152,10 @@ class MulticlassMetrics(JavaModelWrapper):
     >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
     ...     (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
     >>> metrics = MulticlassMetrics(predictionAndLabels)
+    >>> metrics.confusionMatrix().toArray()
+    array([[ 2.,  1.,  1.],
+           [ 1.,  3.,  0.],
+           [ 0.,  0.,  1.]])
     >>> metrics.falsePositiveRate(0.0)
     0.2...
     >>> metrics.precision(1.0)
@@ -186,6 +190,13 @@ def __init__(self, predictionAndLabels):
         java_model = java_class(df._jdf)
         super(MulticlassMetrics, self).__init__(java_model)
 
+    def confusionMatrix(self):
+        """
+        Returns confusion matrix: predicted classes are in columns,
+        they are ordered by class label ascending, as in "labels".
+        """
+        return self.call("confusionMatrix")
+
     def truePositiveRate(self, label):
         """
         Returns true positive rate for a given label (category).
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index f00bb93b7bf40..f921e3ad1a314 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -36,6 +36,7 @@
 from pyspark.mllib.linalg import (
     Vector, Vectors, DenseVector, SparseVector, _convert_to_vector)
 from pyspark.mllib.regression import LabeledPoint
+from pyspark.mllib.util import JavaLoader, JavaSaveable
 
 __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
            'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
@@ -111,6 +112,15 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer):
     """
 
     def transform(self, vector):
+        """
+        Applies transformation on a vector or an RDD[Vector].
+
+        Note: In Python, transform cannot currently be used within
+              an RDD transformation or action.
+              Call transform directly on the RDD instead.
+
+        :param vector: Vector or RDD of Vector to be transformed.
+        """
         if isinstance(vector, RDD):
             vector = vector.map(_convert_to_vector)
         else:
@@ -191,7 +201,7 @@ def fit(self, dataset):
         Computes the mean and variance and stores as a model to be used
         for later scaling.
 
-        :param data: The data used to compute the mean and variance
+        :param dataset: The data used to compute the mean and variance
                      to build the transformation model.
         :return: a StandardScalarModel
         """
@@ -346,10 +356,6 @@ def transform(self, x):
                   vector
         :return: an RDD of TF-IDF vectors or a TF-IDF vector
         """
-        if isinstance(x, RDD):
-            return JavaVectorTransformer.transform(self, x)
-
-        x = _convert_to_vector(x)
         return JavaVectorTransformer.transform(self, x)
 
     def idf(self):
@@ -411,7 +417,7 @@ def fit(self, dataset):
         return IDFModel(jmodel)
 
 
-class Word2VecModel(JavaVectorTransformer):
+class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader):
     """
     class for Word2Vec model
     """
@@ -450,6 +456,12 @@ def getVectors(self):
         """
         return self.call("getVectors")
 
+    @classmethod
+    def load(cls, sc, path):
+        jmodel = sc._jvm.org.apache.spark.mllib.feature \
+            .Word2VecModel.load(sc._jsc.sc(), path)
+        return Word2VecModel(jmodel)
+
 
 @ignore_unicode_prefix
 class Word2Vec(object):
@@ -483,6 +495,18 @@ class Word2Vec(object):
     >>> syms = model.findSynonyms(vec, 2)
     >>> [s[0] for s in syms]
     [u'b', u'c']
+
+    >>> import os, tempfile
+    >>> path = tempfile.mkdtemp()
+    >>> model.save(sc, path)
+    >>> sameModel = Word2VecModel.load(sc, path)
+    >>> model.transform("a") == sameModel.transform("a")
+    True
+    >>> from shutil import rmtree
+    >>> try:
+    ...     rmtree(path)
+    ... except OSError:
+    ...     pass
     """
     def __init__(self):
         """
diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py
index b7f00d60069e6..bdc4a132b1b18 100644
--- a/python/pyspark/mllib/fpm.py
+++ b/python/pyspark/mllib/fpm.py
@@ -39,8 +39,8 @@ class FPGrowthModel(JavaModelWrapper):
     >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]]
     >>> rdd = sc.parallelize(data, 2)
     >>> model = FPGrowth.train(rdd, 0.6, 2)
-    >>> sorted(model.freqItemsets().collect(), key=lambda x: x.items)
-    [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'a', u'c'], freq=3), ...
+    >>> sorted(model.freqItemsets().collect())
+    [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
     """
 
     def freqItemsets(self):
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index e96c5ef87df86..040886f71775b 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -31,6 +31,7 @@
     xrange = range
     import copyreg as copy_reg
 else:
+    from itertools import izip as zip
     import copy_reg
 
 import numpy as np
@@ -116,6 +117,10 @@ def _format_float(f, digits=4):
     return s
 
 
+def _format_float_list(l):
+    return [_format_float(x) for x in l]
+
+
 class VectorUDT(UserDefinedType):
     """
     SQL user-defined type (UDT) for Vector.
@@ -440,8 +445,10 @@ def __init__(self, size, *args):
         values (sorted by index).
 
         :param size: Size of the vector.
-        :param args: Non-zero entries, as a dictionary, list of tupes,
-               or two sorted lists containing indices and values.
+        :param args: Active entries, as a dictionary {index: value, ...},
+          a list of tuples [(index, value), ...], or a list of strictly i
+          ncreasing indices and a list of corresponding values [index, ...],
+          [value, ...]. Inactive entries are treated as zeros.
 
         >>> SparseVector(4, {1: 1.0, 3: 5.5})
         SparseVector(4, {1: 1.0, 3: 5.5})
@@ -451,6 +458,7 @@ def __init__(self, size, *args):
         SparseVector(4, {1: 1.0, 3: 5.5})
         """
         self.size = int(size)
+        """ Size of the vector. """
         assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments"
         if len(args) == 1:
             pairs = args[0]
@@ -458,7 +466,9 @@ def __init__(self, size, *args):
                 pairs = pairs.items()
             pairs = sorted(pairs)
             self.indices = np.array([p[0] for p in pairs], dtype=np.int32)
+            """ A list of indices corresponding to active entries. """
             self.values = np.array([p[1] for p in pairs], dtype=np.float64)
+            """ A list of values corresponding to active entries. """
         else:
             if isinstance(args[0], bytes):
                 assert isinstance(args[1], bytes), "values should be string too"
@@ -577,34 +587,27 @@ def dot(self, other):
             ...
         AssertionError: dimension mismatch
         """
-        if type(other) == np.ndarray:
-            if other.ndim == 2:
-                results = [self.dot(other[:, i]) for i in xrange(other.shape[1])]
-                return np.array(results)
-            elif other.ndim > 2:
+
+        if isinstance(other, np.ndarray):
+            if other.ndim not in [2, 1]:
                 raise ValueError("Cannot call dot with %d-dimensional array" % other.ndim)
+            assert len(self) == other.shape[0], "dimension mismatch"
+            return np.dot(self.values, other[self.indices])
 
         assert len(self) == _vector_size(other), "dimension mismatch"
 
-        if type(other) in (np.ndarray, array.array, DenseVector):
-            result = 0.0
-            for i in xrange(len(self.indices)):
-                result += self.values[i] * other[self.indices[i]]
-            return result
+        if isinstance(other, DenseVector):
+            return np.dot(other.array[self.indices], self.values)
 
-        elif type(other) is SparseVector:
-            result = 0.0
-            i, j = 0, 0
-            while i < len(self.indices) and j < len(other.indices):
-                if self.indices[i] == other.indices[j]:
-                    result += self.values[i] * other.values[j]
-                    i += 1
-                    j += 1
-                elif self.indices[i] < other.indices[j]:
-                    i += 1
-                else:
-                    j += 1
-            return result
+        elif isinstance(other, SparseVector):
+            # Find out common indices.
+            self_cmind = np.in1d(self.indices, other.indices, assume_unique=True)
+            self_values = self.values[self_cmind]
+            if self_values.size == 0:
+                return 0.0
+            else:
+                other_cmind = np.in1d(other.indices, self.indices, assume_unique=True)
+                return np.dot(self_values, other.values[other_cmind])
 
         else:
             return self.dot(_convert_to_vector(other))
@@ -635,22 +638,23 @@ def squared_distance(self, other):
         AssertionError: dimension mismatch
         """
         assert len(self) == _vector_size(other), "dimension mismatch"
-        if type(other) in (list, array.array, DenseVector, np.array, np.ndarray):
-            if type(other) is np.array and other.ndim != 1:
+
+        if isinstance(other, np.ndarray) or isinstance(other, DenseVector):
+            if isinstance(other, np.ndarray) and other.ndim != 1:
                 raise Exception("Cannot call squared_distance with %d-dimensional array" %
                                 other.ndim)
-            result = 0.0
-            j = 0   # index into our own array
-            for i in xrange(len(other)):
-                if j < len(self.indices) and self.indices[j] == i:
-                    diff = self.values[j] - other[i]
-                    result += diff * diff
-                    j += 1
-                else:
-                    result += other[i] * other[i]
+            if isinstance(other, DenseVector):
+                other = other.array
+            sparse_ind = np.zeros(other.size, dtype=bool)
+            sparse_ind[self.indices] = True
+            dist = other[sparse_ind] - self.values
+            result = np.dot(dist, dist)
+
+            other_ind = other[~sparse_ind]
+            result += np.dot(other_ind, other_ind)
             return result
 
-        elif type(other) is SparseVector:
+        elif isinstance(other, SparseVector):
             result = 0.0
             i, j = 0, 0
             while i < len(self.indices) and j < len(other.indices):
@@ -876,6 +880,50 @@ def __reduce__(self):
             self.numRows, self.numCols, self.values.tostring(),
             int(self.isTransposed))
 
+    def __str__(self):
+        """
+        Pretty printing of a DenseMatrix
+
+        >>> dm = DenseMatrix(2, 2, range(4))
+        >>> print(dm)
+        DenseMatrix([[ 0.,  2.],
+                     [ 1.,  3.]])
+        >>> dm = DenseMatrix(2, 2, range(4), isTransposed=True)
+        >>> print(dm)
+        DenseMatrix([[ 0.,  1.],
+                     [ 2.,  3.]])
+        """
+        # Inspired by __repr__ in scipy matrices.
+        array_lines = repr(self.toArray()).splitlines()
+
+        # We need to adjust six spaces which is the difference in number
+        # of letters between "DenseMatrix" and "array"
+        x = '\n'.join([(" " * 6 + line) for line in array_lines[1:]])
+        return array_lines[0].replace("array", "DenseMatrix") + "\n" + x
+
+    def __repr__(self):
+        """
+        Representation of a DenseMatrix
+
+        >>> dm = DenseMatrix(2, 2, range(4))
+        >>> dm
+        DenseMatrix(2, 2, [0.0, 1.0, 2.0, 3.0], False)
+        """
+        # If the number of values are less than seventeen then return as it is.
+        # Else return first eight values and last eight values.
+        if len(self.values) < 17:
+            entries = _format_float_list(self.values)
+        else:
+            entries = (
+                _format_float_list(self.values[:8]) +
+                ["..."] +
+                _format_float_list(self.values[-8:])
+            )
+
+        entries = ", ".join(entries)
+        return "DenseMatrix({0}, {1}, [{2}], {3})".format(
+            self.numRows, self.numCols, entries, self.isTransposed)
+
     def toArray(self):
         """
         Return an numpy.ndarray
@@ -952,6 +1000,84 @@ def __init__(self, numRows, numCols, colPtrs, rowIndices, values,
             raise ValueError("Expected rowIndices of length %d, got %d."
                              % (self.rowIndices.size, self.values.size))
 
+    def __str__(self):
+        """
+        Pretty printing of a SparseMatrix
+
+        >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+        >>> print(sm1)
+        2 X 2 CSCMatrix
+        (0,0) 2.0
+        (1,0) 3.0
+        (1,1) 4.0
+        >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
+        >>> print(sm1)
+        2 X 2 CSRMatrix
+        (0,0) 2.0
+        (0,1) 3.0
+        (1,1) 4.0
+        """
+        spstr = "{0} X {1} ".format(self.numRows, self.numCols)
+        if self.isTransposed:
+            spstr += "CSRMatrix\n"
+        else:
+            spstr += "CSCMatrix\n"
+
+        cur_col = 0
+        smlist = []
+
+        # Display first 16 values.
+        if len(self.values) <= 16:
+            zipindval = zip(self.rowIndices, self.values)
+        else:
+            zipindval = zip(self.rowIndices[:16], self.values[:16])
+        for i, (rowInd, value) in enumerate(zipindval):
+            if self.colPtrs[cur_col + 1] <= i:
+                cur_col += 1
+            if self.isTransposed:
+                smlist.append('({0},{1}) {2}'.format(
+                    cur_col, rowInd, _format_float(value)))
+            else:
+                smlist.append('({0},{1}) {2}'.format(
+                    rowInd, cur_col, _format_float(value)))
+        spstr += "\n".join(smlist)
+
+        if len(self.values) > 16:
+            spstr += "\n.." * 2
+        return spstr
+
+    def __repr__(self):
+        """
+        Representation of a SparseMatrix
+
+        >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+        >>> sm1
+        SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2.0, 3.0, 4.0], False)
+        """
+        rowIndices = list(self.rowIndices)
+        colPtrs = list(self.colPtrs)
+
+        if len(self.values) <= 16:
+            values = _format_float_list(self.values)
+
+        else:
+            values = (
+                _format_float_list(self.values[:8]) +
+                ["..."] +
+                _format_float_list(self.values[-8:])
+            )
+            rowIndices = rowIndices[:8] + ["..."] + rowIndices[-8:]
+
+        if len(self.colPtrs) > 16:
+            colPtrs = colPtrs[:8] + ["..."] + colPtrs[-8:]
+
+        values = ", ".join(values)
+        rowIndices = ", ".join([str(ind) for ind in rowIndices])
+        colPtrs = ", ".join([str(ptr) for ptr in colPtrs])
+        return "SparseMatrix({0}, {1}, [{2}], [{3}], [{4}], {5})".format(
+            self.numRows, self.numCols, colPtrs, rowIndices,
+            values, self.isTransposed)
+
     def __reduce__(self):
         return SparseMatrix, (
             self.numRows, self.numCols, self.colPtrs.tostring(),
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 5ddbbee4babdd..8e90adee5f4c2 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -19,6 +19,7 @@
 from numpy import array
 
 from pyspark import RDD
+from pyspark.streaming.dstream import DStream
 from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc
 from pyspark.mllib.linalg import SparseVector, Vectors, _convert_to_vector
 from pyspark.mllib.util import Saveable, Loader
@@ -570,6 +571,95 @@ def train(cls, data, isotonic=True):
         return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic)
 
 
+class StreamingLinearAlgorithm(object):
+    """
+    Base class that has to be inherited by any StreamingLinearAlgorithm.
+
+    Prevents reimplementation of methods predictOn and predictOnValues.
+    """
+    def __init__(self, model):
+        self._model = model
+
+    def latestModel(self):
+        """
+        Returns the latest model.
+        """
+        return self._model
+
+    def _validate(self, dstream):
+        if not isinstance(dstream, DStream):
+            raise TypeError(
+                "dstream should be a DStream object, got %s" % type(dstream))
+        if not self._model:
+            raise ValueError(
+                "Model must be intialized using setInitialWeights")
+
+    def predictOn(self, dstream):
+        """
+        Make predictions on a dstream.
+
+        :return: Transformed dstream object.
+        """
+        self._validate(dstream)
+        return dstream.map(lambda x: self._model.predict(x))
+
+    def predictOnValues(self, dstream):
+        """
+        Make predictions on a keyed dstream.
+
+        :return: Transformed dstream object.
+        """
+        self._validate(dstream)
+        return dstream.mapValues(lambda x: self._model.predict(x))
+
+
+@inherit_doc
+class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm):
+    """
+    Run LinearRegression with SGD on a batch of data.
+
+    The problem minimized is (1 / n_samples) * (y - weights'X)**2.
+    After training on a batch of data, the weights obtained at the end of
+    training are used as initial weights for the next batch.
+
+    :param: stepSize Step size for each iteration of gradient descent.
+    :param: numIterations Total number of iterations run.
+    :param: miniBatchFraction Fraction of data on which SGD is run for each
+                              iteration.
+    """
+    def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0):
+        self.stepSize = stepSize
+        self.numIterations = numIterations
+        self.miniBatchFraction = miniBatchFraction
+        self._model = None
+        super(StreamingLinearRegressionWithSGD, self).__init__(
+            model=self._model)
+
+    def setInitialWeights(self, initialWeights):
+        """
+        Set the initial value of weights.
+
+        This must be set before running trainOn and predictOn
+        """
+        initialWeights = _convert_to_vector(initialWeights)
+        self._model = LinearRegressionModel(initialWeights, 0)
+        return self
+
+    def trainOn(self, dstream):
+        """Train the model on the incoming dstream."""
+        self._validate(dstream)
+
+        def update(rdd):
+            # LinearRegressionWithSGD.train raises an error for an empty RDD.
+            if not rdd.isEmpty():
+                self._model = LinearRegressionWithSGD.train(
+                    rdd, self.numIterations, self.stepSize,
+                    self.miniBatchFraction, self._model.weights,
+                    self._model.intercept)
+
+        dstream.foreachRDD(update)
+
+
 def _test():
     import doctest
     from pyspark import SparkContext
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index cd80c3e07a4f7..f2eab5b18f077 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -27,8 +27,9 @@
 from shutil import rmtree
 
 from numpy import (
-    array, array_equal, zeros, inf, random, exp, dot, all, mean)
+    array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones)
 from numpy import sum as array_sum
+
 from py4j.protocol import Py4JJavaError
 
 if sys.version_info[:2] <= (2, 6):
@@ -45,17 +46,19 @@
 from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel
 from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
     DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
-from pyspark.mllib.regression import LabeledPoint
 from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD
+from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
 from pyspark.mllib.random import RandomRDDs
 from pyspark.mllib.stat import Statistics
 from pyspark.mllib.feature import Word2Vec
 from pyspark.mllib.feature import IDF
 from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
 from pyspark.mllib.util import LinearDataGenerator
+from pyspark.mllib.util import MLUtils
 from pyspark.serializers import PickleSerializer
 from pyspark.streaming import StreamingContext
 from pyspark.sql import SQLContext
+from pyspark.streaming import StreamingContext
 
 _have_scipy = False
 try:
@@ -126,17 +129,22 @@ def test_dot(self):
                      [1., 2., 3., 4.],
                      [1., 2., 3., 4.],
                      [1., 2., 3., 4.]])
+        arr = pyarray.array('d', [0, 1, 2, 3])
         self.assertEquals(10.0, sv.dot(dv))
         self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat)))
         self.assertEquals(30.0, dv.dot(dv))
         self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat)))
         self.assertEquals(30.0, lst.dot(dv))
         self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat)))
+        self.assertEquals(7.0, sv.dot(arr))
 
     def test_squared_distance(self):
         sv = SparseVector(4, {1: 1, 3: 2})
         dv = DenseVector(array([1., 2., 3., 4.]))
         lst = DenseVector([4, 3, 2, 1])
+        lst1 = [4, 3, 2, 1]
+        arr = pyarray.array('d', [0, 2, 1, 3])
+        narr = array([0, 2, 1, 3])
         self.assertEquals(15.0, _squared_distance(sv, dv))
         self.assertEquals(25.0, _squared_distance(sv, lst))
         self.assertEquals(20.0, _squared_distance(dv, lst))
@@ -146,6 +154,9 @@ def test_squared_distance(self):
         self.assertEquals(0.0, _squared_distance(sv, sv))
         self.assertEquals(0.0, _squared_distance(dv, dv))
         self.assertEquals(0.0, _squared_distance(lst, lst))
+        self.assertEquals(25.0, _squared_distance(sv, lst1))
+        self.assertEquals(3.0, _squared_distance(sv, arr))
+        self.assertEquals(3.0, _squared_distance(sv, narr))
 
     def test_conversion(self):
         # numpy arrays should be automatically upcast to float64
@@ -178,6 +189,53 @@ def test_matrix_indexing(self):
             for j in range(2):
                 self.assertEquals(mat[i, j], expected[i][j])
 
+    def test_repr_dense_matrix(self):
+        mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
+        self.assertTrue(
+            repr(mat),
+            'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
+
+        mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True)
+        self.assertTrue(
+            repr(mat),
+            'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
+
+        mat = DenseMatrix(6, 3, zeros(18))
+        self.assertTrue(
+            repr(mat),
+            'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \
+                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)')
+
+    def test_repr_sparse_matrix(self):
+        sm1t = SparseMatrix(
+            3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0],
+            isTransposed=True)
+        self.assertTrue(
+            repr(sm1t),
+            'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)')
+
+        indices = tile(arange(6), 3)
+        values = ones(18)
+        sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values)
+        self.assertTrue(
+            repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \
+                [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \
+                [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \
+                1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)")
+
+        self.assertTrue(
+            str(sm),
+            "6 X 3 CSCMatrix\n\
+            (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\
+            (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\
+            (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..")
+
+        sm = SparseMatrix(1, 18, zeros(19), [], [])
+        self.assertTrue(
+            repr(sm),
+            'SparseMatrix(1, 18, \
+                [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)')
+
     def test_sparse_matrix(self):
         # Test sparse matrix creation.
         sm1 = SparseMatrix(
@@ -187,6 +245,9 @@ def test_sparse_matrix(self):
         self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4])
         self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2])
         self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0])
+        self.assertTrue(
+            repr(sm1),
+            'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)')
 
         # Test indexing
         expected = [
@@ -1170,6 +1231,166 @@ def collect_errors(rdd):
         self.assertTrue(errors[1] - errors[-1] > 0.3)
 
 
+class StreamingLinearRegressionWithTests(MLLibStreamingTestCase):
+
+    def assertArrayAlmostEqual(self, array1, array2, dec):
+        for i, j in array1, array2:
+            self.assertAlmostEqual(i, j, dec)
+
+    def test_parameter_accuracy(self):
+        """Test that coefs are predicted accurately by fitting on toy data."""
+
+        # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients
+        # (10, 10)
+        slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
+        slr.setInitialWeights([0.0, 0.0])
+        xMean = [0.0, 0.0]
+        xVariance = [1.0 / 3.0, 1.0 / 3.0]
+
+        # Create ten batches with 100 sample points in each.
+        batches = []
+        for i in range(10):
+            batch = LinearDataGenerator.generateLinearInput(
+                0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1)
+            batches.append(sc.parallelize(batch))
+
+        input_stream = self.ssc.queueStream(batches)
+        t = time()
+        slr.trainOn(input_stream)
+        self.ssc.start()
+        self._ssc_wait(t, 10, 0.01)
+        self.assertArrayAlmostEqual(
+            slr.latestModel().weights.array, [10., 10.], 1)
+        self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1)
+
+    def test_parameter_convergence(self):
+        """Test that the model parameters improve with streaming data."""
+        slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
+        slr.setInitialWeights([0.0])
+
+        # Create ten batches with 100 sample points in each.
+        batches = []
+        for i in range(10):
+            batch = LinearDataGenerator.generateLinearInput(
+                0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1)
+            batches.append(sc.parallelize(batch))
+
+        model_weights = []
+        input_stream = self.ssc.queueStream(batches)
+        input_stream.foreachRDD(
+            lambda x: model_weights.append(slr.latestModel().weights[0]))
+        t = time()
+        slr.trainOn(input_stream)
+        self.ssc.start()
+        self._ssc_wait(t, 10, 0.01)
+
+        model_weights = array(model_weights)
+        diff = model_weights[1:] - model_weights[:-1]
+        self.assertTrue(all(diff >= -0.1))
+
+    def test_prediction(self):
+        """Test prediction on a model with weights already set."""
+        # Create a model with initial Weights equal to coefs
+        slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
+        slr.setInitialWeights([10.0, 10.0])
+
+        # Create ten batches with 100 sample points in each.
+        batches = []
+        for i in range(10):
+            batch = LinearDataGenerator.generateLinearInput(
+                0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0],
+                100, 42 + i, 0.1)
+            batches.append(
+                sc.parallelize(batch).map(lambda lp: (lp.label, lp.features)))
+
+        input_stream = self.ssc.queueStream(batches)
+        t = time()
+        output_stream = slr.predictOnValues(input_stream)
+        samples = []
+        output_stream.foreachRDD(lambda x: samples.append(x.collect()))
+
+        self.ssc.start()
+        self._ssc_wait(t, 5, 0.01)
+
+        # Test that mean absolute error on each batch is less than 0.1
+        for batch in samples:
+            true, predicted = zip(*batch)
+            self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1)
+
+    def test_train_prediction(self):
+        """Test that error on test data improves as model is trained."""
+        slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
+        slr.setInitialWeights([0.0])
+
+        # Create ten batches with 100 sample points in each.
+        batches = []
+        for i in range(10):
+            batch = LinearDataGenerator.generateLinearInput(
+                0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1)
+            batches.append(sc.parallelize(batch))
+
+        predict_batches = [
+            b.map(lambda lp: (lp.label, lp.features)) for b in batches]
+        mean_absolute_errors = []
+
+        def func(rdd):
+            true, predicted = zip(*rdd.collect())
+            mean_absolute_errors.append(mean(abs(true) - abs(predicted)))
+
+        model_weights = []
+        input_stream = self.ssc.queueStream(batches)
+        output_stream = self.ssc.queueStream(predict_batches)
+        t = time()
+        slr.trainOn(input_stream)
+        output_stream = slr.predictOnValues(output_stream)
+        output_stream.foreachRDD(func)
+        self.ssc.start()
+        self._ssc_wait(t, 10, 0.01)
+        self.assertTrue(mean_absolute_errors[1] - mean_absolute_errors[-1] > 2)
+
+
+class MLUtilsTests(MLlibTestCase):
+    def test_append_bias(self):
+        data = [2.0, 2.0, 2.0]
+        ret = MLUtils.appendBias(data)
+        self.assertEqual(ret[3], 1.0)
+        self.assertEqual(type(ret), DenseVector)
+
+    def test_append_bias_with_vector(self):
+        data = Vectors.dense([2.0, 2.0, 2.0])
+        ret = MLUtils.appendBias(data)
+        self.assertEqual(ret[3], 1.0)
+        self.assertEqual(type(ret), DenseVector)
+
+    def test_append_bias_with_sp_vector(self):
+        data = Vectors.sparse(3, {0: 2.0, 2: 2.0})
+        expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0})
+        # Returned value must be SparseVector
+        ret = MLUtils.appendBias(data)
+        self.assertEqual(ret, expected)
+        self.assertEqual(type(ret), SparseVector)
+
+    def test_load_vectors(self):
+        import shutil
+        data = [
+            [1.0, 2.0, 3.0],
+            [1.0, 2.0, 3.0]
+        ]
+        temp_dir = tempfile.mkdtemp()
+        load_vectors_path = os.path.join(temp_dir, "test_load_vectors")
+        try:
+            self.sc.parallelize(data).saveAsTextFile(load_vectors_path)
+            ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path)
+            ret = ret_rdd.collect()
+            self.assertEqual(len(ret), 2)
+            self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0]))
+            self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0]))
+        except:
+            self.fail()
+        finally:
+            shutil.rmtree(load_vectors_path)
+
+
 if __name__ == "__main__":
     if not _have_scipy:
         print("NOTE: Skipping SciPy tests as it does not seem to be installed")
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 348238319e407..875d3b2d642c6 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -169,6 +169,28 @@ def loadLabeledPoints(sc, path, minPartitions=None):
         minPartitions = minPartitions or min(sc.defaultParallelism, 2)
         return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
 
+    @staticmethod
+    def appendBias(data):
+        """
+        Returns a new vector with `1.0` (bias) appended to
+        the end of the input vector.
+        """
+        vec = _convert_to_vector(data)
+        if isinstance(vec, SparseVector):
+            newIndices = np.append(vec.indices, len(vec))
+            newValues = np.append(vec.values, 1.0)
+            return SparseVector(len(vec) + 1, newIndices, newValues)
+        else:
+            return _convert_to_vector(np.append(vec.toArray(), 1.0))
+
+    @staticmethod
+    def loadVectors(sc, path):
+        """
+        Loads vectors saved using `RDD[Vector].saveAsTextFile`
+        with the default number of partitions.
+        """
+        return callMLlibFunc("loadVectors", sc, path)
+
 
 class Saveable(object):
     """
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 1b64be23a667e..3218bed5c74fc 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -121,10 +121,23 @@ def _parse_memory(s):
 
 
 def _load_from_socket(port, serializer):
-    sock = socket.socket()
-    sock.settimeout(3)
+    sock = None
+    # Support for both IPv4 and IPv6.
+    # On most of IPv6-ready systems, IPv6 will take precedence.
+    for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
+        af, socktype, proto, canonname, sa = res
+        sock = socket.socket(af, socktype, proto)
+        try:
+            sock.settimeout(3)
+            sock.connect(sa)
+        except socket.error:
+            sock.close()
+            sock = None
+            continue
+        break
+    if not sock:
+        raise Exception("could not open socket")
     try:
-        sock.connect(("localhost", port))
         rf = sock.makefile("rb", 65536)
         for item in serializer.load_stream(rf):
             yield item
@@ -687,12 +700,14 @@ def groupBy(self, f, numPartitions=None):
         return self.map(lambda x: (f(x), x)).groupByKey(numPartitions)
 
     @ignore_unicode_prefix
-    def pipe(self, command, env={}):
+    def pipe(self, command, env={}, checkCode=False):
         """
         Return an RDD created by piping elements to a forked external process.
 
         >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect()
         [u'1', u'2', u'', u'3']
+
+        :param checkCode: whether or not to check the return value of the shell command.
         """
         def func(iterator):
             pipe = Popen(
@@ -704,7 +719,17 @@ def pipe_objs(out):
                     out.write(s.encode('utf-8'))
                 out.close()
             Thread(target=pipe_objs, args=[pipe.stdin]).start()
-            return (x.rstrip(b'\n').decode('utf-8') for x in iter(pipe.stdout.readline, b''))
+
+            def check_return_code():
+                pipe.wait()
+                if checkCode and pipe.returncode:
+                    raise Exception("Pipe function `%s' exited "
+                                    "with error code %d" % (command, pipe.returncode))
+                else:
+                    for i in range(0):
+                        yield i
+            return (x.rstrip(b'\n').decode('utf-8') for x in
+                    chain(iter(pipe.stdout.readline, b''), check_return_code()))
         return self.mapPartitions(func)
 
     def foreach(self, f):
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index dc239226e6d3c..c93a15badae29 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -30,9 +30,10 @@
 from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
 from pyspark.sql import since
 from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
-    _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
+    _infer_schema, _has_nulltype, _merge_type, _create_converter
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.readwriter import DataFrameReader
+from pyspark.sql.utils import install_exception_handler
 
 try:
     import pandas
@@ -96,6 +97,7 @@ def __init__(self, sparkContext, sqlContext=None):
         self._jvm = self._sc._jvm
         self._scala_SQLContext = sqlContext
         _monkey_patch_RDD(self)
+        install_exception_handler()
 
     @property
     def _ssql_ctx(self):
@@ -203,7 +205,37 @@ def registerFunction(self, name, f, returnType=StringType()):
                                             self._sc._javaAccumulator,
                                             returnType.json())
 
+    def _inferSchemaFromList(self, data):
+        """
+        Infer schema from list of Row or tuple.
+
+        :param data: list of Row or tuple
+        :return: StructType
+        """
+        if not data:
+            raise ValueError("can not infer schema from empty dataset")
+        first = data[0]
+        if type(first) is dict:
+            warnings.warn("inferring schema from dict is deprecated,"
+                          "please use pyspark.sql.Row instead")
+        schema = _infer_schema(first)
+        if _has_nulltype(schema):
+            for r in data:
+                schema = _merge_type(schema, _infer_schema(r))
+                if not _has_nulltype(schema):
+                    break
+            else:
+                raise ValueError("Some of types cannot be determined after inferring")
+        return schema
+
     def _inferSchema(self, rdd, samplingRatio=None):
+        """
+        Infer schema from an RDD of Row or tuple.
+
+        :param rdd: an RDD of Row or tuple
+        :param samplingRatio: sampling ratio, or no sampling (default)
+        :return: StructType
+        """
         first = rdd.first()
         if not first:
             raise ValueError("The first row in RDD is empty, "
@@ -312,16 +344,20 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
 
         >>> sqlContext.createDataFrame(df.toPandas()).collect()  # doctest: +SKIP
         [Row(name=u'Alice', age=1)]
+        >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]]).collect())  # doctest: +SKIP
+        [Row(0=1, 1=2)]
         """
         if isinstance(data, DataFrame):
             raise TypeError("data is already a DataFrame")
 
         if has_pandas and isinstance(data, pandas.DataFrame):
             if schema is None:
-                schema = list(data.columns)
+                schema = [str(x) for x in data.columns]
             data = [r.tolist() for r in data.to_records(index=False)]
 
         if not isinstance(data, RDD):
+            if not isinstance(data, list):
+                data = list(data)
             try:
                 # data could be list, tuple, generator ...
                 rdd = self._sc.parallelize(data)
@@ -330,32 +366,29 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
         else:
             rdd = data
 
-        if schema is None:
-            schema = self._inferSchema(rdd, samplingRatio)
+        if schema is None or isinstance(schema, (list, tuple)):
+            if isinstance(data, RDD):
+                struct = self._inferSchema(rdd, samplingRatio)
+            else:
+                struct = self._inferSchemaFromList(data)
+            if isinstance(schema, (list, tuple)):
+                for i, name in enumerate(schema):
+                    struct.fields[i].name = name
+            schema = struct
             converter = _create_converter(schema)
             rdd = rdd.map(converter)
 
-        if isinstance(schema, (list, tuple)):
-            first = rdd.first()
-            if not isinstance(first, (list, tuple)):
-                raise TypeError("each row in `rdd` should be list or tuple, "
-                                "but got %r" % type(first))
-            row_cls = Row(*schema)
-            schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio)
-
-        # take the first few rows to verify schema
-        rows = rdd.take(10)
-        # Row() cannot been deserialized by Pyrolite
-        if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
-            rdd = rdd.map(tuple)
+        elif isinstance(schema, StructType):
+            # take the first few rows to verify schema
             rows = rdd.take(10)
+            for row in rows:
+                _verify_type(row, schema)
 
-        for row in rows:
-            _verify_type(row, schema)
+        else:
+            raise TypeError("schema should be StructType or list or None")
 
         # convert python objects to sql data
-        converter = _python_to_sql_converter(schema)
-        rdd = rdd.map(converter)
+        rdd = rdd.map(schema.toInternal)
 
         jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
         df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 152b87351db31..83e02b85f06f1 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -31,7 +31,7 @@
 from pyspark.storagelevel import StorageLevel
 from pyspark.traceback_utils import SCCallSiteSync
 from pyspark.sql import since
-from pyspark.sql.types import _create_cls, _parse_datatype_json_string
+from pyspark.sql.types import _parse_datatype_json_string
 from pyspark.sql.column import Column, _to_seq, _to_java_column
 from pyspark.sql.readwriter import DataFrameWriter
 from pyspark.sql.types import *
@@ -83,15 +83,7 @@ def rdd(self):
         """
         if self._lazy_rdd is None:
             jrdd = self._jdf.javaToPython()
-            rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
-            schema = self.schema
-
-            def applySchema(it):
-                cls = _create_cls(schema)
-                return map(cls, it)
-
-            self._lazy_rdd = rdd.mapPartitions(applySchema)
-
+            self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
         return self._lazy_rdd
 
     @property
@@ -247,9 +239,12 @@ def isLocal(self):
         return self._jdf.isLocal()
 
     @since(1.3)
-    def show(self, n=20):
+    def show(self, n=20, truncate=True):
         """Prints the first ``n`` rows to the console.
 
+        :param n: Number of rows to show.
+        :param truncate: Whether truncate long strings and align cells right.
+
         >>> df
         DataFrame[age: int, name: string]
         >>> df.show()
@@ -260,7 +255,7 @@ def show(self, n=20):
         |  5|  Bob|
         +---+-----+
         """
-        print(self._jdf.showString(n))
+        print(self._jdf.showString(n, truncate))
 
     def __repr__(self):
         return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
@@ -284,9 +279,7 @@ def collect(self):
         """
         with SCCallSiteSync(self._sc) as css:
             port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd())
-        rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
-        cls = _create_cls(self.schema)
-        return [cls(r) for r in rs]
+        return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
 
     @ignore_unicode_prefix
     @since(1.3)
@@ -481,13 +474,12 @@ def dtypes(self):
         return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]
 
     @property
-    @ignore_unicode_prefix
     @since(1.3)
     def columns(self):
         """Returns all column names as a list.
 
         >>> df.columns
-        [u'age', u'name']
+        ['age', 'name']
         """
         return [f.name for f in self.schema.fields]
 
@@ -800,11 +792,11 @@ def groupBy(self, *cols):
             Each element should be a column name (string) or an expression (:class:`Column`).
 
         >>> df.groupBy().avg().collect()
-        [Row(AVG(age)=3.5)]
+        [Row(avg(age)=3.5)]
         >>> df.groupBy('name').agg({'age': 'mean'}).collect()
-        [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
+        [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
         >>> df.groupBy(df.name).avg().collect()
-        [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
+        [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
         >>> df.groupBy(['name', df.age]).count().collect()
         [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
         """
@@ -862,10 +854,10 @@ def agg(self, *exprs):
         (shorthand for ``df.groupBy.agg()``).
 
         >>> df.agg({"age": "max"}).collect()
-        [Row(MAX(age)=5)]
+        [Row(max(age)=5)]
         >>> from pyspark.sql import functions as F
         >>> df.agg(F.min(df.age)).collect()
-        [Row(MIN(age)=2)]
+        [Row(min(age)=2)]
         """
         return self.groupBy().agg(*exprs)
 
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 7d3d0361610b7..dca39fa833435 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -39,11 +39,15 @@
     'coalesce',
     'countDistinct',
     'explode',
+    'log2',
+    'md5',
     'monotonicallyIncreasingId',
     'rand',
     'randn',
+    'sha1',
     'sha2',
     'sparkPartitionId',
+    'strlen',
     'struct',
     'udf',
     'when']
@@ -262,7 +266,7 @@ def coalesce(*cols):
 
     >>> cDf.select(coalesce(cDf["a"], cDf["b"])).show()
     +-------------+
-    |Coalesce(a,b)|
+    |coalesce(a,b)|
     +-------------+
     |         null|
     |            1|
@@ -271,7 +275,7 @@ def coalesce(*cols):
 
     >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show()
     +----+----+---------------+
-    |   a|   b|Coalesce(a,0.0)|
+    |   a|   b|coalesce(a,0.0)|
     +----+----+---------------+
     |null|null|            0.0|
     |   1|null|            1.0|
@@ -319,6 +323,33 @@ def explode(col):
     return Column(jc)
 
 
+@ignore_unicode_prefix
+@since(1.5)
+def levenshtein(left, right):
+    """Computes the Levenshtein distance of the two given strings.
+
+    >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
+    >>> df0.select(levenshtein('l', 'r').alias('d')).collect()
+    [Row(d=3)]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right))
+    return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def md5(col):
+    """Calculates the MD5 digest and returns the value as a 32 character hex string.
+
+    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect()
+    [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.md5(_to_java_column(col))
+    return Column(jc)
+
+
 @since(1.4)
 def monotonicallyIncreasingId():
     """A column that generates monotonically increasing 64-bit integers.
@@ -364,6 +395,47 @@ def randn(seed=None):
     return Column(jc)
 
 
+@ignore_unicode_prefix
+@since(1.5)
+def hex(col):
+    """Computes hex value of the given column, which could be StringType,
+    BinaryType, IntegerType or LongType.
+
+    >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect()
+    [Row(hex(a)=u'414243', hex(b)=u'3')]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.hex(_to_java_column(col))
+    return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def unhex(col):
+    """Inverse of hex. Interprets each pair of characters as a hexadecimal number
+    and converts to the byte representation of number.
+
+    >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect()
+    [Row(unhex(a)=bytearray(b'ABC'))]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.unhex(_to_java_column(col))
+    return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def sha1(col):
+    """Returns the hex string result of SHA-1.
+
+    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect()
+    [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.sha1(_to_java_column(col))
+    return Column(jc)
+
+
 @ignore_unicode_prefix
 @since(1.5)
 def sha2(col, numBits):
@@ -382,6 +454,43 @@ def sha2(col, numBits):
     return Column(jc)
 
 
+@since(1.5)
+def shiftLeft(col, numBits):
+    """Shift the the given value numBits left.
+
+    >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect()
+    [Row(r=42)]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.shiftLeft(_to_java_column(col), numBits)
+    return Column(jc)
+
+
+@since(1.5)
+def shiftRight(col, numBits):
+    """Shift the the given value numBits right.
+
+    >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect()
+    [Row(r=21)]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits)
+    return Column(jc)
+
+
+@since(1.5)
+def shiftRightUnsigned(col, numBits):
+    """Unsigned shift the the given value numBits right.
+
+    >>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\
+    .collect()
+    [Row(r=9223372036854775787)]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits)
+    return Column(jc)
+
+
 @since(1.4)
 def sparkPartitionId():
     """A column for partition ID of the Spark task.
@@ -395,13 +504,24 @@ def sparkPartitionId():
     return Column(sc._jvm.functions.sparkPartitionId())
 
 
+@ignore_unicode_prefix
+@since(1.5)
+def strlen(col):
+    """Calculates the length of a string expression.
+
+    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect()
+    [Row(length=3)]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.strlen(_to_java_column(col)))
+
+
 @ignore_unicode_prefix
 @since(1.4)
 def struct(*cols):
     """Creates a new struct column.
 
     :param cols: list of column names (string) or list of :class:`Column` expressions
-        that are named or aliased.
 
     >>> df.select(struct('age', 'name').alias("struct")).collect()
     [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]
@@ -457,6 +577,17 @@ def log(arg1, arg2=None):
     return Column(jc)
 
 
+@since(1.5)
+def log2(col):
+    """Returns the base-2 logarithm of the argument.
+
+    >>> sqlContext.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect()
+    [Row(log2=2.0)]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.log2(_to_java_column(col)))
+
+
 @since(1.4)
 def lag(col, count=1, default=None):
     """
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 5a37a673ee80c..04594d5a836ce 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -75,11 +75,11 @@ def agg(self, *exprs):
 
         >>> gdf = df.groupBy(df.name)
         >>> gdf.agg({"*": "count"}).collect()
-        [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
+        [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)]
 
         >>> from pyspark.sql import functions as F
         >>> gdf.agg(F.min(df.age)).collect()
-        [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
+        [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)]
         """
         assert exprs, "exprs should not be empty"
         if len(exprs) == 1 and isinstance(exprs[0], dict):
@@ -110,9 +110,9 @@ def mean(self, *cols):
         :param cols: list of column names (string). Non-numeric columns are ignored.
 
         >>> df.groupBy().mean('age').collect()
-        [Row(AVG(age)=3.5)]
+        [Row(avg(age)=3.5)]
         >>> df3.groupBy().mean('age', 'height').collect()
-        [Row(AVG(age)=3.5, AVG(height)=82.5)]
+        [Row(avg(age)=3.5, avg(height)=82.5)]
         """
 
     @df_varargs_api
@@ -125,9 +125,9 @@ def avg(self, *cols):
         :param cols: list of column names (string). Non-numeric columns are ignored.
 
         >>> df.groupBy().avg('age').collect()
-        [Row(AVG(age)=3.5)]
+        [Row(avg(age)=3.5)]
         >>> df3.groupBy().avg('age', 'height').collect()
-        [Row(AVG(age)=3.5, AVG(height)=82.5)]
+        [Row(avg(age)=3.5, avg(height)=82.5)]
         """
 
     @df_varargs_api
@@ -136,9 +136,9 @@ def max(self, *cols):
         """Computes the max value for each numeric columns for each group.
 
         >>> df.groupBy().max('age').collect()
-        [Row(MAX(age)=5)]
+        [Row(max(age)=5)]
         >>> df3.groupBy().max('age', 'height').collect()
-        [Row(MAX(age)=5, MAX(height)=85)]
+        [Row(max(age)=5, max(height)=85)]
         """
 
     @df_varargs_api
@@ -149,9 +149,9 @@ def min(self, *cols):
         :param cols: list of column names (string). Non-numeric columns are ignored.
 
         >>> df.groupBy().min('age').collect()
-        [Row(MIN(age)=2)]
+        [Row(min(age)=2)]
         >>> df3.groupBy().min('age', 'height').collect()
-        [Row(MIN(age)=2, MIN(height)=80)]
+        [Row(min(age)=2, min(height)=80)]
         """
 
     @df_varargs_api
@@ -162,9 +162,9 @@ def sum(self, *cols):
         :param cols: list of column names (string). Non-numeric columns are ignored.
 
         >>> df.groupBy().sum('age').collect()
-        [Row(SUM(age)=7)]
+        [Row(sum(age)=7)]
         >>> df3.groupBy().sum('age', 'height').collect()
-        [Row(SUM(age)=7, SUM(height)=165)]
+        [Row(sum(age)=7, sum(height)=165)]
         """
 
 
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index ffee43a94baba..241eac45cfe36 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1,3 +1,4 @@
+# -*- encoding: utf-8 -*-
 #
 # Licensed to the Apache Software Foundation (ASF) under one or more
 # contributor license agreements.  See the NOTICE file distributed with
@@ -46,6 +47,7 @@
 from pyspark.tests import ReusedPySparkTestCase
 from pyspark.sql.functions import UserDefinedFunction
 from pyspark.sql.window import Window
+from pyspark.sql.utils import AnalysisException
 
 
 class UTC(datetime.tzinfo):
@@ -149,6 +151,17 @@ def test_range(self):
         self.assertEqual(self.sqlCtx.range(-2).count(), 0)
         self.assertEqual(self.sqlCtx.range(3).count(), 3)
 
+    def test_duplicated_column_names(self):
+        df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"])
+        row = df.select('*').first()
+        self.assertEqual(1, row[0])
+        self.assertEqual(2, row[1])
+        self.assertEqual("Row(c=1, c=2)", str(row))
+        # Cannot access columns
+        self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
+        self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
+        self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
+
     def test_explode(self):
         from pyspark.sql.functions import explode
         d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
@@ -399,6 +412,14 @@ def test_apply_schema_with_udt(self):
         point = df.head().point
         self.assertEquals(point, ExamplePoint(1.0, 2.0))
 
+    def test_udf_with_udt(self):
+        from pyspark.sql.tests import ExamplePoint
+        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+        df = self.sc.parallelize([row]).toDF()
+        self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
+        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
+        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+
     def test_parquet_with_udt(self):
         from pyspark.sql.tests import ExamplePoint
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
@@ -516,6 +537,35 @@ def test_between_function(self):
         self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
                          df.filter(df.a.between(df.b, df.c)).collect())
 
+    def test_struct_type(self):
+        from pyspark.sql.types import StructType, StringType, StructField
+        struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+        struct2 = StructType([StructField("f1", StringType(), True),
+                              StructField("f2", StringType(), True, None)])
+        self.assertEqual(struct1, struct2)
+
+        struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+        struct2 = StructType([StructField("f1", StringType(), True)])
+        self.assertNotEqual(struct1, struct2)
+
+        struct1 = (StructType().add(StructField("f1", StringType(), True))
+                   .add(StructField("f2", StringType(), True, None)))
+        struct2 = StructType([StructField("f1", StringType(), True),
+                              StructField("f2", StringType(), True, None)])
+        self.assertEqual(struct1, struct2)
+
+        struct1 = (StructType().add(StructField("f1", StringType(), True))
+                   .add(StructField("f2", StringType(), True, None)))
+        struct2 = StructType([StructField("f1", StringType(), True)])
+        self.assertNotEqual(struct1, struct2)
+
+        # Catch exception raised during improper construction
+        try:
+            struct1 = StructType().add("name")
+            self.assertEqual(1, 0)
+        except ValueError:
+            self.assertEqual(1, 1)
+
     def test_save_and_load(self):
         df = self.df
         tmpPath = tempfile.mkdtemp()
@@ -598,6 +648,14 @@ def test_access_column(self):
         self.assertRaises(IndexError, lambda: df["bad_key"])
         self.assertRaises(TypeError, lambda: df[{}])
 
+    def test_column_name_with_non_ascii(self):
+        df = self.sqlCtx.createDataFrame([(1,)], ["数量"])
+        self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema)
+        self.assertEqual("DataFrame[数量: bigint]", str(df))
+        self.assertEqual([("数量", 'bigint')], df.dtypes)
+        self.assertEqual(1, df.select("数量").first()[0])
+        self.assertEqual(1, df.select(df["数量"]).first()[0])
+
     def test_access_nested_types(self):
         df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
         self.assertEqual(1, df.select(df.l[0]).first()[0])
@@ -647,19 +705,31 @@ def test_filter_with_datetime(self):
     def test_time_with_timezone(self):
         day = datetime.date.today()
         now = datetime.datetime.now()
-        ts = time.mktime(now.timetuple()) + now.microsecond / 1e6
+        ts = time.mktime(now.timetuple())
         # class in __main__ is not serializable
         from pyspark.sql.tests import UTC
         utc = UTC()
-        utcnow = datetime.datetime.fromtimestamp(ts, utc)
+        utcnow = datetime.datetime.utcfromtimestamp(ts)  # without microseconds
+        # add microseconds to utcnow (keeping year,month,day,hour,minute,second)
+        utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc)))
         df = self.sqlCtx.createDataFrame([(day, now, utcnow)])
         day1, now1, utcnow1 = df.first()
-        # Pyrolite serialize java.sql.Date as datetime, will be fixed in new version
-        self.assertEqual(day1.date(), day)
-        # Pyrolite does not support microsecond, the error should be
-        # less than 1 millisecond
-        self.assertTrue(now - now1 < datetime.timedelta(0.001))
-        self.assertTrue(now - utcnow1 < datetime.timedelta(0.001))
+        self.assertEqual(day1, day)
+        self.assertEqual(now, now1)
+        self.assertEqual(now, utcnow1)
+
+    def test_decimal(self):
+        from decimal import Decimal
+        schema = StructType([StructField("decimal", DecimalType(10, 5))])
+        df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema)
+        row = df.select(df.decimal + 1).first()
+        self.assertEqual(row[0], Decimal("4.14159"))
+        tmpPath = tempfile.mkdtemp()
+        shutil.rmtree(tmpPath)
+        df.write.parquet(tmpPath)
+        df2 = self.sqlCtx.read.parquet(tmpPath)
+        row = df2.first()
+        self.assertEqual(row[0], Decimal("3.14159"))
 
     def test_dropna(self):
         schema = StructType([
@@ -818,6 +888,12 @@ def test_replace(self):
         self.assertEqual(row.age, 10)
         self.assertEqual(row.height, None)
 
+    def test_capture_analysis_exception(self):
+        self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
+        self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
+        # RuntimeException should not be captured
+        self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc"))
+
 
 class HiveContextSQLTests(ReusedPySparkTestCase):
 
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 23d9adb0daea1..f75791fad1612 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -20,13 +20,9 @@
 import time
 import datetime
 import calendar
-import keyword
-import warnings
 import json
 import re
-import weakref
 from array import array
-from operator import itemgetter
 
 if sys.version >= "3":
     long = int
@@ -71,6 +67,26 @@ def json(self):
                           separators=(',', ':'),
                           sort_keys=True)
 
+    def needConversion(self):
+        """
+        Does this type need to conversion between Python object and internal SQL object.
+
+        This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
+        """
+        return False
+
+    def toInternal(self, obj):
+        """
+        Converts a Python object into an internal SQL object.
+        """
+        return obj
+
+    def fromInternal(self, obj):
+        """
+        Converts an internal SQL object into a native Python object.
+        """
+        return obj
+
 
 # This singleton pattern does not work with pickle, you will get
 # another object after pickle and unpickle
@@ -143,6 +159,17 @@ class DateType(AtomicType):
 
     __metaclass__ = DataTypeSingleton
 
+    EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
+
+    def needConversion(self):
+        return True
+
+    def toInternal(self, d):
+        return d and d.toordinal() - self.EPOCH_ORDINAL
+
+    def fromInternal(self, v):
+        return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL)
+
 
 class TimestampType(AtomicType):
     """Timestamp (datetime.datetime) data type.
@@ -150,6 +177,20 @@ class TimestampType(AtomicType):
 
     __metaclass__ = DataTypeSingleton
 
+    def needConversion(self):
+        return True
+
+    def toInternal(self, dt):
+        if dt is not None:
+            seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
+                       else time.mktime(dt.timetuple()))
+            return int(seconds * 1e6 + dt.microsecond)
+
+    def fromInternal(self, ts):
+        if ts is not None:
+            # using int to avoid precision loss in float
+            return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000)
+
 
 class DecimalType(FractionalType):
     """Decimal (decimal.Decimal) data type.
@@ -259,6 +300,19 @@ def fromJson(cls, json):
         return ArrayType(_parse_datatype_json_value(json["elementType"]),
                          json["containsNull"])
 
+    def needConversion(self):
+        return self.elementType.needConversion()
+
+    def toInternal(self, obj):
+        if not self.needConversion():
+            return obj
+        return obj and [self.elementType.toInternal(v) for v in obj]
+
+    def fromInternal(self, obj):
+        if not self.needConversion():
+            return obj
+        return obj and [self.elementType.fromInternal(v) for v in obj]
+
 
 class MapType(DataType):
     """Map data type.
@@ -304,6 +358,21 @@ def fromJson(cls, json):
                        _parse_datatype_json_value(json["valueType"]),
                        json["valueContainsNull"])
 
+    def needConversion(self):
+        return self.keyType.needConversion() or self.valueType.needConversion()
+
+    def toInternal(self, obj):
+        if not self.needConversion():
+            return obj
+        return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v))
+                            for k, v in obj.items())
+
+    def fromInternal(self, obj):
+        if not self.needConversion():
+            return obj
+        return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v))
+                            for k, v in obj.items())
+
 
 class StructField(DataType):
     """A field in :class:`StructType`.
@@ -311,7 +380,7 @@ class StructField(DataType):
     :param name: string, name of the field.
     :param dataType: :class:`DataType` of the field.
     :param nullable: boolean, whether the field can be null (None) or not.
-    :param metadata: a dict from string to simple type that can be serialized to JSON automatically
+    :param metadata: a dict from string to simple type that can be toInternald to JSON automatically
     """
 
     def __init__(self, name, dataType, nullable=True, metadata=None):
@@ -324,6 +393,8 @@ def __init__(self, name, dataType, nullable=True, metadata=None):
         False
         """
         assert isinstance(dataType, DataType), "dataType should be DataType"
+        if not isinstance(name, str):
+            name = name.encode('utf-8')
         self.name = name
         self.dataType = dataType
         self.nullable = nullable
@@ -349,14 +420,22 @@ def fromJson(cls, json):
                            json["nullable"],
                            json["metadata"])
 
+    def needConversion(self):
+        return self.dataType.needConversion()
+
+    def toInternal(self, obj):
+        return self.dataType.toInternal(obj)
+
+    def fromInternal(self, obj):
+        return self.dataType.fromInternal(obj)
+
 
 class StructType(DataType):
     """Struct type, consisting of a list of :class:`StructField`.
 
     This is the data type representing a :class:`Row`.
     """
-
-    def __init__(self, fields):
+    def __init__(self, fields=None):
         """
         >>> struct1 = StructType([StructField("f1", StringType(), True)])
         >>> struct2 = StructType([StructField("f1", StringType(), True)])
@@ -368,8 +447,58 @@ def __init__(self, fields):
         >>> struct1 == struct2
         False
         """
-        assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType"
-        self.fields = fields
+        if not fields:
+            self.fields = []
+            self.names = []
+        else:
+            self.fields = fields
+            self.names = [f.name for f in fields]
+            assert all(isinstance(f, StructField) for f in fields),\
+                "fields should be a list of StructField"
+        self._needSerializeFields = None
+
+    def add(self, field, data_type=None, nullable=True, metadata=None):
+        """
+        Construct a StructType by adding new elements to it to define the schema. The method accepts
+        either:
+            a) A single parameter which is a StructField object.
+            b) Between 2 and 4 parameters as (name, data_type, nullable (optional),
+             metadata(optional). The data_type parameter may be either a String or a DataType object
+
+        >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+        >>> struct2 = StructType([StructField("f1", StringType(), True),\
+         StructField("f2", StringType(), True, None)])
+        >>> struct1 == struct2
+        True
+        >>> struct1 = StructType().add(StructField("f1", StringType(), True))
+        >>> struct2 = StructType([StructField("f1", StringType(), True)])
+        >>> struct1 == struct2
+        True
+        >>> struct1 = StructType().add("f1", "string", True)
+        >>> struct2 = StructType([StructField("f1", StringType(), True)])
+        >>> struct1 == struct2
+        True
+
+        :param field: Either the name of the field or a StructField object
+        :param data_type: If present, the DataType of the StructField to create
+        :param nullable: Whether the field to add should be nullable (default True)
+        :param metadata: Any additional metadata (default None)
+        :return: a new updated StructType
+        """
+        if isinstance(field, StructField):
+            self.fields.append(field)
+            self.names.append(field.name)
+        else:
+            if isinstance(field, str) and data_type is None:
+                raise ValueError("Must specify DataType if passing name of struct_field to create.")
+
+            if isinstance(data_type, str):
+                data_type_f = _parse_datatype_json_value(data_type)
+            else:
+                data_type_f = data_type
+            self.fields.append(StructField(field, data_type_f, nullable, metadata))
+            self.names.append(field)
+        return self
 
     def simpleString(self):
         return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields))
@@ -386,6 +515,41 @@ def jsonValue(self):
     def fromJson(cls, json):
         return StructType([StructField.fromJson(f) for f in json["fields"]])
 
+    def needConversion(self):
+        # We need convert Row()/namedtuple into tuple()
+        return True
+
+    def toInternal(self, obj):
+        if obj is None:
+            return
+
+        if self._needSerializeFields is None:
+            self._needSerializeFields = any(f.needConversion() for f in self.fields)
+
+        if self._needSerializeFields:
+            if isinstance(obj, dict):
+                return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields))
+            elif isinstance(obj, (tuple, list)):
+                return tuple(f.toInternal(v) for f, v in zip(self.fields, obj))
+            else:
+                raise ValueError("Unexpected tuple %r with StructType" % obj)
+        else:
+            if isinstance(obj, dict):
+                return tuple(obj.get(n) for n in self.names)
+            elif isinstance(obj, (list, tuple)):
+                return tuple(obj)
+            else:
+                raise ValueError("Unexpected tuple %r with StructType" % obj)
+
+    def fromInternal(self, obj):
+        if obj is None:
+            return
+        if isinstance(obj, Row):
+            # it's already converted by pickler
+            return obj
+        values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)]
+        return _create_row(self.names, values)
+
 
 class UserDefinedType(DataType):
     """User-defined type (UDT).
@@ -418,17 +582,35 @@ def scalaUDT(cls):
         """
         raise NotImplementedError("UDT must have a paired Scala UDT.")
 
+    def needConversion(self):
+        return True
+
+    @classmethod
+    def _cachedSqlType(cls):
+        """
+        Cache the sqlType() into class, because it's heavy used in `toInternal`.
+        """
+        if not hasattr(cls, "_cached_sql_type"):
+            cls._cached_sql_type = cls.sqlType()
+        return cls._cached_sql_type
+
+    def toInternal(self, obj):
+        return self._cachedSqlType().toInternal(self.serialize(obj))
+
+    def fromInternal(self, obj):
+        return self.deserialize(self._cachedSqlType().fromInternal(obj))
+
     def serialize(self, obj):
         """
         Converts the a user-type object into a SQL datum.
         """
-        raise NotImplementedError("UDT must implement serialize().")
+        raise NotImplementedError("UDT must implement toInternal().")
 
     def deserialize(self, datum):
         """
         Converts a SQL datum into a user-type object.
         """
-        raise NotImplementedError("UDT must implement deserialize().")
+        raise NotImplementedError("UDT must implement fromInternal().")
 
     def simpleString(self):
         return 'udt'
@@ -625,112 +807,6 @@ def _infer_schema(row):
     return StructType(fields)
 
 
-def _need_python_to_sql_conversion(dataType):
-    """
-    Checks whether we need python to sql conversion for the given type.
-    For now, only UDTs need this conversion.
-
-    >>> _need_python_to_sql_conversion(DoubleType())
-    False
-    >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
-    ...                       StructField("values", ArrayType(DoubleType(), False), False)])
-    >>> _need_python_to_sql_conversion(schema0)
-    False
-    >>> _need_python_to_sql_conversion(ExamplePointUDT())
-    True
-    >>> schema1 = ArrayType(ExamplePointUDT(), False)
-    >>> _need_python_to_sql_conversion(schema1)
-    True
-    >>> schema2 = StructType([StructField("label", DoubleType(), False),
-    ...                       StructField("point", ExamplePointUDT(), False)])
-    >>> _need_python_to_sql_conversion(schema2)
-    True
-    """
-    if isinstance(dataType, StructType):
-        return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
-    elif isinstance(dataType, ArrayType):
-        return _need_python_to_sql_conversion(dataType.elementType)
-    elif isinstance(dataType, MapType):
-        return _need_python_to_sql_conversion(dataType.keyType) or \
-            _need_python_to_sql_conversion(dataType.valueType)
-    elif isinstance(dataType, UserDefinedType):
-        return True
-    elif isinstance(dataType, (DateType, TimestampType)):
-        return True
-    else:
-        return False
-
-
-EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
-
-
-def _python_to_sql_converter(dataType):
-    """
-    Returns a converter that converts a Python object into a SQL datum for the given type.
-
-    >>> conv = _python_to_sql_converter(DoubleType())
-    >>> conv(1.0)
-    1.0
-    >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
-    >>> conv([1.0, 2.0])
-    [1.0, 2.0]
-    >>> conv = _python_to_sql_converter(ExamplePointUDT())
-    >>> conv(ExamplePoint(1.0, 2.0))
-    [1.0, 2.0]
-    >>> schema = StructType([StructField("label", DoubleType(), False),
-    ...                      StructField("point", ExamplePointUDT(), False)])
-    >>> conv = _python_to_sql_converter(schema)
-    >>> conv((1.0, ExamplePoint(1.0, 2.0)))
-    (1.0, [1.0, 2.0])
-    """
-    if not _need_python_to_sql_conversion(dataType):
-        return lambda x: x
-
-    if isinstance(dataType, StructType):
-        names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
-        converters = [_python_to_sql_converter(t) for t in types]
-
-        def converter(obj):
-            if isinstance(obj, dict):
-                return tuple(c(obj.get(n)) for n, c in zip(names, converters))
-            elif isinstance(obj, tuple):
-                if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
-                    return tuple(c(v) for c, v in zip(converters, obj))
-                elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):  # k-v pairs
-                    d = dict(obj)
-                    return tuple(c(d.get(n)) for n, c in zip(names, converters))
-                else:
-                    return tuple(c(v) for c, v in zip(converters, obj))
-            elif obj is not None:
-                raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
-        return converter
-    elif isinstance(dataType, ArrayType):
-        element_converter = _python_to_sql_converter(dataType.elementType)
-        return lambda a: a and [element_converter(v) for v in a]
-    elif isinstance(dataType, MapType):
-        key_converter = _python_to_sql_converter(dataType.keyType)
-        value_converter = _python_to_sql_converter(dataType.valueType)
-        return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
-
-    elif isinstance(dataType, UserDefinedType):
-        return lambda obj: obj and dataType.serialize(obj)
-
-    elif isinstance(dataType, DateType):
-        return lambda d: d and d.toordinal() - EPOCH_ORDINAL
-
-    elif isinstance(dataType, TimestampType):
-
-        def to_posix_timstamp(dt):
-            if dt:
-                seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
-                           else time.mktime(dt.timetuple()))
-                return int(seconds * 1e7 + dt.microsecond * 10)
-        return to_posix_timstamp
-
-    else:
-        raise ValueError("Unexpected type %r" % dataType)
-
-
 def _has_nulltype(dt):
     """ Return whether there is NullType in `dt` or not """
     if isinstance(dt, StructType):
@@ -1018,19 +1094,26 @@ def _verify_type(obj, dataType):
     if obj is None:
         return
 
+    # StringType can work with any types
+    if isinstance(dataType, StringType):
+        return
+
     if isinstance(dataType, UserDefinedType):
         if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
             raise ValueError("%r is not an instance of type %r" % (obj, dataType))
-        _verify_type(dataType.serialize(obj), dataType.sqlType())
+        _verify_type(dataType.toInternal(obj), dataType.sqlType())
         return
 
     _type = type(dataType)
     assert _type in _acceptable_types, "unknown datatype: %s" % dataType
 
-    # subclass of them can not be deserialized in JVM
-    if type(obj) not in _acceptable_types[_type]:
-        raise TypeError("%s can not accept object in type %s"
-                        % (dataType, type(obj)))
+    if _type is StructType:
+        if not isinstance(obj, (tuple, list)):
+            raise TypeError("StructType can not accept object in type %s" % type(obj))
+    else:
+        # subclass of them can not be fromInternald in JVM
+        if type(obj) not in _acceptable_types[_type]:
+            raise TypeError("%s can not accept object in type %s" % (dataType, type(obj)))
 
     if isinstance(dataType, ArrayType):
         for i in obj:
@@ -1048,159 +1131,10 @@ def _verify_type(obj, dataType):
         for v, f in zip(obj, dataType.fields):
             _verify_type(v, f.dataType)
 
-_cached_cls = weakref.WeakValueDictionary()
-
-
-def _restore_object(dataType, obj):
-    """ Restore object during unpickling. """
-    # use id(dataType) as key to speed up lookup in dict
-    # Because of batched pickling, dataType will be the
-    # same object in most cases.
-    k = id(dataType)
-    cls = _cached_cls.get(k)
-    if cls is None or cls.__datatype is not dataType:
-        # use dataType as key to avoid create multiple class
-        cls = _cached_cls.get(dataType)
-        if cls is None:
-            cls = _create_cls(dataType)
-            _cached_cls[dataType] = cls
-        cls.__datatype = dataType
-        _cached_cls[k] = cls
-    return cls(obj)
-
-
-def _create_object(cls, v):
-    """ Create an customized object with class `cls`. """
-    # datetime.date would be deserialized as datetime.datetime
-    # from java type, so we need to set it back.
-    if cls is datetime.date and isinstance(v, datetime.datetime):
-        return v.date()
-    return cls(v) if v is not None else v
-
-
-def _create_getter(dt, i):
-    """ Create a getter for item `i` with schema """
-    cls = _create_cls(dt)
-
-    def getter(self):
-        return _create_object(cls, self[i])
-
-    return getter
-
-
-def _has_struct_or_date(dt):
-    """Return whether `dt` is or has StructType/DateType in it"""
-    if isinstance(dt, StructType):
-        return True
-    elif isinstance(dt, ArrayType):
-        return _has_struct_or_date(dt.elementType)
-    elif isinstance(dt, MapType):
-        return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType)
-    elif isinstance(dt, DateType):
-        return True
-    elif isinstance(dt, UserDefinedType):
-        return True
-    return False
-
-
-def _create_properties(fields):
-    """Create properties according to fields"""
-    ps = {}
-    for i, f in enumerate(fields):
-        name = f.name
-        if (name.startswith("__") and name.endswith("__")
-                or keyword.iskeyword(name)):
-            warnings.warn("field name %s can not be accessed in Python,"
-                          "use position to access it instead" % name)
-        if _has_struct_or_date(f.dataType):
-            # delay creating object until accessing it
-            getter = _create_getter(f.dataType, i)
-        else:
-            getter = itemgetter(i)
-        ps[name] = property(getter)
-    return ps
-
-
-def _create_cls(dataType):
-    """
-    Create an class by dataType
-
-    The created class is similar to namedtuple, but can have nested schema.
-
-    >>> schema = _parse_schema_abstract("a b c")
-    >>> row = (1, 1.0, "str")
-    >>> schema = _infer_schema_type(row, schema)
-    >>> obj = _create_cls(schema)(row)
-    >>> import pickle
-    >>> pickle.loads(pickle.dumps(obj))
-    Row(a=1, b=1.0, c='str')
-
-    >>> row = [[1], {"key": (1, 2.0)}]
-    >>> schema = _parse_schema_abstract("a[] b{c d}")
-    >>> schema = _infer_schema_type(row, schema)
-    >>> obj = _create_cls(schema)(row)
-    >>> pickle.loads(pickle.dumps(obj))
-    Row(a=[1], b={'key': Row(c=1, d=2.0)})
-    >>> pickle.loads(pickle.dumps(obj.a))
-    [1]
-    >>> pickle.loads(pickle.dumps(obj.b))
-    {'key': Row(c=1, d=2.0)}
-    """
-
-    if isinstance(dataType, ArrayType):
-        cls = _create_cls(dataType.elementType)
-
-        def List(l):
-            if l is None:
-                return
-            return [_create_object(cls, v) for v in l]
-
-        return List
-
-    elif isinstance(dataType, MapType):
-        kcls = _create_cls(dataType.keyType)
-        vcls = _create_cls(dataType.valueType)
-
-        def Dict(d):
-            if d is None:
-                return
-            return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items())
-
-        return Dict
-
-    elif isinstance(dataType, DateType):
-        return datetime.date
-
-    elif isinstance(dataType, UserDefinedType):
-        return lambda datum: dataType.deserialize(datum)
-
-    elif not isinstance(dataType, StructType):
-        # no wrapper for atomic types
-        return lambda x: x
-
-    class Row(tuple):
-
-        """ Row in DataFrame """
-        __datatype = dataType
-        __fields__ = tuple(f.name for f in dataType.fields)
-        __slots__ = ()
-
-        # create property for fast access
-        locals().update(_create_properties(dataType.fields))
-
-        def asDict(self):
-            """ Return as a dict """
-            return dict((n, getattr(self, n)) for n in self.__fields__)
-
-        def __repr__(self):
-            # call collect __repr__ for nested objects
-            return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
-                                          for n in self.__fields__))
-
-        def __reduce__(self):
-            return (_restore_object, (self.__datatype, tuple(self)))
 
-    return Row
+# This is used to unpickle a Row from JVM
+def _create_row_inbound_converter(dataType):
+    return lambda *a: dataType.fromInternal(a)
 
 
 def _create_row(fields, values):
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
new file mode 100644
index 0000000000000..cc5b2c088b7cc
--- /dev/null
+++ b/python/pyspark/sql/utils.py
@@ -0,0 +1,54 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import py4j
+
+
+class AnalysisException(Exception):
+    """
+    Failed to analyze a SQL query plan.
+    """
+
+
+def capture_sql_exception(f):
+    def deco(*a, **kw):
+        try:
+            return f(*a, **kw)
+        except py4j.protocol.Py4JJavaError as e:
+            s = e.java_exception.toString()
+            if s.startswith('org.apache.spark.sql.AnalysisException: '):
+                raise AnalysisException(s.split(': ', 1)[1])
+            raise
+    return deco
+
+
+def install_exception_handler():
+    """
+    Hook an exception handler into Py4j, which could capture some SQL exceptions in Java.
+
+    When calling Java API, it will call `get_return_value` to parse the returned object.
+    If any exception happened in JVM, the result will be Java exception object, it raise
+    py4j.protocol.Py4JJavaError. We replace the original `get_return_value` with one that
+    could capture the Java exception and throw a Python one (with the same error message).
+
+    It's idempotent, could be called multiple times.
+    """
+    original = py4j.protocol.get_return_value
+    # The original `get_return_value` is not patched, it's idempotent.
+    patched = capture_sql_exception(original)
+    # only patch the one used in in py4j.java_gateway (call Java API)
+    py4j.java_gateway.get_return_value = patched
diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py
new file mode 100644
index 0000000000000..cbb573f226bbe
--- /dev/null
+++ b/python/pyspark/streaming/flume.py
@@ -0,0 +1,147 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+if sys.version >= "3":
+    from io import BytesIO
+else:
+    from StringIO import StringIO
+from py4j.java_gateway import Py4JJavaError
+
+from pyspark.storagelevel import StorageLevel
+from pyspark.serializers import PairDeserializer, NoOpSerializer, UTF8Deserializer, read_int
+from pyspark.streaming import DStream
+
+__all__ = ['FlumeUtils', 'utf8_decoder']
+
+
+def utf8_decoder(s):
+    """ Decode the unicode as UTF-8 """
+    return s and s.decode('utf-8')
+
+
+class FlumeUtils(object):
+
+    @staticmethod
+    def createStream(ssc, hostname, port,
+                     storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2,
+                     enableDecompression=False,
+                     bodyDecoder=utf8_decoder):
+        """
+        Create an input stream that pulls events from Flume.
+
+        :param ssc:  StreamingContext object
+        :param hostname:  Hostname of the slave machine to which the flume data will be sent
+        :param port:  Port of the slave machine to which the flume data will be sent
+        :param storageLevel:  Storage level to use for storing the received objects
+        :param enableDecompression:  Should netty server decompress input stream
+        :param bodyDecoder:  A function used to decode body (default is utf8_decoder)
+        :return: A DStream object
+        """
+        jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
+
+        try:
+            helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
+                .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
+            helper = helperClass.newInstance()
+            jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression)
+        except Py4JJavaError as e:
+            if 'ClassNotFoundException' in str(e.java_exception):
+                FlumeUtils._printErrorMsg(ssc.sparkContext)
+            raise e
+
+        return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)
+
+    @staticmethod
+    def createPollingStream(ssc, addresses,
+                            storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2,
+                            maxBatchSize=1000,
+                            parallelism=5,
+                            bodyDecoder=utf8_decoder):
+        """
+        Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent.
+        This stream will poll the sink for data and will pull events as they are available.
+
+        :param ssc:  StreamingContext object
+        :param addresses:  List of (host, port)s on which the Spark Sink is running.
+        :param storageLevel:  Storage level to use for storing the received objects
+        :param maxBatchSize:  The maximum number of events to be pulled from the Spark sink
+                              in a single RPC call
+        :param parallelism:  Number of concurrent requests this stream should send to the sink.
+                             Note that having a higher number of requests concurrently being pulled
+                             will result in this stream using more threads
+        :param bodyDecoder:  A function used to decode body (default is utf8_decoder)
+        :return: A DStream object
+        """
+        jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
+        hosts = []
+        ports = []
+        for (host, port) in addresses:
+            hosts.append(host)
+            ports.append(port)
+
+        try:
+            helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
+                .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
+            helper = helperClass.newInstance()
+            jstream = helper.createPollingStream(
+                ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism)
+        except Py4JJavaError as e:
+            if 'ClassNotFoundException' in str(e.java_exception):
+                FlumeUtils._printErrorMsg(ssc.sparkContext)
+            raise e
+
+        return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)
+
+    @staticmethod
+    def _toPythonDStream(ssc, jstream, bodyDecoder):
+        ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
+        stream = DStream(jstream, ssc, ser)
+
+        def func(event):
+            headersBytes = BytesIO(event[0]) if sys.version >= "3" else StringIO(event[0])
+            headers = {}
+            strSer = UTF8Deserializer()
+            for i in range(0, read_int(headersBytes)):
+                key = strSer.loads(headersBytes)
+                value = strSer.loads(headersBytes)
+                headers[key] = value
+            body = bodyDecoder(event[1])
+            return (headers, body)
+        return stream.map(func)
+
+    @staticmethod
+    def _printErrorMsg(sc):
+        print("""
+________________________________________________________________________________________________
+
+  Spark Streaming's Flume libraries not found in class path. Try one of the following.
+
+  1. Include the Flume library and its dependencies with in the
+     spark-submit command as
+
+     $ bin/spark-submit --packages org.apache.spark:spark-streaming-flume:%s ...
+
+  2. Download the JAR of the artifact from Maven Central http://search.maven.org/,
+     Group Id = org.apache.spark, Artifact Id = spark-streaming-flume-assembly, Version = %s.
+     Then, include the jar in the spark-submit command as
+
+     $ bin/spark-submit --jars  ...
+
+________________________________________________________________________________________________
+
+""" % (sc.version, sc.version))
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
index 10a859a532e28..33dd596335b47 100644
--- a/python/pyspark/streaming/kafka.py
+++ b/python/pyspark/streaming/kafka.py
@@ -21,6 +21,8 @@
 from pyspark.storagelevel import StorageLevel
 from pyspark.serializers import PairDeserializer, NoOpSerializer
 from pyspark.streaming import DStream
+from pyspark.streaming.dstream import TransformedDStream
+from pyspark.streaming.util import TransformFunction
 
 __all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder']
 
@@ -122,8 +124,9 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={},
             raise e
 
         ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
-        stream = DStream(jstream, ssc, ser)
-        return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+        stream = DStream(jstream, ssc, ser) \
+            .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+        return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
 
     @staticmethod
     def createRDD(sc, kafkaParams, offsetRanges, leaders={},
@@ -161,8 +164,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={},
             raise e
 
         ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
-        rdd = RDD(jrdd, sc, ser)
-        return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+        rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
+        return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer)
 
     @staticmethod
     def _printErrorMsg(sc):
@@ -200,14 +203,30 @@ def __init__(self, topic, partition, fromOffset, untilOffset):
         :param fromOffset: Inclusive starting offset.
         :param untilOffset: Exclusive ending offset.
         """
-        self._topic = topic
-        self._partition = partition
-        self._fromOffset = fromOffset
-        self._untilOffset = untilOffset
+        self.topic = topic
+        self.partition = partition
+        self.fromOffset = fromOffset
+        self.untilOffset = untilOffset
+
+    def __eq__(self, other):
+        if isinstance(other, self.__class__):
+            return (self.topic == other.topic
+                    and self.partition == other.partition
+                    and self.fromOffset == other.fromOffset
+                    and self.untilOffset == other.untilOffset)
+        else:
+            return False
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def __str__(self):
+        return "OffsetRange(topic: %s, partition: %d, range: [%d -> %d]" \
+               % (self.topic, self.partition, self.fromOffset, self.untilOffset)
 
     def _jOffsetRange(self, helper):
-        return helper.createOffsetRange(self._topic, self._partition, self._fromOffset,
-                                        self._untilOffset)
+        return helper.createOffsetRange(self.topic, self.partition, self.fromOffset,
+                                        self.untilOffset)
 
 
 class TopicAndPartition(object):
@@ -244,3 +263,87 @@ def __init__(self, host, port):
 
     def _jBroker(self, helper):
         return helper.createBroker(self._host, self._port)
+
+
+class KafkaRDD(RDD):
+    """
+    A Python wrapper of KafkaRDD, to provide additional information on normal RDD.
+    """
+
+    def __init__(self, jrdd, ctx, jrdd_deserializer):
+        RDD.__init__(self, jrdd, ctx, jrdd_deserializer)
+
+    def offsetRanges(self):
+        """
+        Get the OffsetRange of specific KafkaRDD.
+        :return: A list of OffsetRange
+        """
+        try:
+            helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
+                .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
+            helper = helperClass.newInstance()
+            joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
+        except Py4JJavaError as e:
+            if 'ClassNotFoundException' in str(e.java_exception):
+                KafkaUtils._printErrorMsg(self.ctx)
+            raise e
+
+        ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset())
+                  for o in joffsetRanges]
+        return ranges
+
+
+class KafkaDStream(DStream):
+    """
+    A Python wrapper of KafkaDStream
+    """
+
+    def __init__(self, jdstream, ssc, jrdd_deserializer):
+        DStream.__init__(self, jdstream, ssc, jrdd_deserializer)
+
+    def foreachRDD(self, func):
+        """
+        Apply a function to each RDD in this DStream.
+        """
+        if func.__code__.co_argcount == 1:
+            old_func = func
+            func = lambda r, rdd: old_func(rdd)
+        jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) \
+            .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser))
+        api = self._ssc._jvm.PythonDStream
+        api.callForeachRDD(self._jdstream, jfunc)
+
+    def transform(self, func):
+        """
+        Return a new DStream in which each RDD is generated by applying a function
+        on each RDD of this DStream.
+
+        `func` can have one argument of `rdd`, or have two arguments of
+        (`time`, `rdd`)
+        """
+        if func.__code__.co_argcount == 1:
+            oldfunc = func
+            func = lambda t, rdd: oldfunc(rdd)
+        assert func.__code__.co_argcount == 2, "func should take one or two arguments"
+
+        return KafkaTransformedDStream(self, func)
+
+
+class KafkaTransformedDStream(TransformedDStream):
+    """
+    Kafka specific wrapper of TransformedDStream to transform on Kafka RDD.
+    """
+
+    def __init__(self, prev, func):
+        TransformedDStream.__init__(self, prev, func)
+
+    @property
+    def _jdstream(self):
+        if self._jdstream_val is not None:
+            return self._jdstream_val
+
+        jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) \
+            .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser))
+        dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
+        self._jdstream_val = dstream.asJavaDStream()
+        return self._jdstream_val
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 91ce681fbe169..4ecae1e4bf282 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -38,6 +38,7 @@
 from pyspark.context import SparkConf, SparkContext, RDD
 from pyspark.streaming.context import StreamingContext
 from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition
+from pyspark.streaming.flume import FlumeUtils
 
 
 class PySparkStreamingTestCase(unittest.TestCase):
@@ -677,7 +678,220 @@ def test_kafka_rdd_with_leaders(self):
         rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders)
         self._validateRddResult(sendData, rdd)
 
-if __name__ == "__main__":
+    @unittest.skipIf(sys.version >= "3", "long type not support")
+    def test_kafka_rdd_get_offsetRanges(self):
+        """Test Python direct Kafka RDD get OffsetRanges."""
+        topic = self._randomTopic()
+        sendData = {"a": 3, "b": 4, "c": 5}
+        offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))]
+        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()}
+
+        self._kafkaTestUtils.createTopic(topic)
+        self._kafkaTestUtils.sendMessages(topic, sendData)
+        rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
+        self.assertEqual(offsetRanges, rdd.offsetRanges())
+
+    @unittest.skipIf(sys.version >= "3", "long type not support")
+    def test_kafka_direct_stream_foreach_get_offsetRanges(self):
+        """Test the Python direct Kafka stream foreachRDD get offsetRanges."""
+        topic = self._randomTopic()
+        sendData = {"a": 1, "b": 2, "c": 3}
+        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
+                       "auto.offset.reset": "smallest"}
+
+        self._kafkaTestUtils.createTopic(topic)
+        self._kafkaTestUtils.sendMessages(topic, sendData)
+
+        stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
+
+        offsetRanges = []
+
+        def getOffsetRanges(_, rdd):
+            for o in rdd.offsetRanges():
+                offsetRanges.append(o)
+
+        stream.foreachRDD(getOffsetRanges)
+        self.ssc.start()
+        self.wait_for(offsetRanges, 1)
+
+        self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
+
+    @unittest.skipIf(sys.version >= "3", "long type not support")
+    def test_kafka_direct_stream_transform_get_offsetRanges(self):
+        """Test the Python direct Kafka stream transform get offsetRanges."""
+        topic = self._randomTopic()
+        sendData = {"a": 1, "b": 2, "c": 3}
+        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
+                       "auto.offset.reset": "smallest"}
+
+        self._kafkaTestUtils.createTopic(topic)
+        self._kafkaTestUtils.sendMessages(topic, sendData)
+
+        stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
+
+        offsetRanges = []
+
+        def transformWithOffsetRanges(rdd):
+            for o in rdd.offsetRanges():
+                offsetRanges.append(o)
+            return rdd
+
+        stream.transform(transformWithOffsetRanges).foreachRDD(lambda rdd: rdd.count())
+        self.ssc.start()
+        self.wait_for(offsetRanges, 1)
+
+        self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
+
+
+class FlumeStreamTests(PySparkStreamingTestCase):
+    timeout = 20  # seconds
+    duration = 1
+
+    def setUp(self):
+        super(FlumeStreamTests, self).setUp()
+
+        utilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
+            .loadClass("org.apache.spark.streaming.flume.FlumeTestUtils")
+        self._utils = utilsClz.newInstance()
+
+    def tearDown(self):
+        if self._utils is not None:
+            self._utils.close()
+            self._utils = None
+
+        super(FlumeStreamTests, self).tearDown()
+
+    def _startContext(self, n, compressed):
+        # Start the StreamingContext and also collect the result
+        dstream = FlumeUtils.createStream(self.ssc, "localhost", self._utils.getTestPort(),
+                                          enableDecompression=compressed)
+        result = []
+
+        def get_output(_, rdd):
+            for event in rdd.collect():
+                if len(result) < n:
+                    result.append(event)
+        dstream.foreachRDD(get_output)
+        self.ssc.start()
+        return result
+
+    def _validateResult(self, input, result):
+        # Validate both the header and the body
+        header = {"test": "header"}
+        self.assertEqual(len(input), len(result))
+        for i in range(0, len(input)):
+            self.assertEqual(header, result[i][0])
+            self.assertEqual(input[i], result[i][1])
+
+    def _writeInput(self, input, compressed):
+        # Try to write input to the receiver until success or timeout
+        start_time = time.time()
+        while True:
+            try:
+                self._utils.writeInput(input, compressed)
+                break
+            except:
+                if time.time() - start_time < self.timeout:
+                    time.sleep(0.01)
+                else:
+                    raise
+
+    def test_flume_stream(self):
+        input = [str(i) for i in range(1, 101)]
+        result = self._startContext(len(input), False)
+        self._writeInput(input, False)
+        self.wait_for(result, len(input))
+        self._validateResult(input, result)
+
+    def test_compressed_flume_stream(self):
+        input = [str(i) for i in range(1, 101)]
+        result = self._startContext(len(input), True)
+        self._writeInput(input, True)
+        self.wait_for(result, len(input))
+        self._validateResult(input, result)
+
+
+class FlumePollingStreamTests(PySparkStreamingTestCase):
+    timeout = 20  # seconds
+    duration = 1
+    maxAttempts = 5
+
+    def setUp(self):
+        utilsClz = \
+            self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
+                .loadClass("org.apache.spark.streaming.flume.PollingFlumeTestUtils")
+        self._utils = utilsClz.newInstance()
+
+    def tearDown(self):
+        if self._utils is not None:
+            self._utils.close()
+            self._utils = None
+
+    def _writeAndVerify(self, ports):
+        # Set up the streaming context and input streams
+        ssc = StreamingContext(self.sc, self.duration)
+        try:
+            addresses = [("localhost", port) for port in ports]
+            dstream = FlumeUtils.createPollingStream(
+                ssc,
+                addresses,
+                maxBatchSize=self._utils.eventsPerBatch(),
+                parallelism=5)
+            outputBuffer = []
+
+            def get_output(_, rdd):
+                for e in rdd.collect():
+                    outputBuffer.append(e)
+
+            dstream.foreachRDD(get_output)
+            ssc.start()
+            self._utils.sendDatAndEnsureAllDataHasBeenReceived()
+
+            self.wait_for(outputBuffer, self._utils.getTotalEvents())
+            outputHeaders = [event[0] for event in outputBuffer]
+            outputBodies = [event[1] for event in outputBuffer]
+            self._utils.assertOutput(outputHeaders, outputBodies)
+        finally:
+            ssc.stop(False)
+
+    def _testMultipleTimes(self, f):
+        attempt = 0
+        while True:
+            try:
+                f()
+                break
+            except:
+                attempt += 1
+                if attempt >= self.maxAttempts:
+                    raise
+                else:
+                    import traceback
+                    traceback.print_exc()
+
+    def _testFlumePolling(self):
+        try:
+            port = self._utils.startSingleSink()
+            self._writeAndVerify([port])
+            self._utils.assertChannelsAreEmpty()
+        finally:
+            self._utils.close()
+
+    def _testFlumePollingMultipleHosts(self):
+        try:
+            port = self._utils.startSingleSink()
+            self._writeAndVerify([port])
+            self._utils.assertChannelsAreEmpty()
+        finally:
+            self._utils.close()
+
+    def test_flume_polling(self):
+        self._testMultipleTimes(self._testFlumePolling)
+
+    def test_flume_polling_multiple_hosts(self):
+        self._testMultipleTimes(self._testFlumePollingMultipleHosts)
+
+
+def search_kafka_assembly_jar():
     SPARK_HOME = os.environ["SPARK_HOME"]
     kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly")
     jars = glob.glob(
@@ -692,5 +906,30 @@ def test_kafka_rdd_with_leaders(self):
         raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please "
                          "remove all but one") % kafka_assembly_dir)
     else:
-        os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars[0]
+        return jars[0]
+
+
+def search_flume_assembly_jar():
+    SPARK_HOME = os.environ["SPARK_HOME"]
+    flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly")
+    jars = glob.glob(
+        os.path.join(flume_assembly_dir, "target/scala-*/spark-streaming-flume-assembly-*.jar"))
+    if not jars:
+        raise Exception(
+            ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) +
+            "You need to build Spark with "
+            "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or "
+            "'build/mvn package' before running this test")
+    elif len(jars) > 1:
+        raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please "
+                         "remove all but one") % flume_assembly_dir)
+    else:
+        return jars[0]
+
+if __name__ == "__main__":
+    kafka_assembly_jar = search_kafka_assembly_jar()
+    flume_assembly_jar = search_flume_assembly_jar()
+    jars = "%s,%s" % (kafka_assembly_jar, flume_assembly_jar)
+
+    os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars
     unittest.main()
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
index a9bfec2aab8fc..b20613b1283bd 100644
--- a/python/pyspark/streaming/util.py
+++ b/python/pyspark/streaming/util.py
@@ -37,6 +37,11 @@ def __init__(self, ctx, func, *deserializers):
         self.ctx = ctx
         self.func = func
         self.deserializers = deserializers
+        self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
+
+    def rdd_wrapper(self, func):
+        self._rdd_wrapper = func
+        return self
 
     def call(self, milliseconds, jrdds):
         try:
@@ -51,7 +56,7 @@ def call(self, milliseconds, jrdds):
             if len(sers) < len(jrdds):
                 sers += (sers[0],) * (len(jrdds) - len(sers))
 
-            rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None
+            rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None
                     for jrdd, ser in zip(jrdds, sers)]
             t = datetime.fromtimestamp(milliseconds / 1000.0)
             r = self.func(t, *rdds)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 17256dfc95744..c5c0add49d02c 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -885,6 +885,18 @@ def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
             for size in sizes:
                 self.assertGreater(size, 0)
 
+    def test_pipe_functions(self):
+        data = ['1', '2', '3']
+        rdd = self.sc.parallelize(data)
+        with QuietTest(self.sc):
+            self.assertEqual([], rdd.pipe('cc').collect())
+            self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect)
+        result = rdd.pipe('cat').collect()
+        result.sort()
+        [self.assertEqual(x, y) for x, y in zip(data, result)]
+        self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
+        self.assertEqual([], rdd.pipe('grep 4').collect())
+
 
 class ProfilerTests(PySparkTestCase):
 
diff --git a/python/run-tests.py b/python/run-tests.py
index 7d485b500ee3a..cc560779373b3 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -18,12 +18,36 @@
 #
 
 from __future__ import print_function
+import logging
 from optparse import OptionParser
 import os
 import re
 import subprocess
 import sys
+import tempfile
+from threading import Thread, Lock
 import time
+if sys.version < '3':
+    import Queue
+else:
+    import queue as Queue
+if sys.version_info >= (2, 7):
+    subprocess_check_output = subprocess.check_output
+else:
+    # SPARK-8763
+    # backported from subprocess module in Python 2.7
+    def subprocess_check_output(*popenargs, **kwargs):
+        if 'stdout' in kwargs:
+            raise ValueError('stdout argument not allowed, it will be overridden.')
+        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
+        output, unused_err = process.communicate()
+        retcode = process.poll()
+        if retcode:
+            cmd = kwargs.get("args")
+            if cmd is None:
+                cmd = popenargs[0]
+            raise subprocess.CalledProcessError(retcode, cmd, output=output)
+        return output
 
 
 # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module
@@ -43,34 +67,56 @@ def print_red(text):
 
 
 LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log")
+FAILURE_REPORTING_LOCK = Lock()
+LOGGER = logging.getLogger()
 
 
 def run_individual_python_test(test_name, pyspark_python):
-    env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}
-    print("    Running test: %s ..." % test_name, end='')
+    env = dict(os.environ)
+    env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)})
+    LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name)
     start_time = time.time()
-    with open(LOG_FILE, 'a') as log_file:
-        retcode = subprocess.call(
+    try:
+        per_test_output = tempfile.TemporaryFile()
+        retcode = subprocess.Popen(
             [os.path.join(SPARK_HOME, "bin/pyspark"), test_name],
-            stderr=log_file, stdout=log_file, env=env)
+            stderr=per_test_output, stdout=per_test_output, env=env).wait()
+    except:
+        LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python)
+        # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
+        # this code is invoked from a thread other than the main thread.
+        os._exit(1)
     duration = time.time() - start_time
     # Exit on the first failure.
     if retcode != 0:
-        with open(LOG_FILE, 'r') as log_file:
-            for line in log_file:
-                if not re.match('[0-9]+', line):
-                    print(line, end='')
-        print_red("\nHad test failures in %s; see logs." % test_name)
-        exit(-1)
+        try:
+            with FAILURE_REPORTING_LOCK:
+                with open(LOG_FILE, 'ab') as log_file:
+                    per_test_output.seek(0)
+                    log_file.writelines(per_test_output)
+                per_test_output.seek(0)
+                for line in per_test_output:
+                    decoded_line = line.decode()
+                    if not re.match('[0-9]+', decoded_line):
+                        print(decoded_line, end='')
+                per_test_output.close()
+        except:
+            LOGGER.exception("Got an exception while trying to print failed test output")
+        finally:
+            print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python))
+            # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
+            # this code is invoked from a thread other than the main thread.
+            os._exit(-1)
     else:
-        print("ok (%is)" % duration)
+        per_test_output.close()
+        LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration)
 
 
 def get_default_python_executables():
     python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)]
     if "python2.6" not in python_execs:
-        print("WARNING: Not testing against `python2.6` because it could not be found; falling"
-              " back to `python` instead")
+        LOGGER.warning("Not testing against `python2.6` because it could not be found; falling"
+                       " back to `python` instead")
         python_execs.insert(0, "python")
     return python_execs
 
@@ -88,16 +134,31 @@ def parse_opts():
         default=",".join(sorted(python_modules.keys())),
         help="A comma-separated list of Python modules to test (default: %default)"
     )
+    parser.add_option(
+        "-p", "--parallelism", type="int", default=4,
+        help="The number of suites to test in parallel (default %default)"
+    )
+    parser.add_option(
+        "--verbose", action="store_true",
+        help="Enable additional debug logging"
+    )
 
     (opts, args) = parser.parse_args()
     if args:
         parser.error("Unsupported arguments: %s" % ' '.join(args))
+    if opts.parallelism < 1:
+        parser.error("Parallelism cannot be less than 1")
     return opts
 
 
 def main():
     opts = parse_opts()
-    print("Running PySpark tests. Output is in python/%s" % LOG_FILE)
+    if (opts.verbose):
+        log_level = logging.DEBUG
+    else:
+        log_level = logging.INFO
+    logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
+    LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE)
     if os.path.exists(LOG_FILE):
         os.remove(LOG_FILE)
     python_execs = opts.python_executables.split(',')
@@ -108,24 +169,45 @@ def main():
         else:
             print("Error: unrecognized module %s" % module_name)
             sys.exit(-1)
-    print("Will test against the following Python executables: %s" % python_execs)
-    print("Will test the following Python modules: %s" % [x.name for x in modules_to_test])
+    LOGGER.info("Will test against the following Python executables: %s", python_execs)
+    LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])
 
-    start_time = time.time()
+    task_queue = Queue.Queue()
     for python_exec in python_execs:
-        python_implementation = subprocess.check_output(
+        python_implementation = subprocess_check_output(
             [python_exec, "-c", "import platform; print(platform.python_implementation())"],
             universal_newlines=True).strip()
-        print("Testing with `%s`: " % python_exec, end='')
-        subprocess.call([python_exec, "--version"])
-
+        LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation)
+        LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output(
+            [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip())
         for module in modules_to_test:
             if python_implementation not in module.blacklisted_python_implementations:
-                print("Running %s tests ..." % module.name)
                 for test_goal in module.python_test_goals:
-                    run_individual_python_test(test_goal, python_exec)
+                    task_queue.put((python_exec, test_goal))
+
+    def process_queue(task_queue):
+        while True:
+            try:
+                (python_exec, test_goal) = task_queue.get_nowait()
+            except Queue.Empty:
+                break
+            try:
+                run_individual_python_test(test_goal, python_exec)
+            finally:
+                task_queue.task_done()
+
+    start_time = time.time()
+    for _ in range(opts.parallelism):
+        worker = Thread(target=process_queue, args=(task_queue,))
+        worker.daemon = True
+        worker.start()
+    try:
+        task_queue.join()
+    except (KeyboardInterrupt, SystemExit):
+        print_red("Exiting due to interrupt")
+        sys.exit(-1)
     total_duration = time.time() - start_time
-    print("Tests passed in %i seconds" % total_duration)
+    LOGGER.info("Tests passed in %i seconds", total_duration)
 
 
 if __name__ == "__main__":
diff --git a/repl/pom.xml b/repl/pom.xml
index 370b2bc2fa8ed..70c9bd7c01296 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -38,11 +38,6 @@
   
 
   
-    
-      ${jline.groupid}
-      jline
-      ${jline.version}
-    
     
       org.apache.spark
       spark-core_${scala.binary.version}
@@ -161,6 +156,20 @@
     
   
   
+    
+      scala-2.10
+      
+        !scala-2.11
+      
+      
+        
+          ${jline.groupid}
+          jline
+          ${jline.version}
+        
+      
+    
+
     
       scala-2.11
       
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
index 6480e2d24e044..24fbbc12c08da 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
@@ -39,6 +39,8 @@ class SparkCommandLine(args: List[String], override val settings: Settings)
   }
 
   def this(args: List[String]) {
+    // scalastyle:off println
     this(args, str => Console.println("Error: " + str))
+    // scalastyle:on println
   }
 }
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 2b235525250c2..8f7f9074d3f03 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -1101,7 +1101,9 @@ object SparkILoop extends Logging {
             val s = super.readLine()
             // helping out by printing the line being interpreted.
             if (s != null)
+              // scalastyle:off println
               output.println(s)
+              // scalastyle:on println
             s
           }
         }
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
index 05faef8786d2c..bd3314d94eed6 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
@@ -80,11 +80,13 @@ private[repl] trait SparkILoopInit {
     if (!initIsComplete)
       withLock { while (!initIsComplete) initLoopCondition.await() }
     if (initError != null) {
+      // scalastyle:off println
       println("""
         |Failed to initialize the REPL due to an unexpected error.
         |This is a bug, please, report it along with the error diagnostics printed below.
         |%s.""".stripMargin.format(initError)
       )
+      // scalastyle:on println
       false
     } else true
   }
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
index 35fb625645022..8791618bd355e 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
@@ -1761,7 +1761,9 @@ object SparkIMain {
         if (intp.totalSilence) ()
         else super.printMessage(msg)
       }
+      // scalastyle:off println
       else Console.println(msg)
+      // scalastyle:on println
     }
   }
 }
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
index f4f4b626988e9..eed4a379afa60 100644
--- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
@@ -17,13 +17,14 @@
 
 package org.apache.spark.repl
 
+import java.io.File
+
+import scala.tools.nsc.Settings
+
 import org.apache.spark.util.Utils
 import org.apache.spark._
 import org.apache.spark.sql.SQLContext
 
-import scala.tools.nsc.Settings
-import scala.tools.nsc.interpreter.SparkILoop
-
 object Main extends Logging {
 
   val conf = new SparkConf()
@@ -32,7 +33,8 @@ object Main extends Logging {
   val outputDir = Utils.createTempDir(rootDir)
   val s = new Settings()
   s.processArguments(List("-Yrepl-class-based",
-    "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true)
+    "-Yrepl-outdir", s"${outputDir.getAbsolutePath}",
+    "-classpath", getAddedJars.mkString(File.pathSeparator)), true)
   val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf))
   var sparkContext: SparkContext = _
   var sqlContext: SQLContext = _
@@ -48,7 +50,6 @@ object Main extends Logging {
     Option(sparkContext).map(_.stop)
   }
 
-
   def getAddedJars: Array[String] = {
     val envJars = sys.env.get("ADD_JARS")
     if (envJars.isDefined) {
@@ -84,10 +85,9 @@ object Main extends Logging {
     val loader = Utils.getContextOrSparkClassLoader
     try {
       sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext])
-        .newInstance(sparkContext).asInstanceOf[SQLContext] 
+        .newInstance(sparkContext).asInstanceOf[SQLContext]
       logInfo("Created sql context (with Hive support)..")
-    }
-    catch {
+    } catch {
       case _: java.lang.ClassNotFoundException | _: java.lang.NoClassDefFoundError =>
         sqlContext = new SQLContext(sparkContext)
         logInfo("Created sql context..")
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
deleted file mode 100644
index 8e519fa67f649..0000000000000
--- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
+++ /dev/null
@@ -1,86 +0,0 @@
-/* NSC -- new Scala compiler
- * Copyright 2005-2013 LAMP/EPFL
- * @author  Paul Phillips
- */
-
-package scala.tools.nsc
-package interpreter
-
-import scala.tools.nsc.ast.parser.Tokens.EOF
-
-trait SparkExprTyper {
-  val repl: SparkIMain
-
-  import repl._
-  import global.{ reporter => _, Import => _, _ }
-  import naming.freshInternalVarName
-
-  def symbolOfLine(code: String): Symbol = {
-    def asExpr(): Symbol = {
-      val name  = freshInternalVarName()
-      // Typing it with a lazy val would give us the right type, but runs
-      // into compiler bugs with things like existentials, so we compile it
-      // behind a def and strip the NullaryMethodType which wraps the expr.
-      val line = "def " + name + " = " + code
-
-      interpretSynthetic(line) match {
-        case IR.Success =>
-          val sym0 = symbolOfTerm(name)
-          // drop NullaryMethodType
-          sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType)
-        case _          => NoSymbol
-      }
-    }
-    def asDefn(): Symbol = {
-      val old = repl.definedSymbolList.toSet
-
-      interpretSynthetic(code) match {
-        case IR.Success =>
-          repl.definedSymbolList filterNot old match {
-            case Nil        => NoSymbol
-            case sym :: Nil => sym
-            case syms       => NoSymbol.newOverloaded(NoPrefix, syms)
-          }
-        case _ => NoSymbol
-      }
-    }
-    def asError(): Symbol = {
-      interpretSynthetic(code)
-      NoSymbol
-    }
-    beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError()
-  }
-
-  private var typeOfExpressionDepth = 0
-  def typeOfExpression(expr: String, silent: Boolean = true): Type = {
-    if (typeOfExpressionDepth > 2) {
-      repldbg("Terminating typeOfExpression recursion for expression: " + expr)
-      return NoType
-    }
-    typeOfExpressionDepth += 1
-    // Don't presently have a good way to suppress undesirable success output
-    // while letting errors through, so it is first trying it silently: if there
-    // is an error, and errors are desired, then it re-evaluates non-silently
-    // to induce the error message.
-    try beSilentDuring(symbolOfLine(expr).tpe) match {
-      case NoType if !silent => symbolOfLine(expr).tpe // generate error
-      case tpe               => tpe
-    }
-    finally typeOfExpressionDepth -= 1
-  }
-
-  // This only works for proper types.
-  def typeOfTypeString(typeString: String): Type = {
-    def asProperType(): Option[Type] = {
-      val name = freshInternalVarName()
-      val line = "def %s: %s = ???" format (name, typeString)
-      interpretSynthetic(line) match {
-        case IR.Success =>
-          val sym0 = symbolOfTerm(name)
-          Some(sym0.asMethod.returnType)
-        case _          => None
-      }
-    }
-    beSilentDuring(asProperType()) getOrElse NoType
-  }
-}
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 7a5e94da5cbf3..bf609ff0f65fc 100644
--- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -1,88 +1,64 @@
-/* NSC -- new Scala compiler
- * Copyright 2005-2013 LAMP/EPFL
- * @author Alexander Spoon
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
  */
 
-package scala
-package tools.nsc
-package interpreter
+package org.apache.spark.repl
 
-import scala.language.{ implicitConversions, existentials }
-import scala.annotation.tailrec
-import Predef.{ println => _, _ }
-import interpreter.session._
-import StdReplTags._
-import scala.reflect.api.{Mirror, Universe, TypeCreator}
-import scala.util.Properties.{ jdkHome, javaVersion, versionString, javaVmName }
-import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream }
-import scala.reflect.{ClassTag, classTag}
-import scala.reflect.internal.util.{ BatchSourceFile, ScalaClassLoader }
-import ScalaClassLoader._
-import scala.reflect.io.{ File, Directory }
-import scala.tools.util._
-import scala.collection.generic.Clearable
-import scala.concurrent.{ ExecutionContext, Await, Future, future }
-import ExecutionContext.Implicits._
-import java.io.{ BufferedReader, FileReader }
+import java.io.{BufferedReader, FileReader}
 
-/** The Scala interactive shell.  It provides a read-eval-print loop
-  *  around the Interpreter class.
-  *  After instantiation, clients should call the main() method.
-  *
-  *  If no in0 is specified, then input will come from the console, and
-  *  the class will attempt to provide input editing feature such as
-  *  input history.
-  *
-  *  @author Moez A. Abdel-Gawad
-  *  @author  Lex Spoon
-  *  @version 1.2
-  */
-class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter)
-  extends AnyRef
-  with LoopCommands
-{
-  def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out)
-  def this() = this(None, new JPrintWriter(Console.out, true))
-//
-//  @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp
-//  @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: Interpreter): Unit = intp = i
-
-  var in: InteractiveReader = _   // the input stream from which commands come
-  var settings: Settings = _
-  var intp: SparkIMain = _
+import Predef.{println => _, _}
+import scala.util.Properties.{jdkHome, javaVersion, versionString, javaVmName}
 
-  var globalFuture: Future[Boolean] = _
+import scala.tools.nsc.interpreter.{JPrintWriter, ILoop}
+import scala.tools.nsc.Settings
+import scala.tools.nsc.util.stringFromStream
 
-  protected def asyncMessage(msg: String) {
-    if (isReplInfo || isReplPower)
-      echoAndRefresh(msg)
-  }
+/**
+ *  A Spark-specific interactive shell.
+ */
+class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter)
+    extends ILoop(in0, out) {
+  def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out)
+  def this() = this(None, new JPrintWriter(Console.out, true))
 
   def initializeSpark() {
     intp.beQuietDuring {
-      command( """
+      processLine("""
          @transient val sc = {
            val _sc = org.apache.spark.repl.Main.createSparkContext()
            println("Spark context available as sc.")
            _sc
          }
         """)
-      command( """
+      processLine("""
          @transient val sqlContext = {
            val _sqlContext = org.apache.spark.repl.Main.createSQLContext()
            println("SQL context available as sqlContext.")
            _sqlContext
          }
         """)
-      command("import org.apache.spark.SparkContext._")
-      command("import sqlContext.implicits._")
-      command("import sqlContext.sql")
-      command("import org.apache.spark.sql.functions._")
+      processLine("import org.apache.spark.SparkContext._")
+      processLine("import sqlContext.implicits._")
+      processLine("import sqlContext.sql")
+      processLine("import org.apache.spark.sql.functions._")
     }
   }
 
   /** Print a welcome message */
-  def printWelcome() {
+  override def printWelcome() {
     import org.apache.spark.SPARK_VERSION
     echo("""Welcome to
       ____              __
@@ -98,875 +74,42 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter)
     echo("Type :help for more information.")
   }
 
-  override def echoCommandMessage(msg: String) {
-    intp.reporter printUntruncatedMessage msg
-  }
-
-  // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals])
-  def history = in.history
-
-  // classpath entries added via :cp
-  var addedClasspath: String = ""
-
-  /** A reverse list of commands to replay if the user requests a :replay */
-  var replayCommandStack: List[String] = Nil
-
-  /** A list of commands to replay if the user requests a :replay */
-  def replayCommands = replayCommandStack.reverse
-
-  /** Record a command for replay should the user request a :replay */
-  def addReplay(cmd: String) = replayCommandStack ::= cmd
-
-  def savingReplayStack[T](body: => T): T = {
-    val saved = replayCommandStack
-    try body
-    finally replayCommandStack = saved
-  }
-  def savingReader[T](body: => T): T = {
-    val saved = in
-    try body
-    finally in = saved
-  }
-
-  /** Close the interpreter and set the var to null. */
-  def closeInterpreter() {
-    if (intp ne null) {
-      intp.close()
-      intp = null
-    }
-  }
-
-  class SparkILoopInterpreter extends SparkIMain(settings, out) {
-    outer =>
-
-    override lazy val formatting = new Formatting {
-      def prompt = SparkILoop.this.prompt
-    }
-    override protected def parentClassLoader =
-      settings.explicitParentLoader.getOrElse( classOf[SparkILoop].getClassLoader )
-  }
-
-  /** Create a new interpreter. */
-  def createInterpreter() {
-    if (addedClasspath != "")
-      settings.classpath append addedClasspath
-
-    intp = new SparkILoopInterpreter
-  }
-
-  /** print a friendly help message */
-  def helpCommand(line: String): Result = {
-    if (line == "") helpSummary()
-    else uniqueCommand(line) match {
-      case Some(lc) => echo("\n" + lc.help)
-      case _        => ambiguousError(line)
-    }
-  }
-  private def helpSummary() = {
-    val usageWidth  = commands map (_.usageMsg.length) max
-    val formatStr   = "%-" + usageWidth + "s %s"
-
-    echo("All commands can be abbreviated, e.g. :he instead of :help.")
-
-    commands foreach { cmd =>
-      echo(formatStr.format(cmd.usageMsg, cmd.help))
-    }
-  }
-  private def ambiguousError(cmd: String): Result = {
-    matchingCommands(cmd) match {
-      case Nil  => echo(cmd + ": no such command.  Type :help for help.")
-      case xs   => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?")
-    }
-    Result(keepRunning = true, None)
-  }
-  private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd)
-  private def uniqueCommand(cmd: String): Option[LoopCommand] = {
-    // this lets us add commands willy-nilly and only requires enough command to disambiguate
-    matchingCommands(cmd) match {
-      case List(x)  => Some(x)
-      // exact match OK even if otherwise appears ambiguous
-      case xs       => xs find (_.name == cmd)
-    }
-  }
-
-  /** Show the history */
-  lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") {
-    override def usage = "[num]"
-    def defaultLines = 20
-
-    def apply(line: String): Result = {
-      if (history eq NoHistory)
-        return "No history available."
-
-      val xs      = words(line)
-      val current = history.index
-      val count   = try xs.head.toInt catch { case _: Exception => defaultLines }
-      val lines   = history.asStrings takeRight count
-      val offset  = current - lines.size + 1
-
-      for ((line, index) <- lines.zipWithIndex)
-        echo("%3d  %s".format(index + offset, line))
-    }
-  }
-
-  // When you know you are most likely breaking into the middle
-  // of a line being typed.  This softens the blow.
-  protected def echoAndRefresh(msg: String) = {
-    echo("\n" + msg)
-    in.redrawLine()
-  }
-  protected def echo(msg: String) = {
-    out println msg
-    out.flush()
-  }
-
-  /** Search the history */
-  def searchHistory(_cmdline: String) {
-    val cmdline = _cmdline.toLowerCase
-    val offset  = history.index - history.size + 1
-
-    for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline)
-      echo("%d %s".format(index + offset, line))
-  }
-
-  private val currentPrompt = Properties.shellPromptString
-
-  /** Prompt to print when awaiting input */
-  def prompt = currentPrompt
-
   import LoopCommand.{ cmd, nullary }
 
-  /** Standard commands **/
-  lazy val standardCommands = List(
-    cmd("cp", "", "add a jar or directory to the classpath", addClasspath),
-    cmd("edit", "|", "edit history", editCommand),
-    cmd("help", "[command]", "print this summary or command-specific help", helpCommand),
-    historyCommand,
-    cmd("h?", "", "search the history", searchHistory),
-    cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand),
-    //cmd("implicits", "[-v]", "show the implicits in scope", intp.implicitsCommand),
-    cmd("javap", "", "disassemble a file or class name", javapCommand),
-    cmd("line", "|", "place line(s) at the end of history", lineCommand),
-    cmd("load", "", "interpret lines in a file", loadCommand),
-    cmd("paste", "[-raw] [path]", "enter paste mode or paste a file", pasteCommand),
-    // nullary("power", "enable power user mode", powerCmd),
-    nullary("quit", "exit the interpreter", () => Result(keepRunning = false, None)),
-    nullary("replay", "reset execution and replay all previous commands", replay),
-    nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand),
-    cmd("save", "", "save replayable session to a file", saveCommand),
-    shCommand,
-    cmd("settings", "[+|-]", "+enable/-disable flags, set compiler options", changeSettings),
-    nullary("silent", "disable/enable automatic printing of results", verbosity),
-//    cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand),
-//    cmd("kind", "[-v] ", "display the kind of expression's type", kindCommand),
-    nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand)
-  )
-
-  /** Power user commands */
-//  lazy val powerCommands: List[LoopCommand] = List(
-//    cmd("phase", "", "set the implicit phase for power commands", phaseCommand)
-//  )
-
-  private def importsCommand(line: String): Result = {
-    val tokens    = words(line)
-    val handlers  = intp.languageWildcardHandlers ++ intp.importHandlers
-
-    handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach {
-      case (handler, idx) =>
-        val (types, terms) = handler.importedSymbols partition (_.name.isTypeName)
-        val imps           = handler.implicitSymbols
-        val found          = tokens filter (handler importsSymbolNamed _)
-        val typeMsg        = if (types.isEmpty) "" else types.size + " types"
-        val termMsg        = if (terms.isEmpty) "" else terms.size + " terms"
-        val implicitMsg    = if (imps.isEmpty) "" else imps.size + " are implicit"
-        val foundMsg       = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "")
-        val statsMsg       = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")")
-
-        intp.reporter.printMessage("%2d) %-30s %s%s".format(
-          idx + 1,
-          handler.importString,
-          statsMsg,
-          foundMsg
-        ))
-    }
-  }
-
-  private def findToolsJar() = PathResolver.SupplementalLocations.platformTools
+  private val blockedCommands = Set("implicits", "javap", "power", "type", "kind")
 
-  private def addToolsJarToLoader() = {
-    val cl = findToolsJar() match {
-      case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader)
-      case _           => intp.classLoader
-    }
-    if (Javap.isAvailable(cl)) {
-      repldbg(":javap available.")
-      cl
-    }
-    else {
-      repldbg(":javap unavailable: no tools.jar at " + jdkHome)
-      intp.classLoader
-    }
-  }
-//
-//  protected def newJavap() =
-//    JavapClass(addToolsJarToLoader(), new IMain.ReplStrippingWriter(intp), Some(intp))
-//
-//  private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap())
-
-  // Still todo: modules.
-//  private def typeCommand(line0: String): Result = {
-//    line0.trim match {
-//      case "" => ":type [-v] "
-//      case s  => intp.typeCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ")
-//    }
-//  }
-
-//  private def kindCommand(expr: String): Result = {
-//    expr.trim match {
-//      case "" => ":kind [-v] "
-//      case s  => intp.kindCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ")
-//    }
-//  }
-
-  private def warningsCommand(): Result = {
-    if (intp.lastWarnings.isEmpty)
-      "Can't find any cached warnings."
-    else
-      intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) }
-  }
-
-  private def changeSettings(args: String): Result = {
-    def showSettings() = {
-      for (s <- settings.userSetSettings.toSeq.sorted) echo(s.toString)
-    }
-    def updateSettings() = {
-      // put aside +flag options
-      val (pluses, rest) = (args split "\\s+").toList partition (_.startsWith("+"))
-      val tmps = new Settings
-      val (ok, leftover) = tmps.processArguments(rest, processAll = true)
-      if (!ok) echo("Bad settings request.")
-      else if (leftover.nonEmpty) echo("Unprocessed settings.")
-      else {
-        // boolean flags set-by-user on tmp copy should be off, not on
-        val offs = tmps.userSetSettings filter (_.isInstanceOf[Settings#BooleanSetting])
-        val (minuses, nonbools) = rest partition (arg => offs exists (_ respondsTo arg))
-        // update non-flags
-        settings.processArguments(nonbools, processAll = true)
-        // also snag multi-value options for clearing, e.g. -Ylog: and -language:
-        for {
-          s <- settings.userSetSettings
-          if s.isInstanceOf[Settings#MultiStringSetting] || s.isInstanceOf[Settings#PhasesSetting]
-          if nonbools exists (arg => arg.head == '-' && arg.last == ':' && (s respondsTo arg.init))
-        } s match {
-          case c: Clearable => c.clear()
-          case _ =>
-        }
-        def update(bs: Seq[String], name: String=>String, setter: Settings#Setting=>Unit) = {
-          for (b <- bs)
-            settings.lookupSetting(name(b)) match {
-              case Some(s) =>
-                if (s.isInstanceOf[Settings#BooleanSetting]) setter(s)
-                else echo(s"Not a boolean flag: $b")
-              case _ =>
-                echo(s"Not an option: $b")
-            }
-        }
-        update(minuses, identity, _.tryToSetFromPropertyValue("false"))  // turn off
-        update(pluses, "-" + _.drop(1), _.tryToSet(Nil))                 // turn on
-      }
-    }
-    if (args.isEmpty) showSettings() else updateSettings()
-  }
-
-  private def javapCommand(line: String): Result = {
-//    if (javap == null)
-//      ":javap unavailable, no tools.jar at %s.  Set JDK_HOME.".format(jdkHome)
-//    else if (line == "")
-//      ":javap [-lcsvp] [path1 path2 ...]"
-//    else
-//      javap(words(line)) foreach { res =>
-//        if (res.isError) return "Failed: " + res.value
-//        else res.show()
-//      }
-  }
-
-  private def pathToPhaseWrapper = intp.originalPath("$r") + ".phased.atCurrent"
-
-  private def phaseCommand(name: String): Result = {
-//    val phased: Phased = power.phased
-//    import phased.NoPhaseName
-//
-//    if (name == "clear") {
-//      phased.set(NoPhaseName)
-//      intp.clearExecutionWrapper()
-//      "Cleared active phase."
-//    }
-//    else if (name == "") phased.get match {
-//      case NoPhaseName => "Usage: :phase  (e.g. typer, erasure.next, erasure+3)"
-//      case ph          => "Active phase is '%s'.  (To clear, :phase clear)".format(phased.get)
-//    }
-//    else {
-//      val what = phased.parse(name)
-//      if (what.isEmpty || !phased.set(what))
-//        "'" + name + "' does not appear to represent a valid phase."
-//      else {
-//        intp.setExecutionWrapper(pathToPhaseWrapper)
-//        val activeMessage =
-//          if (what.toString.length == name.length) "" + what
-//          else "%s (%s)".format(what, name)
-//
-//        "Active phase is now: " + activeMessage
-//      }
-//    }
-  }
+  /** Standard commands **/
+  lazy val sparkStandardCommands: List[SparkILoop.this.LoopCommand] =
+    standardCommands.filter(cmd => !blockedCommands(cmd.name))
 
   /** Available commands */
-  def commands: List[LoopCommand] = standardCommands ++ (
-    // if (isReplPower)
-    //  powerCommands
-    // else
-      Nil
-    )
-
-  val replayQuestionMessage =
-    """|That entry seems to have slain the compiler.  Shall I replay
-      |your session? I can re-run each line except the last one.
-      |[y/n]
-    """.trim.stripMargin
-
-  private val crashRecovery: PartialFunction[Throwable, Boolean] = {
-    case ex: Throwable =>
-      val (err, explain) = (
-        if (intp.isInitializeComplete)
-          (intp.global.throwableAsString(ex), "")
-        else
-          (ex.getMessage, "The compiler did not initialize.\n")
-        )
-      echo(err)
-
-      ex match {
-        case _: NoSuchMethodError | _: NoClassDefFoundError =>
-          echo("\nUnrecoverable error.")
-          throw ex
-        case _  =>
-          def fn(): Boolean =
-            try in.readYesOrNo(explain + replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() })
-            catch { case _: RuntimeException => false }
-
-          if (fn()) replay()
-          else echo("\nAbandoning crashed session.")
-      }
-      true
-  }
-
-  // return false if repl should exit
-  def processLine(line: String): Boolean = {
-    import scala.concurrent.duration._
-    Await.ready(globalFuture, 60.seconds)
-
-    (line ne null) && (command(line) match {
-      case Result(false, _)      => false
-      case Result(_, Some(line)) => addReplay(line) ; true
-      case _                     => true
-    })
-  }
-
-  private def readOneLine() = {
-    out.flush()
-    in readLine prompt
-  }
-
-  /** The main read-eval-print loop for the repl.  It calls
-    *  command() for each line of input, and stops when
-    *  command() returns false.
-    */
-  @tailrec final def loop() {
-    if ( try processLine(readOneLine()) catch crashRecovery )
-      loop()
-  }
-
-  /** interpret all lines from a specified file */
-  def interpretAllFrom(file: File) {
-    savingReader {
-      savingReplayStack {
-        file applyReader { reader =>
-          in = SimpleReader(reader, out, interactive = false)
-          echo("Loading " + file + "...")
-          loop()
-        }
-      }
-    }
-  }
-
-  /** create a new interpreter and replay the given commands */
-  def replay() {
-    reset()
-    if (replayCommandStack.isEmpty)
-      echo("Nothing to replay.")
-    else for (cmd <- replayCommands) {
-      echo("Replaying: " + cmd)  // flush because maybe cmd will have its own output
-      command(cmd)
-      echo("")
-    }
-  }
-  def resetCommand() {
-    echo("Resetting interpreter state.")
-    if (replayCommandStack.nonEmpty) {
-      echo("Forgetting this session history:\n")
-      replayCommands foreach echo
-      echo("")
-      replayCommandStack = Nil
-    }
-    if (intp.namedDefinedTerms.nonEmpty)
-      echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", "))
-    if (intp.definedTypes.nonEmpty)
-      echo("Forgetting defined types: " + intp.definedTypes.mkString(", "))
-
-    reset()
-  }
-  def reset() {
-    intp.reset()
-    unleashAndSetPhase()
-  }
-
-  def lineCommand(what: String): Result = editCommand(what, None)
-
-  // :edit id or :edit line
-  def editCommand(what: String): Result = editCommand(what, Properties.envOrNone("EDITOR"))
-
-  def editCommand(what: String, editor: Option[String]): Result = {
-    def diagnose(code: String) = {
-      echo("The edited code is incomplete!\n")
-      val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}")
-      if (errless) echo("The compiler reports no errors.")
-    }
-    def historicize(text: String) = history match {
-      case jlh: JLineHistory => text.lines foreach jlh.add ; jlh.moveToEnd() ; true
-      case _ => false
-    }
-    def edit(text: String): Result = editor match {
-      case Some(ed) =>
-        val tmp = File.makeTemp()
-        tmp.writeAll(text)
-        try {
-          val pr = new ProcessResult(s"$ed ${tmp.path}")
-          pr.exitCode match {
-            case 0 =>
-              tmp.safeSlurp() match {
-                case Some(edited) if edited.trim.isEmpty => echo("Edited text is empty.")
-                case Some(edited) =>
-                  echo(edited.lines map ("+" + _) mkString "\n")
-                  val res = intp interpret edited
-                  if (res == IR.Incomplete) diagnose(edited)
-                  else {
-                    historicize(edited)
-                    Result(lineToRecord = Some(edited), keepRunning = true)
-                  }
-                case None => echo("Can't read edited text. Did you delete it?")
-              }
-            case x => echo(s"Error exit from $ed ($x), ignoring")
-          }
-        } finally {
-          tmp.delete()
-        }
-      case None =>
-        if (historicize(text)) echo("Placing text in recent history.")
-        else echo(f"No EDITOR defined and you can't change history, echoing your text:%n$text")
-    }
-
-    // if what is a number, use it as a line number or range in history
-    def isNum = what forall (c => c.isDigit || c == '-' || c == '+')
-    // except that "-" means last value
-    def isLast = (what == "-")
-    if (isLast || !isNum) {
-      val name = if (isLast) intp.mostRecentVar else what
-      val sym = intp.symbolOfIdent(name)
-      intp.prevRequestList collectFirst { case r if r.defines contains sym => r } match {
-        case Some(req) => edit(req.line)
-        case None      => echo(s"No symbol in scope: $what")
-      }
-    } else try {
-      val s = what
-      // line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)
-      val (start, len) =
-        if ((s indexOf '+') > 0) {
-          val (a,b) = s splitAt (s indexOf '+')
-          (a.toInt, b.drop(1).toInt)
-        } else {
-          (s indexOf '-') match {
-            case -1 => (s.toInt, 1)
-            case 0  => val n = s.drop(1).toInt ; (history.index - n, n)
-            case _ if s.last == '-' => val n = s.init.toInt ; (n, history.index - n)
-            case i  => val n = s.take(i).toInt ; (n, s.drop(i+1).toInt - n)
-          }
-        }
-      import scala.collection.JavaConverters._
-      val index = (start - 1) max 0
-      val text = history match {
-        case jlh: JLineHistory => jlh.entries(index).asScala.take(len) map (_.value) mkString "\n"
-        case _ => history.asStrings.slice(index, index + len) mkString "\n"
-      }
-      edit(text)
-    } catch {
-      case _: NumberFormatException => echo(s"Bad range '$what'")
-        echo("Use line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)")
-    }
-  }
-
-  /** fork a shell and run a command */
-  lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") {
-    override def usage = ""
-    def apply(line: String): Result = line match {
-      case ""   => showUsage()
-      case _    =>
-        val toRun = s"new ${classOf[ProcessResult].getName}(${string2codeQuoted(line)})"
-        intp interpret toRun
-        ()
-    }
-  }
-
-  def withFile[A](filename: String)(action: File => A): Option[A] = {
-    val res = Some(File(filename)) filter (_.exists) map action
-    if (res.isEmpty) echo("That file does not exist")  // courtesy side-effect
-    res
-  }
-
-  def loadCommand(arg: String) = {
-    var shouldReplay: Option[String] = None
-    withFile(arg)(f => {
-      interpretAllFrom(f)
-      shouldReplay = Some(":load " + arg)
-    })
-    Result(keepRunning = true, shouldReplay)
-  }
-
-  def saveCommand(filename: String): Result = (
-    if (filename.isEmpty) echo("File name is required.")
-    else if (replayCommandStack.isEmpty) echo("No replay commands in session")
-    else File(filename).printlnAll(replayCommands: _*)
-    )
-
-  def addClasspath(arg: String): Unit = {
-    val f = File(arg).normalize
-    if (f.exists) {
-      addedClasspath = ClassPath.join(addedClasspath, f.path)
-      val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath)
-      echo("Added '%s'.  Your new classpath is:\n\"%s\"".format(f.path, totalClasspath))
-      replay()
-    }
-    else echo("The path '" + f + "' doesn't seem to exist.")
-  }
-
-  def powerCmd(): Result = {
-    if (isReplPower) "Already in power mode."
-    else enablePowerMode(isDuringInit = false)
-  }
-  def enablePowerMode(isDuringInit: Boolean) = {
-    replProps.power setValue true
-    unleashAndSetPhase()
-    // asyncEcho(isDuringInit, power.banner)
-  }
-  private def unleashAndSetPhase() {
-    if (isReplPower) {
-    //  power.unleash()
-      // Set the phase to "typer"
-      // intp beSilentDuring phaseCommand("typer")
-    }
-  }
-
-  def asyncEcho(async: Boolean, msg: => String) {
-    if (async) asyncMessage(msg)
-    else echo(msg)
-  }
-
-  def verbosity() = {
-    val old = intp.printResults
-    intp.printResults = !old
-    echo("Switched " + (if (old) "off" else "on") + " result printing.")
-  }
-
-  /** Run one command submitted by the user.  Two values are returned:
-    * (1) whether to keep running, (2) the line to record for replay,
-    * if any. */
-  def command(line: String): Result = {
-    if (line startsWith ":") {
-      val cmd = line.tail takeWhile (x => !x.isWhitespace)
-      uniqueCommand(cmd) match {
-        case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace))
-        case _        => ambiguousError(cmd)
-      }
-    }
-    else if (intp.global == null) Result(keepRunning = false, None)  // Notice failure to create compiler
-    else Result(keepRunning = true, interpretStartingWith(line))
-  }
-
-  private def readWhile(cond: String => Boolean) = {
-    Iterator continually in.readLine("") takeWhile (x => x != null && cond(x))
-  }
-
-  def pasteCommand(arg: String): Result = {
-    var shouldReplay: Option[String] = None
-    def result = Result(keepRunning = true, shouldReplay)
-    val (raw, file) =
-      if (arg.isEmpty) (false, None)
-      else {
-        val r = """(-raw)?(\s+)?([^\-]\S*)?""".r
-        arg match {
-          case r(flag, sep, name) =>
-            if (flag != null && name != null && sep == null)
-              echo(s"""I assume you mean "$flag $name"?""")
-            (flag != null, Option(name))
-          case _ =>
-            echo("usage: :paste -raw file")
-            return result
-        }
-      }
-    val code = file match {
-      case Some(name) =>
-        withFile(name)(f => {
-          shouldReplay = Some(s":paste $arg")
-          val s = f.slurp.trim
-          if (s.isEmpty) echo(s"File contains no code: $f")
-          else echo(s"Pasting file $f...")
-          s
-        }) getOrElse ""
-      case None =>
-        echo("// Entering paste mode (ctrl-D to finish)\n")
-        val text = (readWhile(_ => true) mkString "\n").trim
-        if (text.isEmpty) echo("\n// Nothing pasted, nothing gained.\n")
-        else echo("\n// Exiting paste mode, now interpreting.\n")
-        text
-    }
-    def interpretCode() = {
-      val res = intp interpret code
-      // if input is incomplete, let the compiler try to say why
-      if (res == IR.Incomplete) {
-        echo("The pasted code is incomplete!\n")
-        // Remembrance of Things Pasted in an object
-        val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}")
-        if (errless) echo("...but compilation found no error? Good luck with that.")
-      }
-    }
-    def compileCode() = {
-      val errless = intp compileSources new BatchSourceFile("", code)
-      if (!errless) echo("There were compilation errors!")
-    }
-    if (code.nonEmpty) {
-      if (raw) compileCode() else interpretCode()
-    }
-    result
-  }
-
-  private object paste extends Pasted {
-    val ContinueString = "     | "
-    val PromptString   = "scala> "
-
-    def interpret(line: String): Unit = {
-      echo(line.trim)
-      intp interpret line
-      echo("")
-    }
-
-    def transcript(start: String) = {
-      echo("\n// Detected repl transcript paste: ctrl-D to finish.\n")
-      apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim))
-    }
-  }
-  import paste.{ ContinueString, PromptString }
-
-  /** Interpret expressions starting with the first line.
-    * Read lines until a complete compilation unit is available
-    * or until a syntax error has been seen.  If a full unit is
-    * read, go ahead and interpret it.  Return the full string
-    * to be recorded for replay, if any.
-    */
-  def interpretStartingWith(code: String): Option[String] = {
-    // signal completion non-completion input has been received
-    in.completion.resetVerbosity()
-
-    def reallyInterpret = {
-      val reallyResult = intp.interpret(code)
-      (reallyResult, reallyResult match {
-        case IR.Error       => None
-        case IR.Success     => Some(code)
-        case IR.Incomplete  =>
-          if (in.interactive && code.endsWith("\n\n")) {
-            echo("You typed two blank lines.  Starting a new command.")
-            None
-          }
-          else in.readLine(ContinueString) match {
-            case null =>
-              // we know compilation is going to fail since we're at EOF and the
-              // parser thinks the input is still incomplete, but since this is
-              // a file being read non-interactively we want to fail.  So we send
-              // it straight to the compiler for the nice error message.
-              intp.compileString(code)
-              None
-
-            case line => interpretStartingWith(code + "\n" + line)
-          }
-      })
-    }
-
-    /** Here we place ourselves between the user and the interpreter and examine
-      *  the input they are ostensibly submitting.  We intervene in several cases:
-      *
-      *  1) If the line starts with "scala> " it is assumed to be an interpreter paste.
-      *  2) If the line starts with "." (but not ".." or "./") it is treated as an invocation
-      *     on the previous result.
-      *  3) If the Completion object's execute returns Some(_), we inject that value
-      *     and avoid the interpreter, as it's likely not valid scala code.
-      */
-    if (code == "") None
-    else if (!paste.running && code.trim.startsWith(PromptString)) {
-      paste.transcript(code)
-      None
-    }
-    else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") {
-      interpretStartingWith(intp.mostRecentVar + code)
-    }
-    else if (code.trim startsWith "//") {
-      // line comment, do nothing
-      None
-    }
-    else
-      reallyInterpret._2
-  }
-
-  // runs :load `file` on any files passed via -i
-  def loadFiles(settings: Settings) = settings match {
-    case settings: GenericRunnerSettings =>
-      for (filename <- settings.loadfiles.value) {
-        val cmd = ":load " + filename
-        command(cmd)
-        addReplay(cmd)
-        echo("")
-      }
-    case _ =>
-  }
-
-  /** Tries to create a JLineReader, falling back to SimpleReader:
-    *  unless settings or properties are such that it should start
-    *  with SimpleReader.
-    */
-  def chooseReader(settings: Settings): InteractiveReader = {
-    if (settings.Xnojline || Properties.isEmacsShell)
-      SimpleReader()
-    else try new JLineReader(
-      if (settings.noCompletion) NoCompletion
-      else new SparkJLineCompletion(intp)
-    )
-    catch {
-      case ex @ (_: Exception | _: NoClassDefFoundError) =>
-        echo("Failed to created JLineReader: " + ex + "\nFalling back to SimpleReader.")
-        SimpleReader()
-    }
-  }
-  protected def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] =
-    u.TypeTag[T](
-      m,
-      new TypeCreator {
-        def apply[U <: Universe with Singleton](m: Mirror[U]): U # Type =
-          m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type]
-      })
-
-  private def loopPostInit() {
-    // Bind intp somewhere out of the regular namespace where
-    // we can get at it in generated code.
-    intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfStaticClass[SparkIMain], classTag[SparkIMain]))
-    // Auto-run code via some setting.
-    ( replProps.replAutorunCode.option
-      flatMap (f => io.File(f).safeSlurp())
-      foreach (intp quietRun _)
-      )
-    // classloader and power mode setup
-    intp.setContextClassLoader()
-    if (isReplPower) {
-     // replProps.power setValue true
-     // unleashAndSetPhase()
-     // asyncMessage(power.banner)
-    }
-    // SI-7418 Now, and only now, can we enable TAB completion.
-    in match {
-      case x: JLineReader => x.consoleReader.postInit
-      case _              =>
-    }
-  }
-  def process(settings: Settings): Boolean = savingContextLoader {
-    this.settings = settings
-    createInterpreter()
-
-    // sets in to some kind of reader depending on environmental cues
-    in = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true))
-    globalFuture = future {
-      intp.initializeSynchronous()
-      loopPostInit()
-      !intp.reporter.hasErrors
-    }
-    import scala.concurrent.duration._
-    Await.ready(globalFuture, 10 seconds)
-    printWelcome()
+  override def commands: List[LoopCommand] = sparkStandardCommands
+
+  /** 
+   * We override `loadFiles` because we need to initialize Spark *before* the REPL
+   * sees any files, so that the Spark context is visible in those files. This is a bit of a
+   * hack, but there isn't another hook available to us at this point.
+   */
+  override def loadFiles(settings: Settings): Unit = {
     initializeSpark()
-    loadFiles(settings)
-
-    try loop()
-    catch AbstractOrMissingHandler()
-    finally closeInterpreter()
-
-    true
+    super.loadFiles(settings)
   }
-
-  @deprecated("Use `process` instead", "2.9.0")
-  def main(settings: Settings): Unit = process(settings) //used by sbt
 }
 
 object SparkILoop {
-  implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp
 
-  // Designed primarily for use by test code: take a String with a
-  // bunch of code, and prints out a transcript of what it would look
-  // like if you'd just typed it into the repl.
-  def runForTranscript(code: String, settings: Settings): String = {
-    import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
-
-    stringFromStream { ostream =>
-      Console.withOut(ostream) {
-        val output = new JPrintWriter(new OutputStreamWriter(ostream), true) {
-          override def write(str: String) = {
-            // completely skip continuation lines
-            if (str forall (ch => ch.isWhitespace || ch == '|')) ()
-            else super.write(str)
-          }
-        }
-        val input = new BufferedReader(new StringReader(code.trim + "\n")) {
-          override def readLine(): String = {
-            val s = super.readLine()
-            // helping out by printing the line being interpreted.
-            if (s != null)
-              output.println(s)
-            s
-          }
-        }
-        val repl = new SparkILoop(input, output)
-        if (settings.classpath.isDefault)
-          settings.classpath.value = sys.props("java.class.path")
-
-        repl process settings
-      }
-    }
-  }
-
-  /** Creates an interpreter loop with default settings and feeds
-    *  the given code to it as input.
-    */
+  /** 
+   * Creates an interpreter loop with default settings and feeds
+   * the given code to it as input.
+   */
   def run(code: String, sets: Settings = new Settings): String = {
     import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
 
     stringFromStream { ostream =>
       Console.withOut(ostream) {
-        val input    = new BufferedReader(new StringReader(code))
-        val output   = new JPrintWriter(new OutputStreamWriter(ostream), true)
-        val repl     = new SparkILoop(input, output)
+        val input = new BufferedReader(new StringReader(code))
+        val output = new JPrintWriter(new OutputStreamWriter(ostream), true)
+        val repl = new SparkILoop(input, output)
 
         if (sets.classpath.isDefault)
           sets.classpath.value = sys.props("java.class.path")
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala
deleted file mode 100644
index 1cb910f376060..0000000000000
--- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala
+++ /dev/null
@@ -1,1319 +0,0 @@
-/* NSC -- new Scala compiler
- * Copyright 2005-2013 LAMP/EPFL
- * @author  Martin Odersky
- */
-
-package scala
-package tools.nsc
-package interpreter
-
-import PartialFunction.cond
-import scala.language.implicitConversions
-import scala.beans.BeanProperty
-import scala.collection.mutable
-import scala.concurrent.{ Future, ExecutionContext }
-import scala.reflect.runtime.{ universe => ru }
-import scala.reflect.{ ClassTag, classTag }
-import scala.reflect.internal.util.{ BatchSourceFile, SourceFile }
-import scala.tools.util.PathResolver
-import scala.tools.nsc.io.AbstractFile
-import scala.tools.nsc.typechecker.{ TypeStrings, StructuredTypeStrings }
-import scala.tools.nsc.util.{ ScalaClassLoader, stringFromReader, stringFromWriter, StackTraceOps }
-import scala.tools.nsc.util.Exceptional.unwrap
-import javax.script.{AbstractScriptEngine, Bindings, ScriptContext, ScriptEngine, ScriptEngineFactory, ScriptException, CompiledScript, Compilable}
-
-/** An interpreter for Scala code.
-  *
-  *  The main public entry points are compile(), interpret(), and bind().
-  *  The compile() method loads a complete Scala file.  The interpret() method
-  *  executes one line of Scala code at the request of the user.  The bind()
-  *  method binds an object to a variable that can then be used by later
-  *  interpreted code.
-  *
-  *  The overall approach is based on compiling the requested code and then
-  *  using a Java classloader and Java reflection to run the code
-  *  and access its results.
-  *
-  *  In more detail, a single compiler instance is used
-  *  to accumulate all successfully compiled or interpreted Scala code.  To
-  *  "interpret" a line of code, the compiler generates a fresh object that
-  *  includes the line of code and which has public member(s) to export
-  *  all variables defined by that code.  To extract the result of an
-  *  interpreted line to show the user, a second "result object" is created
-  *  which imports the variables exported by the above object and then
-  *  exports members called "$eval" and "$print". To accomodate user expressions
-  *  that read from variables or methods defined in previous statements, "import"
-  *  statements are used.
-  *
-  *  This interpreter shares the strengths and weaknesses of using the
-  *  full compiler-to-Java.  The main strength is that interpreted code
-  *  behaves exactly as does compiled code, including running at full speed.
-  *  The main weakness is that redefining classes and methods is not handled
-  *  properly, because rebinding at the Java level is technically difficult.
-  *
-  *  @author Moez A. Abdel-Gawad
-  *  @author Lex Spoon
-  */
-class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Settings,
-  protected val out: JPrintWriter) extends AbstractScriptEngine with Compilable with SparkImports {
-  imain =>
-
-  setBindings(createBindings, ScriptContext.ENGINE_SCOPE)
-  object replOutput extends ReplOutput(settings.Yreploutdir) { }
-
-  @deprecated("Use replOutput.dir instead", "2.11.0")
-  def virtualDirectory = replOutput.dir
-  // Used in a test case.
-  def showDirectory() = replOutput.show(out)
-
-  private[nsc] var printResults               = true      // whether to print result lines
-  private[nsc] var totalSilence               = false     // whether to print anything
-  private var _initializeComplete             = false     // compiler is initialized
-  private var _isInitialized: Future[Boolean] = null      // set up initialization future
-  private var bindExceptions                  = true      // whether to bind the lastException variable
-  private var _executionWrapper               = ""        // code to be wrapped around all lines
-
-  /** We're going to go to some trouble to initialize the compiler asynchronously.
-    *  It's critical that nothing call into it until it's been initialized or we will
-    *  run into unrecoverable issues, but the perceived repl startup time goes
-    *  through the roof if we wait for it.  So we initialize it with a future and
-    *  use a lazy val to ensure that any attempt to use the compiler object waits
-    *  on the future.
-    */
-  private var _classLoader: util.AbstractFileClassLoader = null                              // active classloader
-  private val _compiler: ReplGlobal                 = newCompiler(settings, reporter)   // our private compiler
-
-  def compilerClasspath: Seq[java.net.URL] = (
-    if (isInitializeComplete) global.classPath.asURLs
-    else new PathResolver(settings).result.asURLs  // the compiler's classpath
-    )
-  def settings = initialSettings
-  // Run the code body with the given boolean settings flipped to true.
-  def withoutWarnings[T](body: => T): T = beQuietDuring {
-    val saved = settings.nowarn.value
-    if (!saved)
-      settings.nowarn.value = true
-
-    try body
-    finally if (!saved) settings.nowarn.value = false
-  }
-
-  /** construct an interpreter that reports to Console */
-  def this(settings: Settings, out: JPrintWriter) = this(null, settings, out)
-  def this(factory: ScriptEngineFactory, settings: Settings) = this(factory, settings, new NewLinePrintWriter(new ConsoleWriter, true))
-  def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true))
-  def this(factory: ScriptEngineFactory) = this(factory, new Settings())
-  def this() = this(new Settings())
-
-  lazy val formatting: Formatting = new Formatting {
-    val prompt = Properties.shellPromptString
-  }
-  lazy val reporter: SparkReplReporter = new SparkReplReporter(this)
-
-  import formatting._
-  import reporter.{ printMessage, printUntruncatedMessage }
-
-  // This exists mostly because using the reporter too early leads to deadlock.
-  private def echo(msg: String) { Console println msg }
-  private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }"))
-  private def _initialize() = {
-    try {
-      // if this crashes, REPL will hang its head in shame
-      val run = new _compiler.Run()
-      assert(run.typerPhase != NoPhase, "REPL requires a typer phase.")
-      run compileSources _initSources
-      _initializeComplete = true
-      true
-    }
-    catch AbstractOrMissingHandler()
-  }
-  private def tquoted(s: String) = "\"\"\"" + s + "\"\"\""
-  private val logScope = scala.sys.props contains "scala.repl.scope"
-  private def scopelog(msg: String) = if (logScope) Console.err.println(msg)
-
-  // argument is a thunk to execute after init is done
-  def initialize(postInitSignal: => Unit) {
-    synchronized {
-      if (_isInitialized == null) {
-        _isInitialized =
-          Future(try _initialize() finally postInitSignal)(ExecutionContext.global)
-      }
-    }
-  }
-  def initializeSynchronous(): Unit = {
-    if (!isInitializeComplete) {
-      _initialize()
-      assert(global != null, global)
-    }
-  }
-  def isInitializeComplete = _initializeComplete
-
-  lazy val global: Global = {
-    if (!isInitializeComplete) _initialize()
-    _compiler
-  }
-
-  import global._
-  import definitions.{ ObjectClass, termMember, dropNullaryMethod}
-
-  lazy val runtimeMirror = ru.runtimeMirror(classLoader)
-
-  private def noFatal(body: => Symbol): Symbol = try body catch { case _: FatalError => NoSymbol }
-
-  def getClassIfDefined(path: String)  = (
-    noFatal(runtimeMirror staticClass path)
-      orElse noFatal(rootMirror staticClass path)
-    )
-  def getModuleIfDefined(path: String) = (
-    noFatal(runtimeMirror staticModule path)
-      orElse noFatal(rootMirror staticModule path)
-    )
-
-  implicit class ReplTypeOps(tp: Type) {
-    def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp)
-  }
-
-  // TODO: If we try to make naming a lazy val, we run into big time
-  // scalac unhappiness with what look like cycles.  It has not been easy to
-  // reduce, but name resolution clearly takes different paths.
-  object naming extends {
-    val global: imain.global.type = imain.global
-  } with Naming {
-    // make sure we don't overwrite their unwisely named res3 etc.
-    def freshUserTermName(): TermName = {
-      val name = newTermName(freshUserVarName())
-      if (replScope containsName name) freshUserTermName()
-      else name
-    }
-    def isInternalTermName(name: Name) = isInternalVarName("" + name)
-  }
-  import naming._
-
-  object deconstruct extends {
-    val global: imain.global.type = imain.global
-  } with StructuredTypeStrings
-
-  lazy val memberHandlers = new {
-    val intp: imain.type = imain
-  } with SparkMemberHandlers
-  import memberHandlers._
-
-  /** Temporarily be quiet */
-  def beQuietDuring[T](body: => T): T = {
-    val saved = printResults
-    printResults = false
-    try body
-    finally printResults = saved
-  }
-  def beSilentDuring[T](operation: => T): T = {
-    val saved = totalSilence
-    totalSilence = true
-    try operation
-    finally totalSilence = saved
-  }
-
-  def quietRun[T](code: String) = beQuietDuring(interpret(code))
-
-  /** takes AnyRef because it may be binding a Throwable or an Exceptional */
-  private def withLastExceptionLock[T](body: => T, alt: => T): T = {
-    assert(bindExceptions, "withLastExceptionLock called incorrectly.")
-    bindExceptions = false
-
-    try     beQuietDuring(body)
-    catch   logAndDiscard("withLastExceptionLock", alt)
-    finally bindExceptions = true
-  }
-
-  def executionWrapper = _executionWrapper
-  def setExecutionWrapper(code: String) = _executionWrapper = code
-  def clearExecutionWrapper() = _executionWrapper = ""
-
-  /** interpreter settings */
-  lazy val isettings = new SparkISettings(this)
-
-  /** Instantiate a compiler.  Overridable. */
-  protected def newCompiler(settings: Settings, reporter: reporters.Reporter): ReplGlobal = {
-    settings.outputDirs setSingleOutput replOutput.dir
-    settings.exposeEmptyPackage.value = true
-    new Global(settings, reporter) with ReplGlobal { override def toString: String = "" }
-  }
-
-  /** Parent classloader.  Overridable. */
-  protected def parentClassLoader: ClassLoader =
-    settings.explicitParentLoader.getOrElse( this.getClass.getClassLoader() )
-
-  /* A single class loader is used for all commands interpreted by this Interpreter.
-     It would also be possible to create a new class loader for each command
-     to interpret.  The advantages of the current approach are:
-
-       - Expressions are only evaluated one time.  This is especially
-         significant for I/O, e.g. "val x = Console.readLine"
-
-     The main disadvantage is:
-
-       - Objects, classes, and methods cannot be rebound.  Instead, definitions
-         shadow the old ones, and old code objects refer to the old
-         definitions.
-  */
-  def resetClassLoader() = {
-    repldbg("Setting new classloader: was " + _classLoader)
-    _classLoader = null
-    ensureClassLoader()
-  }
-  final def ensureClassLoader() {
-    if (_classLoader == null)
-      _classLoader = makeClassLoader()
-  }
-  def classLoader: util.AbstractFileClassLoader = {
-    ensureClassLoader()
-    _classLoader
-  }
-
-  def backticked(s: String): String = (
-    (s split '.').toList map {
-      case "_"                               => "_"
-      case s if nme.keywords(newTermName(s)) => s"`$s`"
-      case s                                 => s
-    } mkString "."
-    )
-  def readRootPath(readPath: String) = getModuleIfDefined(readPath)
-
-  abstract class PhaseDependentOps {
-    def shift[T](op: => T): T
-
-    def path(name: => Name): String = shift(path(symbolOfName(name)))
-    def path(sym: Symbol): String = backticked(shift(sym.fullName))
-    def sig(sym: Symbol): String  = shift(sym.defString)
-  }
-  object typerOp extends PhaseDependentOps {
-    def shift[T](op: => T): T = exitingTyper(op)
-  }
-  object flatOp extends PhaseDependentOps {
-    def shift[T](op: => T): T = exitingFlatten(op)
-  }
-
-  def originalPath(name: String): String = originalPath(name: TermName)
-  def originalPath(name: Name): String   = typerOp path name
-  def originalPath(sym: Symbol): String  = typerOp path sym
-  def flatPath(sym: Symbol): String      = flatOp shift sym.javaClassName
-  def translatePath(path: String) = {
-    val sym = if (path endsWith "$") symbolOfTerm(path.init) else symbolOfIdent(path)
-    sym.toOption map flatPath
-  }
-  def translateEnclosingClass(n: String) = symbolOfTerm(n).enclClass.toOption map flatPath
-
-  private class TranslatingClassLoader(parent: ClassLoader) extends util.AbstractFileClassLoader(replOutput.dir, parent) {
-    /** Overridden here to try translating a simple name to the generated
-      *  class name if the original attempt fails.  This method is used by
-      *  getResourceAsStream as well as findClass.
-      */
-    override protected def findAbstractFile(name: String): AbstractFile =
-      super.findAbstractFile(name) match {
-        case null if _initializeComplete => translatePath(name) map (super.findAbstractFile(_)) orNull
-        case file => file
-      }
-  }
-  private def makeClassLoader(): util.AbstractFileClassLoader =
-    new TranslatingClassLoader(parentClassLoader match {
-      case null   => ScalaClassLoader fromURLs compilerClasspath
-      case p      => new ScalaClassLoader.URLClassLoader(compilerClasspath, p)
-    })
-
-  // Set the current Java "context" class loader to this interpreter's class loader
-  def setContextClassLoader() = classLoader.setAsContext()
-
-  def allDefinedNames: List[Name]  = exitingTyper(replScope.toList.map(_.name).sorted)
-  def unqualifiedIds: List[String] = allDefinedNames map (_.decode) sorted
-
-  /** Most recent tree handled which wasn't wholly synthetic. */
-  private def mostRecentlyHandledTree: Option[Tree] = {
-    prevRequests.reverse foreach { req =>
-      req.handlers.reverse foreach {
-        case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member)
-        case _ => ()
-      }
-    }
-    None
-  }
-
-  private def updateReplScope(sym: Symbol, isDefined: Boolean) {
-    def log(what: String) {
-      val mark = if (sym.isType) "t " else "v "
-      val name = exitingTyper(sym.nameString)
-      val info = cleanTypeAfterTyper(sym)
-      val defn = sym defStringSeenAs info
-
-      scopelog(f"[$mark$what%6s] $name%-25s $defn%s")
-    }
-    if (ObjectClass isSubClass sym.owner) return
-    // unlink previous
-    replScope lookupAll sym.name foreach { sym =>
-      log("unlink")
-      replScope unlink sym
-    }
-    val what = if (isDefined) "define" else "import"
-    log(what)
-    replScope enter sym
-  }
-
-  def recordRequest(req: Request) {
-    if (req == null)
-      return
-
-    prevRequests += req
-
-    // warning about serially defining companions.  It'd be easy
-    // enough to just redefine them together but that may not always
-    // be what people want so I'm waiting until I can do it better.
-    exitingTyper {
-      req.defines filterNot (s => req.defines contains s.companionSymbol) foreach { newSym =>
-        val oldSym = replScope lookup newSym.name.companionName
-        if (Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }) {
-          replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.")
-          replwarn("Companions must be defined together; you may wish to use :paste mode for this.")
-        }
-      }
-    }
-    exitingTyper {
-      req.imports foreach (sym => updateReplScope(sym, isDefined = false))
-      req.defines foreach (sym => updateReplScope(sym, isDefined = true))
-    }
-  }
-
-  private[nsc] def replwarn(msg: => String) {
-    if (!settings.nowarnings)
-      printMessage(msg)
-  }
-
-  def compileSourcesKeepingRun(sources: SourceFile*) = {
-    val run = new Run()
-    assert(run.typerPhase != NoPhase, "REPL requires a typer phase.")
-    reporter.reset()
-    run compileSources sources.toList
-    (!reporter.hasErrors, run)
-  }
-
-  /** Compile an nsc SourceFile.  Returns true if there are
-    *  no compilation errors, or false otherwise.
-    */
-  def compileSources(sources: SourceFile*): Boolean =
-    compileSourcesKeepingRun(sources: _*)._1
-
-  /** Compile a string.  Returns true if there are no
-    *  compilation errors, or false otherwise.
-    */
-  def compileString(code: String): Boolean =
-    compileSources(new BatchSourceFile("