diff --git a/.gitignore b/.gitignore index 3624d12269612..debad77ec2ad3 100644 --- a/.gitignore +++ b/.gitignore @@ -66,6 +66,7 @@ scalastyle-output.xml R-unit-tests.log R/unit-tests.out python/lib/pyspark.zip +lint-r-report.log # For Hive metastore_db/ diff --git a/.rat-excludes b/.rat-excludes index c0f81b57fe09d..236c2db05367c 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -28,6 +28,7 @@ spark-env.sh spark-env.cmd spark-env.sh.template log4j-defaults.properties +log4j-defaults-repl.properties bootstrap-tooltip.js jquery-1.11.1.min.js d3.min.js @@ -80,5 +81,15 @@ local-1425081759269/* local-1426533911241/* local-1426633911242/* local-1430917381534/* +local-1430917381535_1 +local-1430917381535_2 DESCRIPTION NAMESPACE +test_support/* +.*Rd +help/* +html/* +INDEX +.lintr +gen-java.* +.*avpr diff --git a/LICENSE b/LICENSE index d6b9ccf07d999..f9e412cade345 100644 --- a/LICENSE +++ b/LICENSE @@ -853,6 +853,52 @@ and Vis.js may be distributed under either license. +======================================================================== +For dagre-d3 (core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js): +======================================================================== +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +======================================================================== +For graphlib-dot (core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js): +======================================================================== +Copyright (c) 2012-2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + ======================================================================== BSD-style licenses ======================================================================== @@ -861,7 +907,7 @@ The following components are provided under a BSD-style license. See project lin (BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model) - (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.3 - http://jblas.org/) + (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.4 - http://jblas.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) @@ -902,5 +948,6 @@ The following components are provided under the MIT License. See project link fo (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-all:1.8.5 - http://www.mockito.org) + (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) + (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) diff --git a/R/README.md b/R/README.md index a6970e39b55f3..005f56da1670c 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R #### Build Spark -Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run ``` build/mvn -DskipTests -Psparkr package ``` @@ -52,7 +52,7 @@ The SparkR documentation (Rd files and HTML files) are not a part of the source SparkR comes with several sample programs in the `examples/src/main/r` directory. To run one of them, use `./bin/sparkR `. For example: - ./bin/sparkR examples/src/main/r/pi.R local[2] + ./bin/sparkR examples/src/main/r/dataframe.R You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): @@ -63,5 +63,5 @@ You can also run the unit-tests for SparkR by running (you need to install the [ The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run ``` export YARN_CONF_DIR=/etc/hadoop/conf -./bin/spark-submit --master yarn examples/src/main/r/pi.R 4 +./bin/spark-submit --master yarn examples/src/main/r/dataframe.R ``` diff --git a/R/create-docs.sh b/R/create-docs.sh index 4194172a2e115..6a4687b06ecb9 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -23,14 +23,14 @@ # After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html +set -o pipefail +set -e + # Figure out where the script is export FWDIR="$(cd "`dirname "$0"`"; pwd)" pushd $FWDIR -# Generate Rd file -Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))' - -# Install the package +# Install the package (this will also generate the Rd files) ./install-dev.sh # Now create HTML files diff --git a/R/install-dev.bat b/R/install-dev.bat index 008a5c668bc45..f32670b67de96 100644 --- a/R/install-dev.bat +++ b/R/install-dev.bat @@ -25,3 +25,8 @@ set SPARK_HOME=%~dp0.. MKDIR %SPARK_HOME%\R\lib R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ + +rem Zip the SparkR package so that it can be distributed to worker nodes on YARN +pushd %SPARK_HOME%\R\lib +%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR +popd diff --git a/R/install-dev.sh b/R/install-dev.sh index 55ed6f4be1a4a..4972bb9217072 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -26,11 +26,24 @@ # NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory # to load the SparkR package on the worker nodes. +set -o pipefail +set -e FWDIR="$(cd `dirname $0`; pwd)" LIB_DIR="$FWDIR/lib" mkdir -p $LIB_DIR -# Install R +pushd $FWDIR > /dev/null + +# Generate Rd files if devtools is installed +Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' + +# Install SparkR to $LIB_DIR R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ + +# Zip the SparkR package so that it can be distributed to worker nodes on YARN +cd $LIB_DIR +jar cfM "$LIB_DIR/sparkr.zip" SparkR + +popd > /dev/null diff --git a/R/log4j.properties b/R/log4j.properties index 701adb2a3da1d..cce8d9152d32d 100644 --- a/R/log4j.properties +++ b/R/log4j.properties @@ -19,7 +19,7 @@ log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true -log4j.appender.file.file=R-unit-tests.log +log4j.appender.file.file=R/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n diff --git a/R/pkg/.lintr b/R/pkg/.lintr new file mode 100644 index 0000000000000..038236fc149e6 --- /dev/null +++ b/R/pkg/.lintr @@ -0,0 +1,2 @@ +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index efc85bbc4b316..d028821534b1a 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -32,4 +32,3 @@ Collate: 'serialize.R' 'sparkR.R' 'utils.R' - 'zzz.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 819e9a24e5c0e..7f857222452d4 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -1,12 +1,20 @@ # Imports from base R importFrom(methods, setGeneric, setMethod, setOldClass) -useDynLib(SparkR, stringHashCode) + +# Disable native libraries till we figure out how to package it +# See SPARKR-7839 +#useDynLib(SparkR, stringHashCode) # S3 methods exported export("sparkR.init") export("sparkR.stop") export("print.jobj") +# Job group lifecycle management methods +export("setJobGroup", + "clearJobGroup", + "cancelJobGroup") + exportClasses("DataFrame") exportMethods("arrange", @@ -16,9 +24,11 @@ exportMethods("arrange", "count", "describe", "distinct", + "dropna", "dtypes", "except", "explain", + "fillna", "filter", "first", "group_by", @@ -37,7 +47,7 @@ exportMethods("arrange", "registerTempTable", "rename", "repartition", - "sampleDF", + "sample", "sample_frac", "saveAsParquetFile", "saveAsTable", @@ -53,38 +63,62 @@ exportMethods("arrange", "unpersist", "where", "withColumn", - "withColumnRenamed") + "withColumnRenamed", + "write.df") exportClasses("Column") exportMethods("abs", + "acos", "alias", "approxCountDistinct", "asc", + "asin", + "atan", + "atan2", "avg", "cast", + "cbrt", + "ceiling", "contains", + "cos", + "cosh", "countDistinct", "desc", "endsWith", + "exp", + "expm1", + "floor", "getField", "getItem", + "hypot", "isNotNull", "isNull", "last", "like", + "log", + "log10", + "log1p", "lower", "max", "mean", "min", "n", "n_distinct", + "rint", "rlike", + "sign", + "sin", + "sinh", "sqrt", "startsWith", "substr", "sum", "sumDistinct", + "tan", + "tanh", + "toDegrees", + "toRadians", "upper") exportClasses("GroupedData") @@ -101,6 +135,7 @@ export("cacheTable", "jsonFile", "loadDF", "parquetFile", + "read.df", "sql", "table", "tableNames", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 2705817531019..60702824acb46 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -38,7 +38,7 @@ setClass("DataFrame", setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) { .Object@env <- new.env() .Object@env$isCached <- isCached - + .Object@sdf <- sdf .Object }) @@ -55,19 +55,19 @@ dataFrame <- function(sdf, isCached = FALSE) { ############################ DataFrame Methods ############################################## #' Print Schema of a DataFrame -#' +#' #' Prints out the schema in tree format -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname printSchema #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' printSchema(df) #'} setMethod("printSchema", @@ -78,19 +78,19 @@ setMethod("printSchema", }) #' Get schema object -#' +#' #' Returns the schema of this DataFrame as a structType object. -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname schema #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' dfSchema <- schema(df) #'} setMethod("schema", @@ -100,9 +100,9 @@ setMethod("schema", }) #' Explain -#' +#' #' Print the logical and physical Catalyst plans to the console for debugging. -#' +#' #' @param x A SparkSQL DataFrame #' @param extended Logical. If extended is False, explain() only prints the physical plan. #' @rdname explain @@ -110,9 +110,9 @@ setMethod("schema", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' explain(df, TRUE) #'} setMethod("explain", @@ -139,9 +139,9 @@ setMethod("explain", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' isLocal(df) #'} setMethod("isLocal", @@ -162,15 +162,15 @@ setMethod("isLocal", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' showDF(df) #'} setMethod("showDF", signature(x = "DataFrame"), - function(x, numRows = 20) { - s <- callJMethod(x@sdf, "showString", numToInt(numRows)) + function(x, numRows = 20, truncate = TRUE) { + s <- callJMethod(x@sdf, "showString", numToInt(numRows), truncate) cat(s) }) @@ -185,9 +185,9 @@ setMethod("showDF", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' df #'} setMethod("show", "DataFrame", @@ -200,19 +200,19 @@ setMethod("show", "DataFrame", }) #' DataTypes -#' +#' #' Return all column names and their data types as a list -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname dtypes #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' dtypes(df) #'} setMethod("dtypes", @@ -224,19 +224,19 @@ setMethod("dtypes", }) #' Column names -#' +#' #' Return all column names as a list -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname columns #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' columns(df) #'} setMethod("columns", @@ -256,22 +256,22 @@ setMethod("names", }) #' Register Temporary Table -#' +#' #' Registers a DataFrame as a Temporary Table in the SQLContext -#' +#' #' @param x A SparkSQL DataFrame #' @param tableName A character vector containing the name of the table -#' +#' #' @rdname registerTempTable #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "json_df") -#' new_df <- sql(sqlCtx, "SELECT * FROM json_df") +#' new_df <- sql(sqlContext, "SELECT * FROM json_df") #'} setMethod("registerTempTable", signature(x = "DataFrame", tableName = "character"), @@ -293,9 +293,9 @@ setMethod("registerTempTable", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- loadDF(sqlCtx, path, "parquet") -#' df2 <- loadDF(sqlCtx, path2, "parquet") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, path, "parquet") +#' df2 <- read.df(sqlContext, path2, "parquet") #' registerTempTable(df, "table1") #' insertInto(df2, "table1", overwrite = TRUE) #'} @@ -306,19 +306,19 @@ setMethod("insertInto", }) #' Cache -#' +#' #' Persist with the default storage level (MEMORY_ONLY). -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname cache-methods #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' cache(df) #'} setMethod("cache", @@ -341,9 +341,9 @@ setMethod("cache", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #'} setMethod("persist", @@ -366,9 +366,9 @@ setMethod("persist", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #' unpersist(df) #'} @@ -391,16 +391,16 @@ setMethod("unpersist", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- repartition(df, 2L) #'} setMethod("repartition", signature(x = "DataFrame", numPartitions = "numeric"), function(x, numPartitions) { sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions)) - dataFrame(sdf) + dataFrame(sdf) }) # toJSON @@ -415,9 +415,9 @@ setMethod("repartition", # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # path <- "path/to/file.json" -# df <- jsonFile(sqlCtx, path) +# df <- jsonFile(sqlContext, path) # newRDD <- toJSON(df) #} setMethod("toJSON", @@ -440,9 +440,9 @@ setMethod("toJSON", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' saveAsParquetFile(df, "/tmp/sparkr-tmp/") #'} setMethod("saveAsParquetFile", @@ -461,9 +461,9 @@ setMethod("saveAsParquetFile", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' distinctDF <- distinct(df) #'} setMethod("distinct", @@ -473,26 +473,26 @@ setMethod("distinct", dataFrame(sdf) }) -#' SampleDF +#' Sample #' #' Return a sampled subset of this DataFrame using a random seed. #' #' @param x A SparkSQL DataFrame #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction -#' @rdname sampleDF +#' @rdname sample #' @aliases sample_frac #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) -#' collect(sampleDF(df, FALSE, 0.5)) -#' collect(sampleDF(df, TRUE, 0.5)) +#' df <- jsonFile(sqlContext, path) +#' collect(sample(df, FALSE, 0.5)) +#' collect(sample(df, TRUE, 0.5)) #'} -setMethod("sampleDF", +setMethod("sample", # TODO : Figure out how to send integer as java.lang.Long to JVM so # we can send seed as an argument through callJMethod signature(x = "DataFrame", withReplacement = "logical", @@ -503,29 +503,29 @@ setMethod("sampleDF", dataFrame(sdf) }) -#' @rdname sampleDF -#' @aliases sampleDF +#' @rdname sample +#' @aliases sample setMethod("sample_frac", signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), function(x, withReplacement, fraction) { - sampleDF(x, withReplacement, fraction) + sample(x, withReplacement, fraction) }) #' Count -#' +#' #' Returns the number of rows in a DataFrame -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname count #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' count(df) #' } setMethod("count", @@ -545,9 +545,9 @@ setMethod("count", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' collected <- collect(df) #' firstName <- collected[[1]]$name #' } @@ -568,21 +568,21 @@ setMethod("collect", }) #' Limit -#' +#' #' Limit the resulting DataFrame to the number of rows specified. -#' +#' #' @param x A SparkSQL DataFrame #' @param num The number of rows to return #' @return A new DataFrame containing the number of rows specified. -#' +#' #' @rdname limit #' @export #' @examples #' \dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' limitedDF <- limit(df, 10) #' } setMethod("limit", @@ -593,15 +593,15 @@ setMethod("limit", }) #' Take the first NUM rows of a DataFrame and return a the results as a data.frame -#' +#' #' @rdname take #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' take(df, 2) #' } setMethod("take", @@ -613,8 +613,8 @@ setMethod("take", #' Head #' -#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, -#' then head() returns the first 6 rows in keeping with the current data.frame +#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, +#' then head() returns the first 6 rows in keeping with the current data.frame #' convention in R. #' #' @param x A SparkSQL DataFrame @@ -626,9 +626,9 @@ setMethod("take", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' head(df) #' } setMethod("head", @@ -647,9 +647,9 @@ setMethod("head", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' first(df) #' } setMethod("first", @@ -659,19 +659,19 @@ setMethod("first", }) # toRDD() -# +# # Converts a Spark DataFrame to an RDD while preserving column names. -# +# # @param x A Spark DataFrame -# +# # @rdname DataFrame # @export # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # path <- "path/to/file.json" -# df <- jsonFile(sqlCtx, path) +# df <- jsonFile(sqlContext, path) # rdd <- toRDD(df) # } setMethod("toRDD", @@ -938,9 +938,9 @@ setMethod("select", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' selectExpr(df, "col1", "(col2 * 5) as newCol") #' } setMethod("selectExpr", @@ -964,9 +964,9 @@ setMethod("selectExpr", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- withColumn(df, "newCol", df$col1 * 5) #' } setMethod("withColumn", @@ -988,9 +988,9 @@ setMethod("withColumn", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 #' } @@ -1024,9 +1024,9 @@ setMethod("mutate", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- withColumnRenamed(df, "col1", "newCol1") #' } setMethod("withColumnRenamed", @@ -1055,9 +1055,9 @@ setMethod("withColumnRenamed", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- rename(df, col1 = df$newCol1) #' } setMethod("rename", @@ -1095,9 +1095,9 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' arrange(df, df$col1) #' arrange(df, "col1") #' arrange(df, asc(df$col1), desc(abs(df$col2))) @@ -1137,9 +1137,9 @@ setMethod("orderBy", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' filter(df, "col1 > 0") #' filter(df, df$col2 != "abcdefg") #' } @@ -1167,7 +1167,7 @@ setMethod("where", #' #' @param x A Spark DataFrame #' @param y A Spark DataFrame -#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a +#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a #' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join #' @param joinType The type of join to perform. The following join types are available: #' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". @@ -1177,9 +1177,9 @@ setMethod("where", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' join(df1, df2) # Performs a Cartesian #' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression #' join(df1, df2, df1$col1 == df2$col2, "right_outer") @@ -1219,9 +1219,9 @@ setMethod("join", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' unioned <- unionAll(df, df2) #' } setMethod("unionAll", @@ -1244,9 +1244,9 @@ setMethod("unionAll", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' intersectDF <- intersect(df, df2) #' } setMethod("intersect", @@ -1269,9 +1269,9 @@ setMethod("intersect", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' exceptDF <- except(df, df2) #' } #' @rdname except @@ -1303,23 +1303,22 @@ setMethod("except", #' @param source A name for external data source #' @param mode One of 'append', 'overwrite', 'error', 'ignore' #' -#' @rdname saveAsTable +#' @rdname write.df #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) -#' saveAsTable(df, "myfile") +#' df <- jsonFile(sqlContext, path) +#' write.df(df, "myfile", "parquet", "overwrite") #' } -setMethod("saveDF", - signature(df = "DataFrame", path = 'character', source = 'character', - mode = 'character'), - function(df, path = NULL, source = NULL, mode = "append", ...){ +setMethod("write.df", + signature(df = "DataFrame", path = 'character'), + function(df, path, source = NULL, mode = "append", ...){ if (is.null(source)) { - sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) - source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") @@ -1334,6 +1333,14 @@ setMethod("saveDF", callJMethod(df@sdf, "save", source, jmode, options) }) +#' @rdname write.df +#' @aliases saveDF +#' @export +setMethod("saveDF", + signature(df = "DataFrame", path = 'character'), + function(df, path, source = NULL, mode = "append", ...){ + write.df(df, path, source, mode, ...) + }) #' saveAsTable #' @@ -1362,9 +1369,9 @@ setMethod("saveDF", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", @@ -1372,8 +1379,8 @@ setMethod("saveAsTable", mode = 'character'), function(df, tableName, source = NULL, mode="append", ...){ if (is.null(source)) { - sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) - source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") @@ -1394,14 +1401,14 @@ setMethod("saveAsTable", #' @param col A string of name #' @param ... Additional expressions #' @return A DataFrame -#' @rdname describe +#' @rdname describe #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' describe(df) #' describe(df, "col1") #' describe(df, "col1", "col2") @@ -1422,3 +1429,128 @@ setMethod("describe", sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) dataFrame(sdf) }) + +#' dropna +#' +#' Returns a new DataFrame omitting rows with null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param how "any" or "all". +#' if "any", drop a row if it contains any nulls. +#' if "all", drop a row only if all its values are null. +#' if minNonNulls is specified, how is ignored. +#' @param minNonNulls If specified, drop rows that have less than +#' minNonNulls non-null values. +#' This overwrites the how parameter. +#' @param cols Optional list of column names to consider. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dropna(df) +#' } +setMethod("dropna", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + how <- match.arg(how) + if (is.null(cols)) { + cols <- columns(x) + } + if (is.null(minNonNulls)) { + minNonNulls <- if (how == "any") { length(cols) } else { 1 } + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- callJMethod(naFunctions, "drop", + as.integer(minNonNulls), listToSeq(as.list(cols))) + dataFrame(sdf) + }) + +#' @aliases dropna +#' @export +setMethod("na.omit", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + dropna(x, how, minNonNulls, cols) + }) + +#' fillna +#' +#' Replace null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param value Value to replace null values with. +#' Should be an integer, numeric, character or named list. +#' If the value is a named list, then cols is ignored and +#' value must be a mapping from column name (character) to +#' replacement value. The replacement value must be an +#' integer, numeric or character. +#' @param cols optional list of column names to consider. +#' Columns specified in cols that do not have matching data +#' type are ignored. For example, if value is a character, and +#' subset contains a non-character column, then the non-character +#' column is simply ignored. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' fillna(df, 1) +#' fillna(df, list("age" = 20, "name" = "unknown")) +#' } +setMethod("fillna", + signature(x = "DataFrame"), + function(x, value, cols = NULL) { + if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { + stop("value should be an integer, numeric, charactor or named list.") + } + + if (class(value) == "list") { + # Check column names in the named list + colNames <- names(value) + if (length(colNames) == 0 || !all(colNames != "")) { + stop("value should be an a named list with each name being a column name.") + } + + # Convert to the named list to an environment to be passed to JVM + valueMap <- new.env() + for (col in colNames) { + # Check each item in the named list is of valid type + v <- value[[col]] + if (!(class(v) %in% c("integer", "numeric", "character"))) { + stop("Each item in value should be an integer, numeric or charactor.") + } + valueMap[[col]] <- v + } + + # When value is a named list, caller is expected not to pass in cols + if (!is.null(cols)) { + warning("When value is a named list, cols is ignored!") + cols <- NULL + } + + value <- valueMap + } else if (is.integer(value)) { + # Cast an integer to a numeric + value <- as.numeric(value) + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- if (length(cols) == 0) { + callJMethod(naFunctions, "fill", value) + } else { + callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols))) + } + dataFrame(sdf) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 9138629cac9c0..d2d096709245d 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -48,7 +48,7 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, # byte: The RDD stores data serialized in R. # string: The RDD stores data as strings. # row: The RDD stores the serialized rows of a DataFrame. - + # We use an environment to store mutable states inside an RDD object. # Note that R's call-by-value semantics makes modifying slots inside an # object (passed as an argument into a function, such as cache()) difficult: @@ -165,7 +165,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), serializedFuncArr, rdd@env$prev_serializedMode, packageNamesArr, - as.character(.sparkREnv[["libname"]]), broadcastArr, callJMethod(prev_jrdd, "classTag")) } else { @@ -175,7 +174,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), rdd@env$prev_serializedMode, serializedMode, packageNamesArr, - as.character(.sparkREnv[["libname"]]), broadcastArr, callJMethod(prev_jrdd, "classTag")) } @@ -239,7 +237,7 @@ setMethod("cache", # @aliases persist,RDD-method setMethod("persist", signature(x = "RDD", newLevel = "character"), - function(x, newLevel) { + function(x, newLevel = "MEMORY_ONLY") { callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel)) x@env$isCached <- TRUE x @@ -363,7 +361,7 @@ setMethod("collectPartition", # @description # \code{collectAsMap} returns a named list as a map that contains all of the elements -# in a key-value pair RDD. +# in a key-value pair RDD. # @examples #\dontrun{ # sc <- sparkR.init() @@ -666,7 +664,7 @@ setMethod("minimum", # rdd <- parallelize(sc, 1:10) # sumRDD(rdd) # 55 #} -# @rdname sumRDD +# @rdname sumRDD # @aliases sumRDD,RDD setMethod("sumRDD", signature(x = "RDD"), @@ -927,7 +925,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", MAXINT))))) # TODO(zongheng): investigate if this call is an in-place shuffle? - sample(samples)[1:total] + base::sample(samples)[1:total] }) # Creates tuples of the elements in this RDD by applying a function. @@ -996,7 +994,7 @@ setMethod("coalesce", if (shuffle || numPartitions > SparkR:::numPartitions(x)) { func <- function(partIndex, part) { set.seed(partIndex) # partIndex as seed - start <- as.integer(sample(numPartitions, 1) - 1) + start <- as.integer(base::sample(numPartitions, 1) - 1) lapply(seq_along(part), function(i) { pos <- (start + i) %% numPartitions @@ -1090,11 +1088,11 @@ setMethod("sortBy", # Return: # A list of the first N elements from the RDD in the specified order. # -takeOrderedElem <- function(x, num, ascending = TRUE) { +takeOrderedElem <- function(x, num, ascending = TRUE) { if (num <= 0L) { return(list()) } - + partitionFunc <- function(part) { if (num < length(part)) { # R limitation: order works only on primitive types! @@ -1152,7 +1150,7 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { # @aliases takeOrdered,RDD,RDD-method setMethod("takeOrdered", signature(x = "RDD", num = "integer"), - function(x, num) { + function(x, num) { takeOrderedElem(x, num) }) @@ -1173,7 +1171,7 @@ setMethod("takeOrdered", # @aliases top,RDD,RDD-method setMethod("top", signature(x = "RDD", num = "integer"), - function(x, num) { + function(x, num) { takeOrderedElem(x, num, FALSE) }) @@ -1181,7 +1179,7 @@ setMethod("top", # # Aggregate the elements of each partition, and then the results for all the # partitions, using a given associative function and a neutral "zero value". -# +# # @param x An RDD. # @param zeroValue A neutral "zero value". # @param op An associative function for the folding operation. @@ -1207,7 +1205,7 @@ setMethod("fold", # # Aggregate the elements of each partition, and then the results for all the # partitions, using given combine functions and a neutral "zero value". -# +# # @param x An RDD. # @param zeroValue A neutral "zero value". # @param seqOp A function to aggregate the RDD elements. It may return a different @@ -1230,11 +1228,11 @@ setMethod("fold", # @aliases aggregateRDD,RDD,RDD-method setMethod("aggregateRDD", signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), - function(x, zeroValue, seqOp, combOp) { + function(x, zeroValue, seqOp, combOp) { partitionFunc <- function(part) { Reduce(seqOp, part, zeroValue) } - + partitionList <- collect(lapplyPartition(x, partitionFunc), flatten = FALSE) Reduce(combOp, partitionList, zeroValue) @@ -1330,7 +1328,7 @@ setMethod("setName", #\dontrun{ # sc <- sparkR.init() # rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -# collect(zipWithUniqueId(rdd)) +# collect(zipWithUniqueId(rdd)) # # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) #} # @rdname zipWithUniqueId @@ -1426,7 +1424,7 @@ setMethod("glom", partitionFunc <- function(part) { list(part) } - + lapplyPartition(x, partitionFunc) }) @@ -1498,16 +1496,16 @@ setMethod("zipRDD", # The jrdd's elements are of scala Tuple2 type. The serialized # flag here is used for the elements inside the tuples. rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - + mergePartitions(rdd, TRUE) }) # Cartesian product of this RDD and another one. # -# Return the Cartesian product of this RDD and another one, -# that is, the RDD of all pairs of elements (a, b) where a +# Return the Cartesian product of this RDD and another one, +# that is, the RDD of all pairs of elements (a, b) where a # is in this and b is in other. -# +# # @param x An RDD. # @param other An RDD. # @return A new RDD which is the Cartesian product of these two RDDs. @@ -1515,7 +1513,7 @@ setMethod("zipRDD", #\dontrun{ # sc <- sparkR.init() # rdd <- parallelize(sc, 1:2) -# sortByKey(cartesian(rdd, rdd)) +# sortByKey(cartesian(rdd, rdd)) # # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) #} # @rdname cartesian @@ -1528,7 +1526,7 @@ setMethod("cartesian", # The jrdd's elements are of scala Tuple2 type. The serialized # flag here is used for the elements inside the tuples. rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - + mergePartitions(rdd, FALSE) }) @@ -1598,11 +1596,11 @@ setMethod("intersection", # Zips an RDD's partitions with one (or more) RDD(s). # Same as zipPartitions in Spark. -# +# # @param ... RDDs to be zipped. # @param func A function to transform zipped partitions. -# @return A new RDD by applying a function to the zipped partitions. -# Assumes that all the RDDs have the *same number of partitions*, but +# @return A new RDD by applying a function to the zipped partitions. +# Assumes that all the RDDs have the *same number of partitions*, but # does *not* require them to have the same number of elements in each partition. # @examples #\dontrun{ @@ -1610,7 +1608,7 @@ setMethod("intersection", # rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 # rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 # rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 -# collect(zipPartitions(rdd1, rdd2, rdd3, +# collect(zipPartitions(rdd1, rdd2, rdd3, # func = function(x, y, z) { list(list(x, y, z))} )) # # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) #} @@ -1627,7 +1625,7 @@ setMethod("zipPartitions", if (length(unique(nPart)) != 1) { stop("Can only zipPartitions RDDs which have the same number of partitions.") } - + rrdds <- lapply(rrdds, function(rdd) { mapPartitionsWithIndex(rdd, function(partIndex, part) { print(length(part)) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index cae06e6af2bff..30978bb50d339 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -69,7 +69,7 @@ infer_type <- function(x) { #' #' Converts an RDD to a DataFrame by infer the types. #' -#' @param sqlCtx A SQLContext +#' @param sqlContext A SQLContext #' @param data An RDD or list or data.frame #' @param schema a list of column names or named list (StructType), optional #' @return an DataFrame @@ -77,16 +77,18 @@ infer_type <- function(x) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) -#' df <- createDataFrame(sqlCtx, rdd) +#' df <- createDataFrame(sqlContext, rdd) #' } # TODO(davies): support sampling and infer type from NA -createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { +createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { if (is.data.frame(data)) { # get the names of columns, they will be put into RDD - schema <- names(data) + if (is.null(schema)) { + schema <- names(data) + } n <- nrow(data) m <- ncol(data) # get rid of factor type @@ -102,7 +104,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { }) } if (is.list(data)) { - sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlCtx) + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext) rdd <- parallelize(sc, data) } else if (inherits(data, "RDD")) { rdd <- data @@ -146,7 +148,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { jrdd <- getJRDD(lapply(rdd, function(x) x), "row") srdd <- callJMethod(jrdd, "rdd") sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", - srdd, schema$jobj, sqlCtx) + srdd, schema$jobj, sqlContext) dataFrame(sdf) } @@ -161,7 +163,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) # df <- toDF(rdd) # } @@ -170,39 +172,39 @@ setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) setMethod("toDF", signature(x = "RDD"), function(x, ...) { - sqlCtx <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { + sqlContext <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { get(".sparkRHivesc", envir = .sparkREnv) } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { get(".sparkRSQLsc", envir = .sparkREnv) } else { stop("no SQL context available") } - createDataFrame(sqlCtx, x, ...) + createDataFrame(sqlContext, x, ...) }) #' Create a DataFrame from a JSON file. #' -#' Loads a JSON file (one object per line), returning the result as a DataFrame +#' Loads a JSON file (one object per line), returning the result as a DataFrame #' It goes through the entire dataset once to determine the schema. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' } -jsonFile <- function(sqlCtx, path) { +jsonFile <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path path <- normalizePath(path) # Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") - sdf <- callJMethod(sqlCtx, "jsonFile", path) + sdf <- callJMethod(sqlContext, "jsonFile", path) dataFrame(sdf) } @@ -211,7 +213,7 @@ jsonFile <- function(sqlCtx, path) { # # Loads an RDD storing one JSON object per string as a DataFrame. # -# @param sqlCtx SQLContext to use +# @param sqlContext SQLContext to use # @param rdd An RDD of JSON string # @param schema A StructType object to use as schema # @param samplingRatio The ratio of simpling used to infer the schema @@ -220,16 +222,16 @@ jsonFile <- function(sqlCtx, path) { # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # rdd <- texFile(sc, "path/to/json") -# df <- jsonRDD(sqlCtx, rdd) +# df <- jsonRDD(sqlContext, rdd) # } # TODO: support schema -jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { +jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { rdd <- serializeToString(rdd) if (is.null(schema)) { - sdf <- callJMethod(sqlCtx, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + sdf <- callJMethod(sqlContext, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) dataFrame(sdf) } else { stop("not implemented") @@ -238,68 +240,67 @@ jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { #' Create a DataFrame from a Parquet file. -#' +#' #' Loads a Parquet file, returning the result as a DataFrame. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param ... Path(s) of parquet file(s) to read. #' @return DataFrame #' @export # TODO: Implement saveasParquetFile and write examples for both -parquetFile <- function(sqlCtx, ...) { +parquetFile <- function(sqlContext, ...) { # Allow the user to have a more flexible definiton of the text file path paths <- lapply(list(...), normalizePath) - sdf <- callJMethod(sqlCtx, "parquetFile", paths) + sdf <- callJMethod(sqlContext, "parquetFile", paths) dataFrame(sdf) } #' SQL Query -#' +#' #' Executes a SQL query using Spark, returning the result as a DataFrame. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param sqlQuery A character vector containing the SQL query #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' new_df <- sql(sqlCtx, "SELECT * FROM table") +#' new_df <- sql(sqlContext, "SELECT * FROM table") #' } -sql <- function(sqlCtx, sqlQuery) { - sdf <- callJMethod(sqlCtx, "sql", sqlQuery) - dataFrame(sdf) +sql <- function(sqlContext, sqlQuery) { + sdf <- callJMethod(sqlContext, "sql", sqlQuery) + dataFrame(sdf) } - #' Create a DataFrame from a SparkSQL Table -#' +#' #' Returns the specified Table as a DataFrame. The Table must have already been registered #' in the SQLContext. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The SparkSQL Table to convert to a DataFrame. #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' new_df <- table(sqlCtx, "table") +#' new_df <- table(sqlContext, "table") #' } -table <- function(sqlCtx, tableName) { - sdf <- callJMethod(sqlCtx, "table", tableName) - dataFrame(sdf) +table <- function(sqlContext, tableName) { + sdf <- callJMethod(sqlContext, "table", tableName) + dataFrame(sdf) } @@ -307,22 +308,22 @@ table <- function(sqlCtx, tableName) { #' #' Returns a DataFrame containing names of tables in the given database. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param databaseName name of the database #' @return a DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' tables(sqlCtx, "hive") +#' sqlContext <- sparkRSQL.init(sc) +#' tables(sqlContext, "hive") #' } -tables <- function(sqlCtx, databaseName = NULL) { +tables <- function(sqlContext, databaseName = NULL) { jdf <- if (is.null(databaseName)) { - callJMethod(sqlCtx, "tables") + callJMethod(sqlContext, "tables") } else { - callJMethod(sqlCtx, "tables", databaseName) + callJMethod(sqlContext, "tables", databaseName) } dataFrame(jdf) } @@ -332,82 +333,82 @@ tables <- function(sqlCtx, databaseName = NULL) { #' #' Returns the names of tables in the given database as an array. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param databaseName name of the database #' @return a list of table names #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' tableNames(sqlCtx, "hive") +#' sqlContext <- sparkRSQL.init(sc) +#' tableNames(sqlContext, "hive") #' } -tableNames <- function(sqlCtx, databaseName = NULL) { +tableNames <- function(sqlContext, databaseName = NULL) { if (is.null(databaseName)) { - callJMethod(sqlCtx, "tableNames") + callJMethod(sqlContext, "tableNames") } else { - callJMethod(sqlCtx, "tableNames", databaseName) + callJMethod(sqlContext, "tableNames", databaseName) } } #' Cache Table -#' +#' #' Caches the specified table in-memory. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The name of the table being cached #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' cacheTable(sqlCtx, "table") +#' cacheTable(sqlContext, "table") #' } -cacheTable <- function(sqlCtx, tableName) { - callJMethod(sqlCtx, "cacheTable", tableName) +cacheTable <- function(sqlContext, tableName) { + callJMethod(sqlContext, "cacheTable", tableName) } #' Uncache Table -#' +#' #' Removes the specified table from the in-memory cache. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The name of the table being uncached #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' uncacheTable(sqlCtx, "table") +#' uncacheTable(sqlContext, "table") #' } -uncacheTable <- function(sqlCtx, tableName) { - callJMethod(sqlCtx, "uncacheTable", tableName) +uncacheTable <- function(sqlContext, tableName) { + callJMethod(sqlContext, "uncacheTable", tableName) } #' Clear Cache #' #' Removes all cached tables from the in-memory cache. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @examples #' \dontrun{ -#' clearCache(sqlCtx) +#' clearCache(sqlContext) #' } -clearCache <- function(sqlCtx) { - callJMethod(sqlCtx, "clearCache") +clearCache <- function(sqlContext) { + callJMethod(sqlContext, "clearCache") } #' Drop Temporary Table @@ -415,22 +416,22 @@ clearCache <- function(sqlCtx) { #' Drops the temporary table with the given table name in the catalog. #' If the table has been cached/persisted before, it's also unpersisted. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The name of the SparkSQL table to be dropped. #' @examples #' \dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- loadDF(sqlCtx, path, "parquet") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, path, "parquet") #' registerTempTable(df, "table") -#' dropTempTable(sqlCtx, "table") +#' dropTempTable(sqlContext, "table") #' } -dropTempTable <- function(sqlCtx, tableName) { +dropTempTable <- function(sqlContext, tableName) { if (class(tableName) != "character") { stop("tableName must be a string.") } - callJMethod(sqlCtx, "dropTempTable", tableName) + callJMethod(sqlContext, "dropTempTable", tableName) } #' Load an DataFrame @@ -441,7 +442,7 @@ dropTempTable <- function(sqlCtx, tableName) { #' If `source` is not specified, the default data source configured by #' "spark.sql.sources.default" will be used. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param path The path of files to load #' @param source the name of external data source #' @return DataFrame @@ -449,19 +450,37 @@ dropTempTable <- function(sqlCtx, tableName) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- load(sqlCtx, "path/to/file.json", source = "json") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, "path/to/file.json", source = "json") #' } -loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { +read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path } - sdf <- callJMethod(sqlCtx, "load", source, options) + if (is.null(source)) { + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + if (!is.null(schema)) { + stopifnot(class(schema) == "structType") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, + schema$jobj, options) + } else { + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options) + } dataFrame(sdf) } +#' @aliases loadDF +#' @export + +loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { + read.df(sqlContext, path, source, schema, ...) +} + #' Create an external table #' #' Creates an external table based on the dataset in a data source, @@ -471,7 +490,7 @@ loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { #' If `source` is not specified, the default data source configured by #' "spark.sql.sources.default" will be used. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName A name of the table #' @param path The path of files to load #' @param source the name of external data source @@ -480,15 +499,15 @@ loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- sparkRSQL.createExternalTable(sqlCtx, "myjson", path="path/to/json", source="json") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- sparkRSQL.createExternalTable(sqlContext, "myjson", path="path/to/json", source="json") #' } -createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, ...) { +createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path } - sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options) + sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) dataFrame(sdf) } diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 23dc38780716e..2403925b267c8 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -27,9 +27,9 @@ # @description Broadcast variables can be created using the broadcast # function from a \code{SparkContext}. # @rdname broadcast-class -# @seealso broadcast +# @seealso broadcast # -# @param id Id of the backing Spark broadcast variable +# @param id Id of the backing Spark broadcast variable # @export setClass("Broadcast", slots = list(id = "character")) @@ -68,7 +68,7 @@ setMethod("value", # variable on workers. Not intended for use outside the package. # # @rdname broadcast-internal -# @seealso broadcast, value +# @seealso broadcast, value # @param bcastId The id of broadcast variable to set # @param value The value to be set diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 1281c41213e32..78c7a3037ffac 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -34,24 +34,36 @@ connectBackend <- function(hostname, port, timeout = 6000) { con } -launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) { +determineSparkSubmitBin <- function() { if (.Platform$OS.type == "unix") { sparkSubmitBinName = "spark-submit" } else { sparkSubmitBinName = "spark-submit.cmd" } + sparkSubmitBinName +} + +generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + if (jars != "") { + jars <- paste("--jars", jars) + } + + if (packages != "") { + packages <- paste("--packages", packages) + } + combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ") + combinedArgs +} + +launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + sparkSubmitBinName <- determineSparkSubmitBin() if (sparkHome != "") { sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) } else { sparkSubmitBin <- sparkSubmitBinName } - - if (jars != "") { - jars <- paste("--jars", jars) - } - - combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ") + combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages) cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") invisible(system2(sparkSubmitBin, combinedArgs, wait = F)) } diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 9a68445ab451a..8e4b0f5bf1c4d 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -55,12 +55,17 @@ operators <- list( "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", # we can not override `&&` and `||`, so use `&` and `|` instead - "&" = "and", "|" = "or" #, "!" = "unary_$bang" + "&" = "and", "|" = "or", #, "!" = "unary_$bang" + "^" = "pow" ) column_functions1 <- c("asc", "desc", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", - "first", "last", "lower", "upper", "sumDistinct") + "first", "last", "lower", "upper", "sumDistinct", + "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", + "expm1", "floor", "log", "log10", "log1p", "rint", "sign", + "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") +binary_mathfunctions<- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, @@ -76,7 +81,11 @@ createOperator <- function(op) { if (class(e2) == "Column") { e2 <- e2@jc } - callJMethod(e1@jc, operators[[op]], e2) + if (op == "^") { + jc <- callJStatic("org.apache.spark.sql.functions", operators[[op]], e1@jc, e2) + } else { + callJMethod(e1@jc, operators[[op]], e2) + } } column(jc) }) @@ -106,11 +115,29 @@ createStaticFunction <- function(name) { setMethod(name, signature(x = "Column"), function(x) { + if (name == "ceiling") { + name <- "ceil" + } + if (name == "sign") { + name <- "signum" + } jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) column(jc) }) } +createBinaryMathfunctions <- function(name) { + setMethod(name, + signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) + column(jc) + }) +} + createMethods <- function() { for (op in names(operators)) { createOperator(op) @@ -124,6 +151,9 @@ createMethods <- function() { for (x in functions) { createStaticFunction(x) } + for (name in binary_mathfunctions) { + createBinaryMathfunctions(name) + } } createMethods() @@ -180,6 +210,22 @@ setMethod("cast", } }) +#' Match a column with given values. +#' +#' @rdname column +#' @return a matched values as a result of comparing with given values. +#' \dontrun{ +#' filter(df, "age in (10, 30)") +#' where(df, df$age %in% c(10, 30)) +#' } +setMethod("%in%", + signature(x = "Column"), + function(x, table) { + table <- listToSeq(as.list(table)) + jc <- callJMethod(x@jc, "in", table) + return(column(jc)) + }) + #' Approx Count Distinct #' #' @rdname column diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 257b435607ce8..d961bbc383688 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -18,7 +18,7 @@ # Utility functions to deserialize objects from Java. # Type mapping from Java to R -# +# # void -> NULL # Int -> integer # String -> character diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 557128a419f19..fad9d71158c51 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -20,7 +20,8 @@ # @rdname aggregateRDD # @seealso reduce # @export -setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) +setGeneric("aggregateRDD", + function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) # @rdname cache-methods # @export @@ -130,7 +131,7 @@ setGeneric("maximum", function(x) { standardGeneric("maximum") }) # @export setGeneric("minimum", function(x) { standardGeneric("minimum") }) -# @rdname sumRDD +# @rdname sumRDD # @export setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) @@ -219,7 +220,7 @@ setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) # @rdname zipRDD # @export -setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, +setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, signature = "...") # @rdname zipWithIndex @@ -364,7 +365,7 @@ setGeneric("subtract", # @rdname subtractByKey # @export -setGeneric("subtractByKey", +setGeneric("subtractByKey", function(x, other, numPartitions = 1) { standardGeneric("subtractByKey") }) @@ -396,6 +397,20 @@ setGeneric("columns", function(x) {standardGeneric("columns") }) #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) +#' @rdname nafunctions +#' @export +setGeneric("dropna", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("dropna") + }) + +#' @rdname nafunctions +#' @export +setGeneric("na.omit", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("na.omit") + }) + #' @rdname schema #' @export setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) @@ -408,6 +423,10 @@ setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @export setGeneric("except", function(x, y) { standardGeneric("except") }) +#' @rdname nafunctions +#' @export +setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) + #' @rdname filter #' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) @@ -456,19 +475,19 @@ setGeneric("rename", function(x, ...) { standardGeneric("rename") }) #' @export setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) -#' @rdname sampleDF +#' @rdname sample #' @export -setGeneric("sample_frac", +setGeneric("sample", function(x, withReplacement, fraction, seed) { - standardGeneric("sample_frac") - }) + standardGeneric("sample") + }) -#' @rdname sampleDF +#' @rdname sample #' @export -setGeneric("sampleDF", +setGeneric("sample_frac", function(x, withReplacement, fraction, seed) { - standardGeneric("sampleDF") - }) + standardGeneric("sample_frac") + }) #' @rdname saveAsParquetFile #' @export @@ -480,9 +499,13 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { standardGeneric("saveAsTable") }) -#' @rdname saveAsTable +#' @rdname write.df +#' @export +setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) + +#' @rdname write.df #' @export -setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") }) +setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") }) #' @rdname schema #' @export @@ -548,6 +571,10 @@ setGeneric("avg", function(x, ...) { standardGeneric("avg") }) #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) +#' @rdname column +#' @export +setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) + #' @rdname column #' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) @@ -571,6 +598,10 @@ setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) +#' @rdname column +#' @export +setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) + #' @rdname column #' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) @@ -599,6 +630,10 @@ setGeneric("n", function(x) { standardGeneric("n") }) #' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) +#' @rdname column +#' @export +setGeneric("rint", function(x, ...) { standardGeneric("rint") }) + #' @rdname column #' @export setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) @@ -613,5 +648,12 @@ setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) #' @rdname column #' @export -setGeneric("upper", function(x) { standardGeneric("upper") }) +setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) + +#' @rdname column +#' @export +setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) +#' @rdname column +#' @export +setGeneric("upper", function(x) { standardGeneric("upper") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index b758481997574..8f1c68f7c4d28 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -136,4 +136,3 @@ createMethods <- function() { } createMethods() - diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R index a8a25230b636d..0838a7bb35e0d 100644 --- a/R/pkg/R/jobj.R +++ b/R/pkg/R/jobj.R @@ -16,7 +16,7 @@ # # References to objects that exist on the JVM backend -# are maintained using the jobj. +# are maintained using the jobj. #' @include generics.R NULL diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 7694652856da5..ebc6ff65e9d0f 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -215,7 +215,6 @@ setMethod("partitionBy", serializedHashFuncBytes, getSerializedMode(x), packageNamesArr, - as.character(.sparkREnv$libname), broadcastArr, callJMethod(jrdd, "classTag")) @@ -329,7 +328,7 @@ setMethod("reduceByKey", convertEnvsToList(keys, vals) } locallyReduced <- lapplyPartition(x, reduceVals) - shuffled <- partitionBy(locallyReduced, numPartitions) + shuffled <- partitionBy(locallyReduced, numToInt(numPartitions)) lapplyPartition(shuffled, reduceVals) }) @@ -436,7 +435,7 @@ setMethod("combineByKey", convertEnvsToList(keys, combiners) } locallyCombined <- lapplyPartition(x, combineLocally) - shuffled <- partitionBy(locallyCombined, numPartitions) + shuffled <- partitionBy(locallyCombined, numToInt(numPartitions)) mergeAfterShuffle <- function(part) { combiners <- new.env() keys <- new.env() @@ -560,8 +559,8 @@ setMethod("join", # Left outer join two RDDs # # @description -# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -597,8 +596,8 @@ setMethod("leftOuterJoin", # Right outer join two RDDs # # @description -# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -634,8 +633,8 @@ setMethod("rightOuterJoin", # Full outer join two RDDs # # @description -# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -784,7 +783,7 @@ setMethod("sortByKey", newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) lapplyPartition(newRDD, partitionFunc) }) - + # Subtract a pair RDD with another pair RDD. # # Return an RDD with the pairs from x whose keys are not in other. @@ -820,7 +819,7 @@ setMethod("subtractByKey", }) # Return a subset of this RDD sampled by key. -# +# # @description # \code{sampleByKey} Create a sample of this RDD using variable sampling rates # for different keys as specified by fractions, a key to sampling rate map. diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index e442119086b17..15e2bdbd55d79 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -20,7 +20,7 @@ #' structType #' -#' Create a structType object that contains the metadata for a DataFrame. Intended for +#' Create a structType object that contains the metadata for a DataFrame. Intended for #' use with createDataFrame and toDF. #' #' @param x a structField object (created with the field() function) diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index c53d0a961016f..78535eff0d2f6 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -37,6 +37,14 @@ writeObject <- function(con, object, writeType = TRUE) { # passing in vectors as arrays and instead require arrays to be passed # as lists. type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") + # Checking types is needed here, since ‘is.na’ only handles atomic vectors, + # lists and pairlists + if (type %in% c("integer", "character", "logical", "double", "numeric")) { + if (is.na(object)) { + object <- NULL + type <- "NULL" + } + } if (writeType) { writeType(con, type) } @@ -160,6 +168,14 @@ writeList <- function(con, arr) { } } +# Used to pass arrays where the elements can be of different types +writeGenericList <- function(con, list) { + writeInt(con, length(list)) + for (elem in list) { + writeObject(con, elem) + } +} + # Used to pass in hash maps required on Java side. writeEnv <- function(con, env) { len <- length(env) @@ -168,7 +184,7 @@ writeEnv <- function(con, env) { if (len > 0) { writeList(con, as.list(ls(env))) vals <- lapply(ls(env), function(x) { env[[x]] }) - writeList(con, as.list(vals)) + writeGenericList(con, as.list(vals)) } } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index bc82df01f0fff..172335809dec2 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -17,10 +17,6 @@ .sparkREnv <- new.env() -sparkR.onLoad <- function(libname, pkgname) { - .sparkREnv$libname <- libname -} - # Utility function that returns TRUE if we have an active connection to the # backend and FALSE otherwise connExists <- function(env) { @@ -43,7 +39,7 @@ sparkR.stop <- function() { callJMethod(sc, "stop") rm(".sparkRjsc", envir = env) } - + if (exists(".backendLaunched", envir = env)) { callJStatic("SparkRHandler", "stopBackend") } @@ -80,7 +76,7 @@ sparkR.stop <- function() { #' @param sparkEnvir Named list of environment variables to set on worker nodes. #' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. #' @param sparkJars Character string vector of jar files to pass to the worker nodes. -#' @param sparkRLibDir The path where R is installed on the worker nodes. +#' @param sparkPackages Character string vector of packages from spark-packages.org #' @export #' @examples #'\dontrun{ @@ -100,14 +96,15 @@ sparkR.init <- function( sparkEnvir = list(), sparkExecutorEnv = list(), sparkJars = "", - sparkRLibDir = "") { + sparkPackages = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { - cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") + cat(paste("Re-using existing Spark Context.", + "Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")) return(get(".sparkRjsc", envir = .sparkREnv)) } - sparkMem <- Sys.getenv("SPARK_MEM", "512m") + sparkMem <- Sys.getenv("SPARK_MEM", "1024m") jars <- suppressWarnings(normalizePath(as.character(sparkJars))) # Classpath separator is ";" on Windows @@ -129,7 +126,8 @@ sparkR.init <- function( args = path, sparkHome = sparkHome, jars = jars, - sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell")) + sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), + packages = sparkPackages) # wait atmost 100 seconds for JVM to launch wait <- 0.1 for (i in 1:25) { @@ -166,25 +164,23 @@ sparkR.init <- function( sparkHome <- normalizePath(sparkHome) } - if (nchar(sparkRLibDir) != 0) { - .sparkREnv$libname <- sparkRLibDir - } - sparkEnvirMap <- new.env() for (varname in names(sparkEnvir)) { sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] } - + sparkExecutorEnvMap <- new.env() if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { - sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- + paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } for (varname in names(sparkExecutorEnv)) { sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] } nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + localJarPaths <- sapply(nonEmptyJars, + function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs # Seconds resolution is good enough for this purpose, so use ints @@ -214,7 +210,7 @@ sparkR.init <- function( #' Initialize a new SQLContext. #' -#' This function creates a SparkContext from an existing JavaSparkContext and +#' This function creates a SparkContext from an existing JavaSparkContext and #' then uses it to initialize a new SQLContext #' #' @param jsc The existing JavaSparkContext created with SparkR.init() @@ -222,19 +218,26 @@ sparkR.init <- function( #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #'} -sparkRSQL.init <- function(jsc) { +sparkRSQL.init <- function(jsc = NULL) { if (exists(".sparkRSQLsc", envir = .sparkREnv)) { return(get(".sparkRSQLsc", envir = .sparkREnv)) } - sqlCtx <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "createSQLContext", - jsc) - assign(".sparkRSQLsc", sqlCtx, envir = .sparkREnv) - sqlCtx + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + + sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createSQLContext", + sc) + assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv) + sqlContext } #' Initialize a new HiveContext. @@ -246,15 +249,22 @@ sparkRSQL.init <- function(jsc) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRHive.init(sc) +#' sqlContext <- sparkRHive.init(sc) #'} -sparkRHive.init <- function(jsc) { +sparkRHive.init <- function(jsc = NULL) { if (exists(".sparkRHivesc", envir = .sparkREnv)) { return(get(".sparkRHivesc", envir = .sparkREnv)) } - ssc <- callJMethod(jsc, "sc") + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + + ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.HiveContext", ssc) }, error = function(err) { @@ -264,3 +274,47 @@ sparkRHive.init <- function(jsc) { assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) hiveCtx } + +#' Assigns a group ID to all the jobs started by this thread until the group ID is set to a +#' different value or cleared. +#' +#' @param sc existing spark context +#' @param groupid the ID to be assigned to job groups +#' @param description description for the the job group ID +#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setJobGroup(sc, "myJobGroup", "My job group description", TRUE) +#'} + +setJobGroup <- function(sc, groupId, description, interruptOnCancel) { + callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel) +} + +#' Clear current job group ID and its description +#' +#' @param sc existing spark context +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' clearJobGroup(sc) +#'} + +clearJobGroup <- function(sc) { + callJMethod(sc, "clearJobGroup") +} + +#' Cancel active jobs for the specified group +#' +#' @param sc existing spark context +#' @param groupId the ID of job group to be cancelled +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' cancelJobGroup(sc, "myJobGroup") +#'} + +cancelJobGroup <- function(sc, groupId) { + callJMethod(sc, "cancelJobGroup", groupId) +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 0e7b7bd5a5b34..ea629a64f7158 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -122,13 +122,49 @@ hashCode <- function(key) { intBits <- packBits(rawToBits(rawVec), "integer") as.integer(bitwXor(intBits[2], intBits[1])) } else if (class(key) == "character") { - .Call("stringHashCode", key) + # TODO: SPARK-7839 means we might not have the native library available + if (is.loaded("stringHashCode")) { + .Call("stringHashCode", key) + } else { + n <- nchar(key) + if (n == 0) { + 0L + } else { + asciiVals <- sapply(charToRaw(key), function(x) { strtoi(x, 16L) }) + hashC <- 0 + for (k in 1:length(asciiVals)) { + hashC <- mult31AndAdd(hashC, asciiVals[k]) + } + as.integer(hashC) + } + } } else { warning(paste("Could not hash object, returning 0", sep = "")) as.integer(0) } } +# Helper function used to wrap a 'numeric' value to integer bounds. +# Useful for implementing C-like integer arithmetic +wrapInt <- function(value) { + if (value > .Machine$integer.max) { + value <- value - 2 * .Machine$integer.max - 2 + } else if (value < -1 * .Machine$integer.max) { + value <- 2 * .Machine$integer.max + value + 2 + } + value +} + +# Multiply `val` by 31 and add `addVal` to the result. Ensures that +# integer-overflows are handled at every step. +mult31AndAdd <- function(val, addVal) { + vec <- c(bitwShiftL(val, c(4,3,2,1,0)), addVal) + Reduce(function(a, b) { + wrapInt(as.numeric(a) + as.numeric(b)) + }, + vec) +} + # Create a new RDD with serializedMode == "byte". # Return itself if already in "byte" format. serializeToBytes <- function(rdd) { @@ -298,18 +334,21 @@ getStorageLevel <- function(newLevel = c("DISK_ONLY", "MEMORY_ONLY_SER_2", "OFF_HEAP")) { match.arg(newLevel) + storageLevelClass <- "org.apache.spark.storage.StorageLevel" storageLevel <- switch(newLevel, - "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"), - "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"), - "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"), - "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"), - "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"), - "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"), - "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"), - "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"), - "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"), - "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"), - "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP")) + "DISK_ONLY" = callJStatic(storageLevelClass, "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic(storageLevelClass, "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic(storageLevelClass, "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic(storageLevelClass, "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic(storageLevelClass, "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP")) } # Utility function for functions where an argument needs to be integer but we want to allow @@ -332,21 +371,21 @@ listToSeq <- function(l) { } # Utility function to recursively traverse the Abstract Syntax Tree (AST) of a -# user defined function (UDF), and to examine variables in the UDF to decide +# user defined function (UDF), and to examine variables in the UDF to decide # if their values should be included in the new function environment. # param # node The current AST node in the traversal. # oldEnv The original function environment. # defVars An Accumulator of variables names defined in the function's calling environment, # including function argument and local variable names. -# checkedFunc An environment of function objects examined during cleanClosure. It can +# checkedFunc An environment of function objects examined during cleanClosure. It can # be considered as a "name"-to-"list of functions" mapping. # newEnv A new function environment to store necessary function dependencies, an output argument. processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { nodeLen <- length(node) - + if (nodeLen > 1 && typeof(node) == "language") { - # Recursive case: current AST node is an internal node, check for its children. + # Recursive case: current AST node is an internal node, check for its children. if (length(node[[1]]) > 1) { for (i in 1:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) @@ -357,7 +396,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { for (i in 2:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else if (nodeChar == "<-" || nodeChar == "=" || + } else if (nodeChar == "<-" || nodeChar == "=" || nodeChar == "<<-") { # Assignment Ops. defVar <- node[[2]] if (length(defVar) == 1 && typeof(defVar) == "symbol") { @@ -386,21 +425,21 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { } } } - } else if (nodeLen == 1 && + } else if (nodeLen == 1 && (typeof(node) == "symbol" || typeof(node) == "language")) { # Base case: current AST node is a leaf node and a symbol or a function call. nodeChar <- as.character(node) if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. func.env <- oldEnv topEnv <- parent.env(.GlobalEnv) - # Search in function environment, and function's enclosing environments + # Search in function environment, and function's enclosing environments # up to global environment. There is no need to look into package environments - # above the global or namespace environment that is not SparkR below the global, + # above the global or namespace environment that is not SparkR below the global, # as they are assumed to be loaded on workers. while (!identical(func.env, topEnv)) { # Namespaces other than "SparkR" will not be searched. - if (!isNamespace(func.env) || - (getNamespaceName(func.env) == "SparkR" && + if (!isNamespace(func.env) || + (getNamespaceName(func.env) == "SparkR" && !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. # Set parameter 'inherits' to FALSE since we do not need to search in # attached package environments. @@ -408,7 +447,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { error = function(e) { FALSE })) { obj <- get(nodeChar, envir = func.env, inherits = FALSE) if (is.function(obj)) { # If the node is a function call. - funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, + funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, ifnotfound = list(list(NULL)))[[1]] found <- sapply(funcList, function(func) { ifelse(identical(func, obj), TRUE, FALSE) @@ -417,7 +456,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { break } # Function has not been examined, record it and recursively clean its closure. - assign(nodeChar, + assign(nodeChar, if (is.null(funcList[[1]])) { list(obj) } else { @@ -430,7 +469,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { break } } - + # Continue to search in enclosure. func.env <- parent.env(func.env) } @@ -438,8 +477,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { } } -# Utility function to get user defined function (UDF) dependencies (closure). -# More specifically, this function captures the values of free variables defined +# Utility function to get user defined function (UDF) dependencies (closure). +# More specifically, this function captures the values of free variables defined # outside a UDF, and stores them in the function's environment. # param # func A function whose closure needs to be captured. @@ -452,7 +491,7 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { newEnv <- new.env(parent = .GlobalEnv) func.body <- body(func) oldEnv <- environment(func) - # defVars is an Accumulator of variables names defined in the function's calling + # defVars is an Accumulator of variables names defined in the function's calling # environment. First, function's arguments are added to defVars. defVars <- initAccumulator() argNames <- names(as.list(args(func))) @@ -473,15 +512,15 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { # return value # A list of two result RDDs. appendPartitionLengths <- function(x, other) { - if (getSerializedMode(x) != getSerializedMode(other) || + if (getSerializedMode(x) != getSerializedMode(other) || getSerializedMode(x) == "byte") { # Append the number of elements in each partition to that partition so that we can later # know the boundary of elements from x and other. # - # Note that this appending also serves the purpose of reserialization, because even if + # Note that this appending also serves the purpose of reserialization, because even if # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded # as a single byte array. For example, partitions of an RDD generated from partitionBy() - # may be encoded as multiple byte arrays. + # may be encoded as multiple byte arrays. appendLength <- function(part) { len <- length(part) part[[len + 1]] <- len + 1 @@ -508,23 +547,25 @@ mergePartitions <- function(rdd, zip) { lengthOfValues <- part[[len]] lengthOfKeys <- part[[len - lengthOfValues]] stopifnot(len == lengthOfKeys + lengthOfValues) - - # For zip operation, check if corresponding partitions of both RDDs have the same number of elements. + + # For zip operation, check if corresponding partitions + # of both RDDs have the same number of elements. if (zip && lengthOfKeys != lengthOfValues) { - stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + stop(paste("Can only zip RDDs with same number of elements", + "in each pair of corresponding partitions.")) } - + if (lengthOfKeys > 1) { keys <- part[1 : (lengthOfKeys - 1)] } else { keys <- list() } if (lengthOfValues > 1) { - values <- part[(lengthOfKeys + 1) : (len - 1)] + values <- part[(lengthOfKeys + 1) : (len - 1)] } else { values <- list() } - + if (!zip) { return(mergeCompactLists(keys, values)) } @@ -542,6 +583,6 @@ mergePartitions <- function(rdd, zip) { part } } - + PipelinedRDD(rdd, partitionFunc) } diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R index 8fe711b622086..2a8a8213d0849 100644 --- a/R/pkg/inst/profile/general.R +++ b/R/pkg/inst/profile/general.R @@ -16,7 +16,7 @@ # .First <- function() { - home <- Sys.getenv("SPARK_HOME") - .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") + .libPaths(c(packageDir, .libPaths())) Sys.setenv(NOAWT=1) } diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 33478d9e29995..7189f1a260934 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -24,10 +24,24 @@ old <- getOption("defaultPackages") options(defaultPackages = c(old, "SparkR")) - sc <- SparkR::sparkR.init(Sys.getenv("MASTER", unset = "")) + sc <- SparkR::sparkR.init() assign("sc", sc, envir=.GlobalEnv) - sqlCtx <- SparkR::sparkRSQL.init(sc) - assign("sqlCtx", sqlCtx, envir=.GlobalEnv) - cat("\n Welcome to SparkR!") - cat("\n Spark context is available as sc, SQL context is available as sqlCtx\n") + sqlContext <- SparkR::sparkRSQL.init(sc) + sparkVer <- SparkR:::callJMethod(sc, "version") + assign("sqlContext", sqlContext, envir=.GlobalEnv) + cat("\n Welcome to") + cat("\n") + cat(" ____ __", "\n") + cat(" / __/__ ___ _____/ /__", "\n") + cat(" _\\ \\/ _ \\/ _ `/ __/ '_/", "\n") + cat(" /___/ .__/\\_,_/_/ /_/\\_\\") + if (nchar(sparkVer) == 0) { + cat("\n") + } else { + cat(" version ", sparkVer, "\n") + } + cat(" /_/", "\n") + cat("\n") + + cat("\n Spark context is available as sc, SQL context is available as sqlContext\n") } diff --git a/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar b/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar new file mode 100644 index 0000000000000..1d5c2af631aa3 Binary files /dev/null and b/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar differ diff --git a/R/pkg/inst/tests/jarTest.R b/R/pkg/inst/tests/jarTest.R new file mode 100644 index 0000000000000..d68bb20950b00 --- /dev/null +++ b/R/pkg/inst/tests/jarTest.R @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +library(SparkR) + +sc <- sparkR.init() + +helloTest <- SparkR:::callJStatic("sparkR.test.hello", + "helloWorld", + "Dave") + +basicFunction <- SparkR:::callJStatic("sparkR.test.basicFunction", + "addStuff", + 2L, + 2L) + +sparkR.stop() +output <- c(helloTest, basicFunction) +writeLines(output) diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index ca4218f3819f8..ccaea18ecab2a 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -59,15 +59,15 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", wordCount <- lapply(words, function(word) { list(word, 1L) }) counts <- reduceByKey(wordCount, "+", 2L) - + saveAsObjectFile(counts, fileName2) counts <- objectFile(sc, fileName2) - + output <- collect(counts) expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), list("is", 2)) expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) - + unlink(fileName1) unlink(fileName2, recursive = TRUE) }) @@ -82,9 +82,8 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { saveAsObjectFile(rdd2, fileName2) rdd <- objectFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1, recursive = TRUE) unlink(fileName2, recursive = TRUE) }) - diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index 6785a7bdae8cb..3be8c65a6c1a0 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -30,7 +30,7 @@ mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { actual <- collect(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) - + fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) @@ -38,13 +38,13 @@ test_that("union on two RDDs", { union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, c(as.list(nums), mockFile)) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") rdd<- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") unlink(fileName) }) @@ -52,14 +52,14 @@ test_that("union on two RDDs", { test_that("cogroup on two RDDs", { rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) - cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) actual <- collect(cogroup.rdd) - expect_equal(actual, + expect_equal(actual, list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list())))) - + rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4))) rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3))) - cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) actual <- collect(cogroup.rdd) expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3)))) @@ -71,31 +71,31 @@ test_that("zipPartitions() on RDDs", { rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 - actual <- collect(zipPartitions(rdd1, rdd2, rdd3, + actual <- collect(zipPartitions(rdd1, rdd2, rdd3, func = function(x, y, z) { list(list(x, y, z))} )) expect_equal(actual, list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) - + mockFile = c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + rdd <- textFile(sc, fileName, 1) - actual <- collect(zipPartitions(rdd, rdd, + actual <- collect(zipPartitions(rdd, rdd, func = function(x, y) { list(paste(x, y, sep = "\n")) })) expected <- list(paste(mockFile, mockFile, sep = "\n")) expect_equal(actual, expected) - + rdd1 <- parallelize(sc, 0:1, 1) - actual <- collect(zipPartitions(rdd1, rdd, + actual <- collect(zipPartitions(rdd1, rdd, func = function(x, y) { list(x + nchar(y)) })) expected <- list(0:1 + nchar(mockFile)) expect_equal(actual, expected) - + rdd <- map(rdd, function(x) { x }) - actual <- collect(zipPartitions(rdd, rdd1, + actual <- collect(zipPartitions(rdd, rdd1, func = function(x, y) { list(y + nchar(x)) })) expect_equal(actual, expected) - + unlink(fileName) }) diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R new file mode 100644 index 0000000000000..30b05c1a2afcd --- /dev/null +++ b/R/pkg/inst/tests/test_client.R @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in client.R") + +test_that("adding spark-testing-base as a package works", { + args <- generateSparkSubmitArgs("", "", "", "", + "holdenk:spark-testing-base:1.3.0_0.0.5") + expect_equal(gsub("[[:space:]]", "", args), + gsub("[[:space:]]", "", + "--packages holdenk:spark-testing-base:1.3.0_0.0.5")) +}) + +test_that("no package specified doesn't add packages flag", { + args <- generateSparkSubmitArgs("", "", "", "", "") + expect_equal(gsub("[[:space:]]", "", args), + "") +}) diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R index e4aab37436a74..513bbc8e62059 100644 --- a/R/pkg/inst/tests/test_context.R +++ b/R/pkg/inst/tests/test_context.R @@ -48,3 +48,10 @@ test_that("rdd GC across sparkR.stop", { count(rdd3) count(rdd4) }) + +test_that("job group functions can be called", { + sc <- sparkR.init() + setJobGroup(sc, "groupId", "job description", TRUE) + cancelJobGroup(sc, "groupId") + clearJobGroup(sc) +}) diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R new file mode 100644 index 0000000000000..cc1faeabffe30 --- /dev/null +++ b/R/pkg/inst/tests/test_includeJAR.R @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +context("include an external JAR in SparkContext") + +runScript <- function() { + sparkHome <- Sys.getenv("SPARK_HOME") + sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" + jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) + scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") + submitPath <- file.path(sparkHome, "bin/spark-submit") + res <- system2(command = submitPath, + args = c(jarPath, scriptPath), + stdout = TRUE) + tail(res, 2) +} + +test_that("sparkJars tag in SparkContext", { + testOutput <- runScript() + helloTest <- testOutput[1] + expect_equal(helloTest, "Hello, Dave") + basicFunction <- testOutput[2] + expect_equal(basicFunction, "4") +}) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R index fff028657db37..2552127cc547f 100644 --- a/R/pkg/inst/tests/test_parallelize_collect.R +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -57,7 +57,7 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { strListRDD2) for (rdd in rdds) { - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(.hasSlot(rdd, "jrdd") && inherits(rdd@jrdd, "jobj") && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 03207353c31c6..b79692873cec3 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -33,9 +33,9 @@ test_that("get number of partitions in RDD", { }) test_that("first on RDD", { - expect_true(first(rdd) == 1) + expect_equal(first(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) - expect_true(first(newrdd) == 2) + expect_equal(first(newrdd), 2) }) test_that("count and length on RDD", { @@ -477,7 +477,7 @@ test_that("cartesian() on RDDs", { list(1, 1), list(1, 2), list(1, 3), list(2, 1), list(2, 2), list(2, 3), list(3, 1), list(3, 2), list(3, 3))) - + # test case where one RDD is empty emptyRdd <- parallelize(sc, list()) actual <- collect(cartesian(rdd, emptyRdd)) @@ -486,7 +486,7 @@ test_that("cartesian() on RDDs", { mockFile = c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + rdd <- textFile(sc, fileName) actual <- collect(cartesian(rdd, rdd)) expected <- list( @@ -495,7 +495,7 @@ test_that("cartesian() on RDDs", { list("Spark is pretty.", "Spark is pretty."), list("Spark is pretty.", "Spark is awesome.")) expect_equal(sortKeyValueList(actual), expected) - + rdd1 <- parallelize(sc, 0:1) actual <- collect(cartesian(rdd1, rdd)) expect_equal(sortKeyValueList(actual), @@ -504,11 +504,11 @@ test_that("cartesian() on RDDs", { list(0, "Spark is awesome."), list(1, "Spark is pretty."), list(1, "Spark is awesome."))) - + rdd1 <- map(rdd, function(x) { x }) actual <- collect(cartesian(rdd, rdd1)) expect_equal(sortKeyValueList(actual), expected) - + unlink(fileName) }) @@ -669,13 +669,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), + list(2, list(NULL, 4)), list(3, list(3, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), + list("a", list(3, 1)), list("c", list(1, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -683,13 +685,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), + list(3, list(NULL, 3)), list(4, list(NULL, 4))))) rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), + list("d", list(NULL, 4)), list("c", list(NULL, 3))))) }) test_that("sortByKey() on pairwise RDDs", { @@ -760,7 +764,7 @@ test_that("collectAsMap() on a pairwise RDD", { }) test_that("show()", { - rdd <- parallelize(sc, list(1:10)) + rdd <- parallelize(sc, list(1:10)) expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R index d7dedda553c56..adf0b91d25fe9 100644 --- a/R/pkg/inst/tests/test_shuffle.R +++ b/R/pkg/inst/tests/test_shuffle.R @@ -106,39 +106,39 @@ test_that("aggregateByKey", { zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } - aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) - + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + actual <- collect(aggregatedRDD) - + expected <- list(list(1, list(3, 2)), list(2, list(7, 2))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) # test aggregateByKey for string keys rdd <- parallelize(sc, list(list("a", 1), list("a", 2), list("b", 3), list("b", 4))) - + zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } - aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) actual <- collect(aggregatedRDD) - + expected <- list(list("a", list(3, 2)), list("b", list(7, 2))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) }) -test_that("foldByKey", { +test_that("foldByKey", { # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) - + actual <- collect(folded) - + expected <- list(list(2L, 101), list(1L, 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) # test foldByKey for double keys folded <- foldByKey(doubleRdd, 0, "+", 2L) - + actual <- collect(folded) expected <- list(list(1.5, 199), list(2.5, 101)) @@ -146,15 +146,15 @@ test_that("foldByKey", { # test foldByKey for string keys stringKeyPairs <- list(list("a", -1), list("b", 100), list("b", 1), list("a", 200)) - + stringKeyRDD <- parallelize(sc, stringKeyPairs) folded <- foldByKey(stringKeyRDD, 0, "+", 2L) - + actual <- collect(folded) - + expected <- list(list("b", 101), list("a", 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - + # test foldByKey for empty pair RDD rdd <- parallelize(sc, list()) folded <- foldByKey(rdd, 0, "+", 2L) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 99c28830c6237..b0ea38854304e 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -19,11 +19,19 @@ library(testthat) context("SparkSQL functions") +# Utility function for easily checking the values of a StructField +checkStructField <- function(actual, expectedName, expectedType, expectedNullable) { + expect_equal(class(actual), "structField") + expect_equal(actual$name(), expectedName) + expect_equal(actual$dataType.toString(), expectedType) + expect_equal(actual$nullable(), expectedNullable) +} + # Tests for SparkSQL functions in SparkR sc <- sparkR.init() -sqlCtx <- sparkRSQL.init(sc) +sqlContext <- sparkRSQL.init(sc) mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", @@ -32,6 +40,15 @@ jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") writeLines(mockLines, jsonPath) +# For test nafunctions, like dropna(), fillna(),... +mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", + "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", + "{\"name\":\"David\",\"age\":60,\"height\":null}", + "{\"name\":\"Amy\",\"age\":null,\"height\":null}", + "{\"name\":null,\"age\":null,\"height\":null}") +jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesNa, jsonPathNa) + test_that("infer types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") @@ -43,9 +60,10 @@ test_that("infer types", { list(type = 'array', elementType = "integer", containsNull = TRUE)) expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) - expect_equal(infer_type(list(a = 1L, b = "2")), - structType(structField(x = "a", type = "integer", nullable = TRUE), - structField(x = "b", type = "string", nullable = TRUE))) + testStruct <- infer_type(list(a = 1L, b = "2")) + expect_equal(class(testStruct), "structType") + checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) + checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() assign("a", 1L, envir = e) expect_equal(infer_type(e), @@ -55,83 +73,120 @@ test_that("infer types", { test_that("structType and structField", { testField <- structField("a", "string") - expect_true(inherits(testField, "structField")) - expect_true(testField$name() == "a") + expect_is(testField, "structField") + expect_equal(testField$name(), "a") expect_true(testField$nullable()) - + testSchema <- structType(testField, structField("b", "integer")) - expect_true(inherits(testSchema, "structType")) - expect_true(inherits(testSchema$fields()[[2]], "structField")) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType") + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") }) test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) - df <- createDataFrame(sqlCtx, rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - df <- createDataFrame(sqlCtx, rdd) - expect_true(inherits(df, "DataFrame")) + df <- createDataFrame(sqlContext, rdd) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) - df <- createDataFrame(sqlCtx, rdd, schema) - expect_true(inherits(df, "DataFrame")) + df <- createDataFrame(sqlContext, rdd, schema) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) - df <- createDataFrame(sqlCtx, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + df <- createDataFrame(sqlContext, rdd) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) +test_that("convert NAs to null type in DataFrames", { + rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4L) + + l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1L) + expect_true(is.na(collect(df)[2, "y"])) + + rdd <- parallelize(sc, list(list(1, 2), list(NA, 4))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4) + + l <- data.frame(x = 1, y = c(1, NA_real_, 3)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1) + expect_true(is.na(collect(df)[2, "y"])) + + l <- list("a", "b", NA, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list("a", "b", NA_character_, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list(TRUE, FALSE, NA, TRUE) + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], TRUE) +}) + test_that("toDF", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- toDF(rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) test_that("create DataFrame from list or data.frame", { l <- list(list(1, 2), list(3, 4)) - df <- createDataFrame(sqlCtx, l, c("a", "b")) + df <- createDataFrame(sqlContext, l, c("a", "b")) expect_equal(columns(df), c("a", "b")) l <- list(list(a=1, b=2), list(a=3, b=4)) - df <- createDataFrame(sqlCtx, l) + df <- createDataFrame(sqlContext, l) expect_equal(columns(df), c("a", "b")) a <- 1:3 b <- c("a", "b", "c") ldf <- data.frame(a, b) - df <- createDataFrame(sqlCtx, ldf) + df <- createDataFrame(sqlContext, ldf) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) expect_equal(count(df), 3) @@ -142,7 +197,7 @@ test_that("create DataFrame from list or data.frame", { test_that("create DataFrame with different data types", { l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), f = as.POSIXct("2015-03-15 12:13:14.056")) - df <- createDataFrame(sqlCtx, list(l)) + df <- createDataFrame(sqlContext, list(l)) expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), c("d", "string"), c("e", "date"), c("f", "timestamp"))) expect_equal(count(df), 1) @@ -154,7 +209,7 @@ test_that("create DataFrame with different data types", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) -# df <- createDataFrame(sqlCtx, list(l), c("a", "b", "c", "d")) +# df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) # expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), # c("c", "map"), c("d", "struct"))) # expect_equal(count(df), 1) @@ -163,102 +218,102 @@ test_that("create DataFrame with different data types", { #}) test_that("jsonFile() on a local file returns a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + df <- jsonFile(sqlContext, jsonPath) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) }) test_that("jsonRDD() on a RDD with json string", { rdd <- parallelize(sc, mockLines) - expect_true(count(rdd) == 3) - df <- jsonRDD(sqlCtx, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_equal(count(rdd), 3) + df <- jsonRDD(sqlContext, rdd) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) rdd2 <- flatMap(rdd, function(x) c(x, x)) - df <- jsonRDD(sqlCtx, rdd2) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 6) + df <- jsonRDD(sqlContext, rdd2) + expect_is(df, "DataFrame") + expect_equal(count(df), 6) }) test_that("test cache, uncache and clearCache", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - cacheTable(sqlCtx, "table1") - uncacheTable(sqlCtx, "table1") - clearCache(sqlCtx) - dropTempTable(sqlCtx, "table1") + cacheTable(sqlContext, "table1") + uncacheTable(sqlContext, "table1") + clearCache(sqlContext) + dropTempTable(sqlContext, "table1") }) test_that("test tableNames and tables", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - expect_true(length(tableNames(sqlCtx)) == 1) - df <- tables(sqlCtx) - expect_true(count(df) == 1) - dropTempTable(sqlCtx, "table1") + expect_equal(length(tableNames(sqlContext)), 1) + df <- tables(sqlContext) + expect_equal(count(df), 1) + dropTempTable(sqlContext, "table1") }) test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - newdf <- sql(sqlCtx, "SELECT * FROM table1 where name = 'Michael'") - expect_true(inherits(newdf, "DataFrame")) - expect_true(count(newdf) == 1) - dropTempTable(sqlCtx, "table1") + newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") + expect_is(newdf, "DataFrame") + expect_equal(count(newdf), 1) + dropTempTable(sqlContext, "table1") }) test_that("insertInto() on a registered table", { - df <- loadDF(sqlCtx, jsonPath, "json") - saveDF(df, parquetPath, "parquet", "overwrite") - dfParquet <- loadDF(sqlCtx, parquetPath, "parquet") + df <- read.df(sqlContext, jsonPath, "json") + write.df(df, parquetPath, "parquet", "overwrite") + dfParquet <- read.df(sqlContext, parquetPath, "parquet") lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") writeLines(lines, jsonPath2) - df2 <- loadDF(sqlCtx, jsonPath2, "json") - saveDF(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- loadDF(sqlCtx, parquetPath2, "parquet") + df2 <- read.df(sqlContext, jsonPath2, "json") + write.df(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- read.df(sqlContext, parquetPath2, "parquet") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1") - expect_true(count(sql(sqlCtx, "select * from table1")) == 5) - expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Michael") - dropTempTable(sqlCtx, "table1") + expect_equal(count(sql(sqlContext, "select * from table1")), 5) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") + dropTempTable(sqlContext, "table1") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_true(count(sql(sqlCtx, "select * from table1")) == 2) - expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Bob") - dropTempTable(sqlCtx, "table1") + expect_equal(count(sql(sqlContext, "select * from table1")), 2) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") + dropTempTable(sqlContext, "table1") }) test_that("table() returns a new DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - tabledf <- table(sqlCtx, "table1") - expect_true(inherits(tabledf, "DataFrame")) - expect_true(count(tabledf) == 3) - dropTempTable(sqlCtx, "table1") + tabledf <- table(sqlContext, "table1") + expect_is(tabledf, "DataFrame") + expect_equal(count(tabledf), 3) + dropTempTable(sqlContext, "table1") }) test_that("toRDD() returns an RRDD", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testRDD <- toRDD(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(count(testRDD) == 3) + expect_is(testRDD, "RDD") + expect_equal(count(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) - expect_true(inherits(unioned, "RDD")) - expect_true(SparkR:::getSerializedMode(unioned) == "byte") - expect_true(collect(unioned)[[2]]$name == "Andy") + expect_is(unioned, "RDD") + expect_equal(SparkR:::getSerializedMode(unioned), "byte") + expect_equal(collect(unioned)[[2]]$name, "Andy") }) test_that("union on mixed serialization types correctly returns a byte RRDD", { @@ -274,70 +329,70 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { writeLines(textLines, textPath) textRDD <- textFile(sc, textPath) - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) - expect_true(inherits(unionByte, "RDD")) - expect_true(SparkR:::getSerializedMode(unionByte) == "byte") - expect_true(collect(unionByte)[[1]] == 1) - expect_true(collect(unionByte)[[12]]$name == "Andy") + expect_is(unionByte, "RDD") + expect_equal(SparkR:::getSerializedMode(unionByte), "byte") + expect_equal(collect(unionByte)[[1]], 1) + expect_equal(collect(unionByte)[[12]]$name, "Andy") unionString <- unionRDD(textRDD, dfRDD) - expect_true(inherits(unionString, "RDD")) - expect_true(SparkR:::getSerializedMode(unionString) == "byte") - expect_true(collect(unionString)[[1]] == "Michael") - expect_true(collect(unionString)[[5]]$name == "Andy") + expect_is(unionString, "RDD") + expect_equal(SparkR:::getSerializedMode(unionString), "byte") + expect_equal(collect(unionString)[[1]], "Michael") + expect_equal(collect(unionString)[[5]]$name, "Andy") }) test_that("objectFile() works with row serialization", { objectPath <- tempfile(pattern="spark-test", fileext=".tmp") - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) dfRDD <- toRDD(df) saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) - expect_true(inherits(objectIn, "RDD")) + expect_is(objectIn, "RDD") expect_equal(SparkR:::getSerializedMode(objectIn), "byte") expect_equal(collect(objectIn)[[2]]$age, 30) }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 row }) - expect_true(inherits(testRDD, "RDD")) + expect_is(testRDD, "RDD") collected <- collect(testRDD) - expect_true(collected[[1]]$name == "Michael") - expect_true(collected[[2]]$newCol == "35") + expect_equal(collected[[1]]$name, "Michael") + expect_equal(collected[[2]]$newCol, 35) }) test_that("collect() returns a data.frame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) - expect_true(names(rdf)[1] == "age") - expect_true(nrow(rdf) == 3) - expect_true(ncol(rdf) == 2) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 3) + expect_equal(ncol(rdf), 2) }) test_that("limit() returns DataFrame with the correct number of rows", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) dfLimited <- limit(df, 2) - expect_true(inherits(dfLimited, "DataFrame")) - expect_true(count(dfLimited) == 2) + expect_is(dfLimited, "DataFrame") + expect_equal(count(dfLimited), 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { - df <- jsonFile(sqlCtx, jsonPath) - expect_true(nrow(collect(df)) == nrow(take(df, 10))) - expect_true(ncol(collect(df)) == ncol(take(df, 10))) + df <- jsonFile(sqlContext, jsonPath) + expect_equal(nrow(collect(df)), nrow(take(df, 10))) + expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) -test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { - df <- jsonFile(sqlCtx, jsonPath) +test_that("multiple pipeline transformations result in an RDD with the correct values", { + df <- jsonFile(sqlContext, jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 row @@ -346,15 +401,15 @@ test_that("multiple pipeline transformations starting with a DataFrame result in row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE row }) - expect_true(inherits(second, "RDD")) - expect_true(count(second) == 3) - expect_true(collect(second)[[2]]$age == 35) + expect_is(second, "RDD") + expect_equal(count(second), 3) + expect_equal(collect(second)[[2]]$age, 35) expect_true(collect(second)[[2]]$testCol) expect_false(collect(second)[[3]]$testCol) }) test_that("cache(), persist(), and unpersist() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_false(df@env$isCached) cache(df) expect_true(df@env$isCached) @@ -373,38 +428,38 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testSchema <- schema(df) - expect_true(length(testSchema$fields()) == 2) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") - expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") - expect_true(testSchema$fields()[[1]]$name() == "age") + expect_equal(length(testSchema$fields()), 2) + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") + expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string") + expect_equal(testSchema$fields()[[1]]$name(), "age") testTypes <- dtypes(df) - expect_true(length(testTypes[[1]]) == 2) - expect_true(testTypes[[1]][1] == "age") + expect_equal(length(testTypes[[1]]), 2) + expect_equal(testTypes[[1]][1], "age") testCols <- columns(df) - expect_true(length(testCols) == 2) - expect_true(testCols[2] == "name") + expect_equal(length(testCols), 2) + expect_equal(testCols[2], "name") testNames <- names(df) - expect_true(length(testNames) == 2) - expect_true(testNames[2] == "name") + expect_equal(length(testNames), 2) + expect_equal(testNames[2], "name") }) test_that("head() and first() return the correct data", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) - expect_true(nrow(testHead) == 3) - expect_true(ncol(testHead) == 2) + expect_equal(nrow(testHead), 3) + expect_equal(ncol(testHead), 2) testHead2 <- head(df, 2) - expect_true(nrow(testHead2) == 2) - expect_true(ncol(testHead2) == 2) + expect_equal(nrow(testHead2), 2) + expect_equal(ncol(testHead2), 2) testFirst <- first(df) - expect_true(nrow(testFirst) == 1) + expect_equal(nrow(testFirst), 1) }) test_that("distinct() on DataFrames", { @@ -415,18 +470,18 @@ test_that("distinct() on DataFrames", { jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlCtx, jsonPathWithDup) + df <- jsonFile(sqlContext, jsonPathWithDup) uniques <- distinct(df) - expect_true(inherits(uniques, "DataFrame")) - expect_true(count(uniques) == 3) + expect_is(uniques, "DataFrame") + expect_equal(count(uniques), 3) }) -test_that("sampleDF on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) - sampled <- sampleDF(df, FALSE, 1.0) +test_that("sample on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) - expect_true(inherits(sampled, "DataFrame")) - sampled2 <- sampleDF(df, FALSE, 0.1) + expect_is(sampled, "DataFrame") + sampled2 <- sample(df, FALSE, 0.1) expect_true(count(sampled2) < 3) # Also test sample_frac @@ -435,16 +490,16 @@ test_that("sampleDF on a DataFrame", { }) test_that("select operators", { - df <- select(jsonFile(sqlCtx, jsonPath), "name", "age") - expect_true(inherits(df$name, "Column")) - expect_true(inherits(df[[2]], "Column")) - expect_true(inherits(df[["age"]], "Column")) + df <- select(jsonFile(sqlContext, jsonPath), "name", "age") + expect_is(df$name, "Column") + expect_is(df[[2]], "Column") + expect_is(df[["age"]], "Column") - expect_true(inherits(df[,1], "DataFrame")) + expect_is(df[,1], "DataFrame") expect_equal(columns(df[,1]), c("name")) expect_equal(columns(df[,"age"]), c("age")) df2 <- df[,c("age", "name")] - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(columns(df2), c("age", "name")) df$age2 <- df$age @@ -461,48 +516,61 @@ test_that("select operators", { }) test_that("select with column", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) df1 <- select(df, "name") - expect_true(columns(df1) == c("name")) - expect_true(count(df1) == 3) + expect_equal(columns(df1), c("name")) + expect_equal(count(df1), 3) df2 <- select(df, df$age) - expect_true(columns(df2) == c("age")) - expect_true(count(df2) == 3) + expect_equal(columns(df2), c("age")) + expect_equal(count(df2), 3) }) test_that("selectExpr() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") - expect_true(names(selected) == "(age * 2)") + expect_equal(names(selected), "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) selected2 <- selectExpr(df, "name as newName", "abs(age) as age") expect_equal(names(selected2), c("newName", "age")) - expect_true(count(selected2) == 3) + expect_equal(count(selected2), 3) }) test_that("column calculation", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) - expect_true(names(d) == c("age2")) + expect_equal(names(d), c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) -test_that("load() from json file", { - df <- loadDF(sqlCtx, jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) +test_that("read.df() from json file", { + df <- read.df(sqlContext, jsonPath, "json") + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + + # Check if we can apply a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(sqlContext, jsonPath, "json", schema) + expect_is(df1, "DataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Run the same with loadDF + df2 <- loadDF(sqlContext, jsonPath, "json", schema) + expect_is(df2, "DataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) }) -test_that("save() as parquet file", { - df <- loadDF(sqlCtx, jsonPath, "json") - saveDF(df, parquetPath, "parquet", mode="overwrite") - df2 <- loadDF(sqlCtx, parquetPath, "parquet") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) +test_that("write.df() as parquet file", { + df <- read.df(sqlContext, jsonPath, "json") + write.df(df, parquetPath, "parquet", mode="overwrite") + df2 <- read.df(sqlContext, parquetPath, "parquet") + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("test HiveContext", { @@ -512,17 +580,17 @@ test_that("test HiveContext", { skip("Hive is not build with SparkSQL, skipped") }) df <- createExternalTable(hiveCtx, "json", jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) df2 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") saveAsTable(df, "json", "json", "append", path = jsonPath2) df3 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df3, "DataFrame")) - expect_true(count(df3) == 6) + expect_is(df3, "DataFrame") + expect_equal(count(df3), 6) }) test_that("column operators", { @@ -530,6 +598,7 @@ test_that("column operators", { c2 <- (- c + 1 - 2) * 3 / 4.0 c3 <- (c + c2 - c2) * c2 %% c2 c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) + c5 <- c2 ^ c3 ^ c4 }) test_that("column functions", { @@ -538,10 +607,33 @@ test_that("column functions", { c3 <- lower(c) + upper(c) + first(c) + last(c) c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") c5 <- n(c) + n_distinct(c) + c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) + c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c) + c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) + c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) + c9 <- toDegrees(c) + toRadians(c) +}) + +test_that("column binary mathfunctions", { + lines <- c("{\"a\":1, \"b\":5}", + "{\"a\":2, \"b\":6}", + "{\"a\":3, \"b\":7}", + "{\"a\":4, \"b\":8}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + df <- jsonFile(sqlContext, jsonPathWithDup) + expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) }) test_that("string operators", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_equal(count(where(df, like(df$name, "A%"))), 1) expect_equal(count(where(df, startsWith(df$name, "A"))), 1) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") @@ -549,71 +641,81 @@ test_that("string operators", { }) test_that("group by", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) df1 <- agg(df, age2 = max(df$age)) - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) expect_equal(columns(df1), c("age2")) gd <- groupBy(df, "name") - expect_true(inherits(gd, "GroupedData")) + expect_is(gd, "GroupedData") df2 <- count(gd) - expect_true(inherits(df2, "DataFrame")) - expect_true(3 == count(df2)) + expect_is(df2, "DataFrame") + expect_equal(3, count(df2)) # Also test group_by, summarize, mean gd1 <- group_by(df, "name") - expect_true(inherits(gd1, "GroupedData")) + expect_is(gd1, "GroupedData") df_summarized <- summarize(gd, mean_age = mean(df$age)) - expect_true(inherits(df_summarized, "DataFrame")) - expect_true(3 == count(df_summarized)) + expect_is(df_summarized, "DataFrame") + expect_equal(3, count(df_summarized)) df3 <- agg(gd, age = "sum") - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) df3 <- agg(gd, age = sum(df$age)) - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) expect_equal(columns(df3), c("name", "age")) df4 <- sum(gd, "age") - expect_true(inherits(df4, "DataFrame")) - expect_true(3 == count(df4)) - expect_true(3 == count(mean(gd, "age"))) - expect_true(3 == count(max(gd, "age"))) + expect_is(df4, "DataFrame") + expect_equal(3, count(df4)) + expect_equal(3, count(mean(gd, "age"))) + expect_equal(3, count(max(gd, "age"))) }) test_that("arrange() and orderBy() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) sorted <- arrange(df, df$age) - expect_true(collect(sorted)[1,2] == "Michael") + expect_equal(collect(sorted)[1,2], "Michael") sorted2 <- arrange(df, "name") - expect_true(collect(sorted2)[2,"age"] == 19) + expect_equal(collect(sorted2)[2,"age"], 19) sorted3 <- orderBy(df, asc(df$age)) expect_true(is.na(first(sorted3)$age)) - expect_true(collect(sorted3)[2, "age"] == 19) + expect_equal(collect(sorted3)[2, "age"], 19) sorted4 <- orderBy(df, desc(df$name)) - expect_true(first(sorted4)$name == "Michael") - expect_true(collect(sorted4)[3,"name"] == "Andy") + expect_equal(first(sorted4)$name, "Michael") + expect_equal(collect(sorted4)[3,"name"], "Andy") }) test_that("filter() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) filtered <- filter(df, "age > 20") - expect_true(count(filtered) == 1) - expect_true(collect(filtered)$name == "Andy") + expect_equal(count(filtered), 1) + expect_equal(collect(filtered)$name, "Andy") filtered2 <- where(df, df$name != "Michael") - expect_true(count(filtered2) == 2) - expect_true(collect(filtered2)$age[2] == 19) + expect_equal(count(filtered2), 2) + expect_equal(collect(filtered2)$age[2], 19) + + # test suites for %in% + filtered3 <- filter(df, "age in (19)") + expect_equal(count(filtered3), 1) + filtered4 <- filter(df, "age in (19, 30)") + expect_equal(count(filtered4), 2) + filtered5 <- where(df, df$age %in% c(19)) + expect_equal(count(filtered5), 1) + filtered6 <- where(df, df$age %in% c(19, 30)) + expect_equal(count(filtered6), 2) }) test_that("join() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", "{\"name\":\"Andy\", \"test\": \"no\"}", @@ -621,125 +723,232 @@ test_that("join() on a DataFrame", { "{\"name\":\"Bob\", \"test\": \"yes\"}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines2, jsonPath2) - df2 <- jsonFile(sqlCtx, jsonPath2) + df2 <- jsonFile(sqlContext, jsonPath2) joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) - expect_true(count(joined) == 12) + expect_equal(count(joined), 12) joined2 <- join(df, df2, df$name == df2$name) expect_equal(names(joined2), c("age", "name", "name", "test")) - expect_true(count(joined2) == 3) + expect_equal(count(joined2), 3) joined3 <- join(df, df2, df$name == df2$name, "right_outer") expect_equal(names(joined3), c("age", "name", "name", "test")) - expect_true(count(joined3) == 4) + expect_equal(count(joined3), 4) expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) joined4 <- select(join(df, df2, df$name == df2$name, "outer"), alias(df$age + 5, "newAge"), df$name, df2$test) expect_equal(names(joined4), c("newAge", "name", "test")) - expect_true(count(joined4) == 4) + expect_equal(count(joined4), 4) expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) }) test_that("toJSON() returns an RDD of the correct values", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testRDD <- toJSON(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_is(testRDD, "RDD") + expect_equal(SparkR:::getSerializedMode(testRDD), "string") expect_equal(collect(testRDD)[[1]], mockLines[1]) }) test_that("showDF()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) s <- capture.output(showDF(df)) - expect_output(s , "+----+-------+\n| age| name|\n+----+-------+\n|null|Michael|\n| 30| Andy|\n| 19| Justin|\n+----+-------+\n") + expected <- paste("+----+-------+\n", + "| age| name|\n", + "+----+-------+\n", + "|null|Michael|\n", + "| 30| Andy|\n", + "| 19| Justin|\n", + "+----+-------+\n", sep="") + expect_output(s , expected) }) test_that("isLocal()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_false(isLocal(df)) }) test_that("unionAll(), except(), and intersect() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPath2) - df2 <- loadDF(sqlCtx, jsonPath2, "json") + df2 <- read.df(sqlContext, jsonPath2, "json") unioned <- arrange(unionAll(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(unioned) == 6) - expect_true(first(unioned)$name == "Michael") + expect_is(unioned, "DataFrame") + expect_equal(count(unioned), 6) + expect_equal(first(unioned)$name, "Michael") excepted <- arrange(except(df, df2), desc(df$age)) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(excepted) == 2) - expect_true(first(excepted)$name == "Justin") + expect_is(unioned, "DataFrame") + expect_equal(count(excepted), 2) + expect_equal(first(excepted)$name, "Justin") intersected <- arrange(intersect(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(intersected) == 1) - expect_true(first(intersected)$name == "Andy") + expect_is(unioned, "DataFrame") + expect_equal(count(intersected), 1) + expect_equal(first(intersected)$name, "Andy") }) test_that("withColumn() and withColumnRenamed()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- withColumnRenamed(df, "age", "newerAge") - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) test_that("mutate() and rename()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- rename(df, newerAge = df$age) - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) -test_that("saveDF() on DataFrame and works with parquetFile", { - df <- jsonFile(sqlCtx, jsonPath) - saveDF(df, parquetPath, "parquet", mode="overwrite") - parquetDF <- parquetFile(sqlCtx, parquetPath) - expect_true(inherits(parquetDF, "DataFrame")) +test_that("write.df() on DataFrame and works with parquetFile", { + df <- jsonFile(sqlContext, jsonPath) + write.df(df, parquetPath, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlContext, parquetPath) + expect_is(parquetDF, "DataFrame") expect_equal(count(df), count(parquetDF)) }) test_that("parquetFile works with multiple input paths", { - df <- jsonFile(sqlCtx, jsonPath) - saveDF(df, parquetPath, "parquet", mode="overwrite") + df <- jsonFile(sqlContext, jsonPath) + write.df(df, parquetPath, "parquet", mode="overwrite") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - saveDF(df, parquetPath2, "parquet", mode="overwrite") - parquetDF <- parquetFile(sqlCtx, parquetPath, parquetPath2) - expect_true(inherits(parquetDF, "DataFrame")) - expect_true(count(parquetDF) == count(df)*2) + write.df(df, parquetPath2, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) + expect_is(parquetDF, "DataFrame") + expect_equal(count(parquetDF), count(df)*2) }) test_that("describe() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) stats <- describe(df, "age") - expect_true(collect(stats)[1, "summary"] == "count") - expect_true(collect(stats)[2, "age"] == 24.5) - expect_true(collect(stats)[3, "age"] == 5.5) + expect_equal(collect(stats)[1, "summary"], "count") + expect_equal(collect(stats)[2, "age"], "24.5") + expect_equal(collect(stats)[3, "age"], "5.5") stats <- describe(df) - expect_true(collect(stats)[4, "name"] == "Andy") - expect_true(collect(stats)[5, "age"] == 30.0) + expect_equal(collect(stats)[4, "name"], "Andy") + expect_equal(collect(stats)[5, "age"], "30") +}) + +test_that("dropna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # drop with columns + + expected <- rows[!is.na(rows$name),] + actual <- collect(dropna(df, cols = "name")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age),] + actual <- collect(dropna(df, cols = "age")) + row.names(expected) <- row.names(actual) + # identical on two dataframes does not work here. Don't know why. + # use identical on all columns as a workaround. + expect_identical(expected$age, actual$age) + expect_identical(expected$height, actual$height) + expect_identical(expected$name, actual$name) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_identical(expected, actual) + + # drop with how + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] + actual <- collect(dropna(df, "all")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df, "any")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, "any", cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height),] + actual <- collect(dropna(df, "all", cols = c("age", "height"))) + expect_identical(expected, actual) + + # drop with threshold + + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] + actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[as.integer(!is.na(rows$age)) + + as.integer(!is.na(rows$height)) + + as.integer(!is.na(rows$name)) >= 3,] + actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_identical(expected, actual) +}) + +test_that("fillna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # fill with value + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + actual <- collect(fillna(df, 50.6)) + expect_identical(expected, actual) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown")) + expect_identical(expected, actual) + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + actual <- collect(fillna(df, 50.6, "age")) + expect_identical(expected, actual) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown", c("age", "name"))) + expect_identical(expected, actual) + + # fill with named list + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) + expect_identical(expected, actual) }) unlink(parquetPath) unlink(jsonPath) +unlink(jsonPathNa) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R index 7f4c7c315d787..c2c724cdc762f 100644 --- a/R/pkg/inst/tests/test_take.R +++ b/R/pkg/inst/tests/test_take.R @@ -59,9 +59,8 @@ test_that("take() gives back the original elements in correct count and order", expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) - expect_true(length(take(strListRDD, 0)) == 0) - expect_true(length(take(strVectorRDD, 0)) == 0) - expect_true(length(take(numListRDD, 0)) == 0) - expect_true(length(take(numVectorRDD, 0)) == 0) + expect_equal(length(take(strListRDD, 0)), 0) + expect_equal(length(take(strVectorRDD, 0)), 0) + expect_equal(length(take(numListRDD, 0)), 0) + expect_equal(length(take(numVectorRDD, 0)), 0) }) - diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 6b87b4b3e0b08..58318dfef71ab 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -27,9 +27,9 @@ test_that("textFile() on a local file returns an RDD", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(count(rdd) > 0) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName) }) @@ -58,7 +58,7 @@ test_that("textFile() word count works as expected", { expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1), list("Spark", 2)) expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) - + unlink(fileName) }) @@ -115,13 +115,13 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { saveAsTextFile(counts, fileName2) rdd <- textFile(sc, fileName2) - + output <- collect(rdd) expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), list("is", 2)) expectedStr <- lapply(expected, function(x) { toString(x) }) expect_equal(sortKeyValueList(output), sortKeyValueList(expectedStr)) - + unlink(fileName1) unlink(fileName2) }) @@ -133,7 +133,7 @@ test_that("textFile() on multiple paths", { writeLines("Spark is awesome.", fileName2) rdd <- textFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1) unlink(fileName2) @@ -159,4 +159,3 @@ test_that("Pipelined operations on RDDs created using textFile", { unlink(fileName) }) - diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index 539e3a3c19df3..aa0d2a66b9082 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -43,13 +43,13 @@ test_that("serializeToBytes on RDD", { mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + text.rdd <- textFile(sc, fileName) - expect_true(getSerializedMode(text.rdd) == "string") + expect_equal(getSerializedMode(text.rdd), "string") ser.rdd <- serializeToBytes(text.rdd) expect_equal(collect(ser.rdd), as.list(mockFile)) - expect_true(getSerializedMode(ser.rdd) == "byte") - + expect_equal(getSerializedMode(ser.rdd), "byte") + unlink(fileName) }) @@ -64,7 +64,7 @@ test_that("cleanClosure on R functions", { expect_equal(actual, y) actual <- get("g", envir = env, inherits = FALSE) expect_equal(actual, g) - + # Test for nested enclosures and package variables. env2 <- new.env() funcEnv <- new.env(parent = env2) @@ -106,7 +106,7 @@ test_that("cleanClosure on R functions", { expect_equal(length(ls(env)), 1) actual <- get("y", envir = env, inherits = FALSE) expect_equal(actual, y) - + # Test for function (and variable) definitions. f <- function(x) { g <- function(y) { y * 2 } @@ -115,7 +115,7 @@ test_that("cleanClosure on R functions", { newF <- cleanClosure(f) env <- environment(newF) expect_equal(length(ls(env)), 0) # "y" and "g" should not be included. - + # Test for overriding variables in base namespace (Issue: SparkR-196). nums <- as.list(1:10) rdd <- parallelize(sc, nums, 2L) @@ -128,7 +128,7 @@ test_that("cleanClosure on R functions", { actual <- collect(lapply(rdd, f)) expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6))) expect_equal(actual, expected) - + # Test for broadcast variables. a <- matrix(nrow=10, ncol=10, data=rnorm(100)) aBroadcast <- broadcast(sc, a) diff --git a/R/pkg/src/Makefile b/R/pkg/src-native/Makefile similarity index 100% rename from R/pkg/src/Makefile rename to R/pkg/src-native/Makefile diff --git a/R/pkg/src/Makefile.win b/R/pkg/src-native/Makefile.win similarity index 100% rename from R/pkg/src/Makefile.win rename to R/pkg/src-native/Makefile.win diff --git a/R/pkg/src/string_hash_code.c b/R/pkg/src-native/string_hash_code.c similarity index 100% rename from R/pkg/src/string_hash_code.c rename to R/pkg/src-native/string_hash_code.c diff --git a/README.md b/README.md index 9c09d40e2bdae..380422ca00dbe 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@ Spark is a fast and general cluster computing system for Big Data. It provides high-level APIs in Scala, Java, and Python, and an optimized engine that supports general computation graphs for data analysis. It also supports a -rich set of higher-level tools including Spark SQL for SQL and structured -data processing, MLlib for machine learning, GraphX for graph processing, +rich set of higher-level tools including Spark SQL for SQL and DataFrames, +MLlib for machine learning, GraphX for graph processing, and Spark Streaming for stream processing. @@ -22,7 +22,7 @@ This README file only contains basic setup instructions. Spark is built using [Apache Maven](http://maven.apache.org/). To build Spark and its example programs, run: - mvn -DskipTests clean package + build/mvn -DskipTests clean package (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at @@ -43,7 +43,7 @@ Try the following command, which should return 1000: Alternatively, if you prefer Python, you can use the Python shell: ./bin/pyspark - + And run the following command, which should also return 1000: >>> sc.parallelize(range(1000)).count() @@ -58,9 +58,9 @@ To run one of them, use `./bin/run-example [params]`. For example: will run the Pi example locally. You can set the MASTER environment variable when running examples to submit -examples to a cluster. This can be a mesos:// or spark:// URL, -"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run -locally with one thread, or "local[N]" to run locally with N threads. You +examples to a cluster. This can be a mesos:// or spark:// URL, +"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run +locally with one thread, or "local[N]" to run locally with N threads. You can also use an abbreviated class name if the class is in the `examples` package. For instance: @@ -75,7 +75,7 @@ can be run using: ./dev/run-tests -Please see the guidance on how to +Please see the guidance on how to [run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools). ## A Note About Hadoop Versions diff --git a/assembly/pom.xml b/assembly/pom.xml index 626c8577e31fe..e9c6d26ccddc7 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index 1f3dec91314f2..ed5c37e595a96 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala index ccb262a4ee02a..fb10d734ac74b 100644 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.bagel -import org.scalatest.{BeforeAndAfter, FunSuite, Assertions} +import org.scalatest.{BeforeAndAfter, Assertions} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -27,7 +27,7 @@ import org.apache.spark.storage.StorageLevel class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable -class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { +class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { var sc: SparkContext = _ diff --git a/bin/pyspark b/bin/pyspark index 8acad6113797d..f9dbddfa53560 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -17,24 +17,10 @@ # limitations under the License. # -# Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" source "$SPARK_HOME"/bin/load-spark-env.sh - -function usage() { - if [ -n "$1" ]; then - echo $1 - fi - echo "Usage: ./bin/pyspark [options]" 1>&2 - "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit $2 -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage -fi +export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` # executable, while the worker would still be launched using PYSPARK_PYTHON. @@ -90,11 +76,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 - else - exec "$PYSPARK_DRIVER_PYTHON" $1 - fi + exec "$PYSPARK_DRIVER_PYTHON" -m $1 exit fi diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 09b4149c2a439..45e9e3def5121 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -21,6 +21,7 @@ rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. call %SPARK_HOME%\bin\load-spark-env.cmd +set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options] rem Figure out which Python to use. if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( diff --git a/bin/spark-class b/bin/spark-class index c49d97ce5cf25..2b59e5df5736f 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -16,18 +16,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -set -e # Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" . "$SPARK_HOME"/bin/load-spark-env.sh -if [ -z "$1" ]; then - echo "Usage: spark-class []" 1>&2 - exit 1 -fi - # Find the java binary if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" @@ -64,24 +58,6 @@ fi SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" -# Verify that versions of java used to build the jars and run Spark are compatible -if [ -n "$JAVA_HOME" ]; then - JAR_CMD="$JAVA_HOME/bin/jar" -else - JAR_CMD="jar" -fi - -if [ $(command -v "$JAR_CMD") ] ; then - jar_error_check=$("$JAR_CMD" -tf "$SPARK_ASSEMBLY_JAR" nonexistent/class/path 2>&1) - if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then - echo "Loading Spark jar with '$JAR_CMD' failed. " 1>&2 - echo "This is likely because Spark was compiled with Java 7 and run " 1>&2 - echo "with Java 6. (see SPARK-1703). Please use Java 7 to run Spark " 1>&2 - echo "or build Spark with Java 6." 1>&2 - exit 1 - fi -fi - LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR" # Add the launcher build dir to the classpath if requested. @@ -98,9 +74,4 @@ CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@") - -if [ "${CMD[0]}" = "usage" ]; then - "${CMD[@]}" -else - exec "${CMD[@]}" -fi +exec "${CMD[@]}" diff --git a/bin/spark-shell b/bin/spark-shell index b3761b5e1375b..a6dc863d83fc6 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -29,20 +29,7 @@ esac set -o posix export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -usage() { - if [ -n "$1" ]; then - echo "$1" - fi - echo "Usage: ./bin/spark-shell [options]" - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit "$2" -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage "" 0 -fi +export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" # SPARK-4161: scala does not assume use of the java classpath, # so we need to add the "-Dscala.usejavacp=true" flag manually. We diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 00fd30fa38d36..251309d67f860 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -18,12 +18,7 @@ rem limitations under the License. rem set SPARK_HOME=%~dp0.. - -echo "%*" | findstr " \<--help\> \<-h\>" >nul -if %ERRORLEVEL% equ 0 ( - call :usage - exit /b 0 -) +set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options] rem SPARK-4161: scala does not assume use of the java classpath, rem so we need to add the "-Dscala.usejavacp=true" flag manually. We @@ -37,16 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" ( set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" :run_shell -call %SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* -set SPARK_ERROR_LEVEL=%ERRORLEVEL% -if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( - call :usage - exit /b 1 -) -exit /b %SPARK_ERROR_LEVEL% - -:usage -echo %SPARK_LAUNCHER_USAGE_ERROR% -echo "Usage: .\bin\spark-shell.cmd [options]" >&2 -call %SPARK_HOME%\bin\spark-submit2.cmd --help 2>&1 | findstr /V "Usage" 1>&2 -goto :eof +%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* diff --git a/bin/spark-sql b/bin/spark-sql index ca1729f4cfcb4..4ea7bc6e39c07 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -17,41 +17,6 @@ # limitations under the License. # -# -# Shell script for starting the Spark SQL CLI - -# Enter posix mode for bash -set -o posix - -# NOTE: This exact class name is matched downstream by SparkSubmit. -# Any changes need to be reflected there. -export CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" - -# Figure out where Spark is installed export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -function usage { - if [ -n "$1" ]; then - echo "$1" - fi - echo "Usage: ./bin/spark-sql [options] [cli option]" - pattern="usage" - pattern+="\|Spark assembly has been built with Hive" - pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" - pattern+="\|Spark Command: " - pattern+="\|--help" - pattern+="\|=======" - - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - echo - echo "CLI options:" - "$FWDIR"/bin/spark-class "$CLASS" --help 2>&1 | grep -v "$pattern" 1>&2 - exit "$2" -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage "" 0 -fi - -exec "$FWDIR"/bin/spark-submit --class "$CLASS" "$@" +export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" +exec "$FWDIR"/bin/spark-submit --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver "$@" diff --git a/bin/spark-submit b/bin/spark-submit index 0e0afe71a0f05..255378b0f077c 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -22,16 +22,4 @@ SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" # disable randomized hash for string in Python 3.3+ export PYTHONHASHSEED=0 -# Only define a usage function if an upstream script hasn't done so. -if ! type -t usage >/dev/null 2>&1; then - usage() { - if [ -n "$1" ]; then - echo "$1" - fi - "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit --help - exit "$2" - } - export -f usage -fi - exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index d3fc4a5cc3f6e..651376e526928 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -24,15 +24,4 @@ rem disable randomized hash for string in Python 3.3+ set PYTHONHASHSEED=0 set CLASS=org.apache.spark.deploy.SparkSubmit -call %~dp0spark-class2.cmd %CLASS% %* -set SPARK_ERROR_LEVEL=%ERRORLEVEL% -if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( - call :usage - exit /b 1 -) -exit /b %SPARK_ERROR_LEVEL% - -:usage -echo %SPARK_LAUNCHER_USAGE_ERROR% -call %SPARK_HOME%\bin\spark-class2.cmd %CLASS% --help -goto :eof +%~dp0spark-class2.cmd %CLASS% %* diff --git a/bin/sparkR b/bin/sparkR index 8c918e2b09aef..464c29f369424 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -17,23 +17,7 @@ # limitations under the License. # -# Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" - source "$SPARK_HOME"/bin/load-spark-env.sh - -function usage() { - if [ -n "$1" ]; then - echo $1 - fi - echo "Usage: ./bin/sparkR [options]" 1>&2 - "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit $2 -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage -fi - +export _SPARK_CMD_USAGE="Usage: ./bin/sparkR [options]" exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" diff --git a/build/mvn b/build/mvn index 3561110a4c019..f62f61ee1c416 100755 --- a/build/mvn +++ b/build/mvn @@ -69,11 +69,14 @@ install_app() { # Install maven under the build/ folder install_mvn() { + local MVN_VERSION="3.3.3" + install_app \ - "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \ - "apache-maven-3.2.5-bin.tar.gz" \ - "apache-maven-3.2.5/bin/mvn" - MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn" + "http://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \ + "apache-maven-${MVN_VERSION}-bin.tar.gz" \ + "apache-maven-${MVN_VERSION}/bin/mvn" + + MVN_BIN="${_DIR}/apache-maven-${MVN_VERSION}/bin/mvn" } # Install zinc under the build/ folder @@ -105,28 +108,23 @@ install_scala() { SCALA_LIBRARY="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-library.jar" } -# Determines if a given application is already installed. If not, will attempt -# to install -## Arg1 - application name -## Arg2 - Alternate path to local install under build/ dir -check_and_install_app() { - # create the local environment variable in uppercase - local app_bin="`echo $1 | awk '{print toupper(\$0)}'`_BIN" - # some black magic to set the generated app variable (i.e. MVN_BIN) into the - # environment - eval "${app_bin}=`which $1 2>/dev/null`" - - if [ -z "`which $1 2>/dev/null`" ]; then - install_$1 - fi -} - # Setup healthy defaults for the Zinc port if none were provided from # the environment ZINC_PORT=${ZINC_PORT:-"3030"} -# Check and install all applications necessary to build Spark -check_and_install_app "mvn" +# Check for the `--force` flag dictating that `mvn` should be downloaded +# regardless of whether the system already has a `mvn` install +if [ "$1" == "--force" ]; then + FORCE_MVN=1 + shift +fi + +# Install Maven if necessary +MVN_BIN="$(command -v mvn)" + +if [ ! "$MVN_BIN" -o -n "$FORCE_MVN" ]; then + install_mvn +fi # Install the proper version of Scala and Zinc for the build install_zinc @@ -148,5 +146,7 @@ fi # Set any `mvn` options if not already present export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} +echo "Using \`mvn\` from path: $MVN_BIN" + # Last, call the `mvn` command as usual ${MVN_BIN} "$@" diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 2e0cb5db170ac..7f17bc7eea4f5 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -4,7 +4,7 @@ # divided into instances which correspond to internal components. # Each instance can be configured to report its metrics to one or more sinks. # Accepted values for [instance] are "master", "worker", "executor", "driver", -# and "applications". A wild card "*" can be used as an instance name, in +# and "applications". A wildcard "*" can be used as an instance name, in # which case all instances will inherit the supplied property. # # Within an instance, a "source" specifies a particular set of grouped metrics. @@ -32,7 +32,7 @@ # name (see examples below). # 2. Some sinks involve a polling period. The minimum allowed polling period # is 1 second. -# 3. Wild card properties can be overridden by more specific properties. +# 3. Wildcard properties can be overridden by more specific properties. # For example, master.sink.console.period takes precedence over # *.sink.console.period. # 4. A metrics specific configuration @@ -47,6 +47,13 @@ # instance master and applications. MetricsServlet may not be configured by self. # +## List of available common sources and their properties. + +# org.apache.spark.metrics.source.JvmSource +# Note: Currently, JvmSource is the only available common source +# to add additionaly to an instance, to enable this, +# set the "class" option to its fully qulified class name (see examples below) + ## List of available sinks and their properties. # org.apache.spark.metrics.sink.ConsoleSink @@ -126,9 +133,9 @@ #*.sink.slf4j.class=org.apache.spark.metrics.sink.Slf4jSink # Polling period for Slf4JSink -#*.sink.sl4j.period=1 +#*.sink.slf4j.period=1 -#*.sink.sl4j.unit=minutes +#*.sink.slf4j.unit=minutes # Enable jvm source for instance master, worker, driver and executor diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 43c4288912b18..192d3ae091134 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -22,7 +22,7 @@ # - SPARK_EXECUTOR_INSTANCES, Number of workers to start (Default: 2) # - SPARK_EXECUTOR_CORES, Number of cores for the workers (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Worker (e.g. 1000M, 2G) (Default: 1G) -# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb) +# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 1G) # - SPARK_YARN_APP_NAME, The name of your application (Default: Spark) # - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: ‘default’) # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. diff --git a/core/pom.xml b/core/pom.xml index 262a3320db106..558cc3fb9f2f3 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -69,16 +69,6 @@ org.apache.hadoop hadoop-client - - - javax.servlet - servlet-api - - - org.codehaus.jackson - jackson-mapper-asl - - org.apache.spark @@ -338,6 +328,12 @@ org.seleniumhq.selenium selenium-java + + + com.google.guava + guava + + test @@ -346,9 +342,19 @@ xml-apis test + + org.hamcrest + hamcrest-core + test + + + org.hamcrest + hamcrest-library + test + org.mockito - mockito-all + mockito-core test @@ -367,9 +373,15 @@ test - org.spark-project + net.razorvine pyrolite 4.4 + + + net.razorvine + serpent + + net.sf.py4j diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index 646496f313507..fa9acf0a15b88 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -17,23 +17,7 @@ package org.apache.spark; -import org.apache.spark.scheduler.SparkListener; -import org.apache.spark.scheduler.SparkListenerApplicationEnd; -import org.apache.spark.scheduler.SparkListenerApplicationStart; -import org.apache.spark.scheduler.SparkListenerBlockManagerAdded; -import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved; -import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate; -import org.apache.spark.scheduler.SparkListenerExecutorAdded; -import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate; -import org.apache.spark.scheduler.SparkListenerExecutorRemoved; -import org.apache.spark.scheduler.SparkListenerJobEnd; -import org.apache.spark.scheduler.SparkListenerJobStart; -import org.apache.spark.scheduler.SparkListenerStageCompleted; -import org.apache.spark.scheduler.SparkListenerStageSubmitted; -import org.apache.spark.scheduler.SparkListenerTaskEnd; -import org.apache.spark.scheduler.SparkListenerTaskGettingResult; -import org.apache.spark.scheduler.SparkListenerTaskStart; -import org.apache.spark.scheduler.SparkListenerUnpersistRDD; +import org.apache.spark.scheduler.*; /** * Java clients should extend this class instead of implementing @@ -94,4 +78,8 @@ public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { } @Override public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } + } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index fbc5666959055..1214d05ba6063 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -112,4 +112,10 @@ public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { onEvent(executorRemoved); } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { + onEvent(blockUpdated); + } + } diff --git a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java new file mode 100644 index 0000000000000..0399abc63c235 --- /dev/null +++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +import scala.reflect.ClassTag; + +import org.apache.spark.annotation.Private; +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + * Our shuffle write path doesn't actually use this serializer (since we end up calling the + * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + * around this, we pass a dummy no-op serializer. + */ +@Private +public final class DummySerializerInstance extends SerializerInstance { + + public static final DummySerializerInstance INSTANCE = new DummySerializerInstance(); + + private DummySerializerInstance() { } + + @Override + public SerializationStream serializeStream(final OutputStream s) { + return new SerializationStream() { + @Override + public void flush() { + // Need to implement this because DiskObjectWriter uses it to flush the compression stream + try { + s.flush(); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + } + + @Override + public SerializationStream writeObject(T t, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public void close() { + // Need to implement this because DiskObjectWriter uses it to close the compression stream + try { + s.close(); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + } + }; + } + + @Override + public ByteBuffer serialize(T t, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public DeserializationStream deserializeStream(InputStream s) { + throw new UnsupportedOperationException(); + } + + @Override + public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public T deserialize(ByteBuffer bytes, ClassTag ev1) { + throw new UnsupportedOperationException(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java new file mode 100644 index 0000000000000..0b8b604e18494 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import com.google.common.io.Closeables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.Partitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.storage.*; +import org.apache.spark.util.Utils; + +/** + * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path + * writes incoming records to separate files, one file per reduce partition, then concatenates these + * per-partition files to form a single output file, regions of which are served to reducers. + * Records are not buffered in memory. This is essentially identical to + * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format + * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}. + *

+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it + * simultaneously opens separate serializers and file streams for all partitions. As a result, + * {@link SortShuffleManager} only selects this write path when + *

    + *
  • no Ordering is specified,
  • + *
  • no Aggregator is specific, and
  • + *
  • the number of partitions is less than + * spark.shuffle.sort.bypassMergeThreshold.
  • + *
+ * + * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was + * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details. + *

+ * There have been proposals to completely remove this code path; see SPARK-6026 for details. + */ +final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter { + + private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); + + private final int fileBufferSize; + private final boolean transferToEnabled; + private final int numPartitions; + private final BlockManager blockManager; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final Serializer serializer; + + /** Array of file writers, one for each partition */ + private DiskBlockObjectWriter[] partitionWriters; + + public BypassMergeSortShuffleWriter( + SparkConf conf, + BlockManager blockManager, + Partitioner partitioner, + ShuffleWriteMetrics writeMetrics, + Serializer serializer) { + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); + this.numPartitions = partitioner.numPartitions(); + this.blockManager = blockManager; + this.partitioner = partitioner; + this.writeMetrics = writeMetrics; + this.serializer = serializer; + } + + @Override + public void insertAll(Iterator> records) throws IOException { + assert (partitionWriters == null); + if (!records.hasNext()) { + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new DiskBlockObjectWriter[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open(); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime); + + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + } + + for (DiskBlockObjectWriter writer : partitionWriters) { + writer.commitAndClose(); + } + } + + @Override + public long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException { + // Track location of the partition starts in the output file + final long[] lengths = new long[numPartitions]; + if (partitionWriters == null) { + // We were passed an empty iterator + return lengths; + } + + final FileOutputStream out = new FileOutputStream(outputFile, true); + final long writeStartTime = System.nanoTime(); + boolean threwException = true; + try { + for (int i = 0; i < numPartitions; i++) { + final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file()); + boolean copyThrewException = true; + try { + lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) { + logger.error("Unable to delete file for partition {}", i); + } + } + threwException = false; + } finally { + Closeables.close(out, threwException); + writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + } + partitionWriters = null; + return lengths; + } + + @Override + public void stop() throws IOException { + if (partitionWriters != null) { + try { + final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); + for (DiskBlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + writer.revertPartialWritesAndClose(); + if (!diskBlockManager.getFile(writer.blockId()).delete()) { + logger.error("Error while deleting file for block {}", writer.blockId()); + } + } + } finally { + partitionWriters = null; + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java new file mode 100644 index 0000000000000..656ea0401a144 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.IOException; + +import scala.Product2; +import scala.collection.Iterator; + +import org.apache.spark.annotation.Private; +import org.apache.spark.TaskContext; +import org.apache.spark.storage.BlockId; + +/** + * Interface for objects that {@link SortShuffleWriter} uses to write its output files. + */ +@Private +public interface SortShuffleFileWriter { + + void insertAll(Iterator> records) throws IOException; + + /** + * Write all the data added into this shuffle sorter into a file in the disk store. This is + * called by the SortShuffleWriter and can go through an efficient path of just concatenating + * binary files if we decided to avoid merge-sorting. + * + * @param blockId block ID to write to. The index file will be blockId.name + ".index". + * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException; + + void stop() throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java new file mode 100644 index 0000000000000..4ee6a82c0423e --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +/** + * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. + *

+ * Within the long, the data is laid out as follows: + *

+ *   [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ * 
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that + * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the + * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this + * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task. + *

+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this + * optimization to future work as it will require more careful design to ensure that addresses are + * properly aligned (e.g. by padding records). + */ +final class PackedRecordPointer { + + static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes + + /** + * The maximum partition identifier that can be encoded. Note that partition ids start from 0. + */ + static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215 + + /** Bit mask for the lower 40 bits of a long. */ + private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1; + + /** Bit mask for the upper 24 bits of a long */ + private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS; + + /** Bit mask for the lower 27 bits of a long. */ + private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1; + + /** Bit mask for the upper 13 bits of a long */ + private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; + + /** + * Pack a record address and partition id into a single word. + * + * @param recordPointer a record pointer encoded by TaskMemoryManager. + * @param partitionId a shuffle partition id (maximum value of 2^24). + * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class. + */ + public static long packPointer(long recordPointer, int partitionId) { + assert (partitionId <= MAXIMUM_PARTITION_ID); + // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page. + // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses. + final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24; + final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS); + return (((long) partitionId) << 40) | compressedAddress; + } + + private long packedRecordPointer; + + public void set(long packedRecordPointer) { + this.packedRecordPointer = packedRecordPointer; + } + + public int getPartitionId() { + return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40); + } + + public long getRecordPointer() { + final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS; + final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS; + return pageNumber | offsetInPage; + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java new file mode 100644 index 0000000000000..7bac0dc0bbeb6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.File; + +import org.apache.spark.storage.TempShuffleBlockId; + +/** + * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}. + */ +final class SpillInfo { + final long[] partitionLengths; + final File file; + final TempShuffleBlockId blockId; + + public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) { + this.partitionLengths = new long[numPartitions]; + this.file = file; + this.blockId = blockId; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java new file mode 100644 index 0000000000000..1d460432be9ff --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.File; +import java.io.IOException; +import java.util.LinkedList; + +import scala.Tuple2; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +/** + * An external sorter that is specialized for sort-based shuffle. + *

+ * Incoming records are appended to data pages. When all records have been inserted (or when the + * current thread's shuffle memory limit is reached), the in-memory records are sorted according to + * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then + * written to a single output file (or multiple files, if we've spilled). The format of the output + * files is the same as the format of the final output file written by + * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are + * written as a single serialized, compressed stream that can be read with a new decompression and + * deserialization stream. + *

+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its + * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a + * specialized merge procedure that avoids extra serialization/deserialization. + */ +final class UnsafeShuffleExternalSorter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); + + private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; + @VisibleForTesting + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + @VisibleForTesting + static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; + + private final int initialSize; + private final int numPartitions; + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext; + private final ShuffleWriteMetrics writeMetrics; + + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList(); + + private final LinkedList spills = new LinkedList(); + + // These variables are reset after spilling: + private UnsafeShuffleInMemorySorter sorter; + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + private long freeSpaceInCurrentPage = 0; + + public UnsafeShuffleExternalSorter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + int initialSize, + int numPartitions, + SparkConf conf, + ShuffleWriteMetrics writeMetrics) throws IOException { + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.initialSize = initialSize; + this.numPartitions = numPartitions; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + + this.writeMetrics = writeMetrics; + initializeForWriting(); + } + + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeForWriting() throws IOException { + // TODO: move this sizing calculation logic into a static method of sorter: + final long memoryRequested = initialSize * 8L; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); + } + + this.sorter = new UnsafeShuffleInMemorySorter(initialSize); + } + + /** + * Sorts the in-memory records and writes the sorted records to an on-disk file. + * This method does not free the sort data structures. + * + * @param isLastFile if true, this indicates that we're writing the final output file and that the + * bytes written should be counted towards shuffle spill metrics rather than + * shuffle write metrics. + */ + private void writeSortedFile(boolean isLastFile) throws IOException { + + final ShuffleWriteMetrics writeMetricsToUse; + + if (isLastFile) { + // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. + writeMetricsToUse = writeMetrics; + } else { + // We're spilling, so bytes written should be counted towards spill rather than write. + // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count + // them towards shuffle bytes written. + writeMetricsToUse = new ShuffleWriteMetrics(); + } + + // This call performs the actual sort. + final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = + sorter.getSortedIterator(); + + // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this + // after SPARK-5581 is fixed. + DiskBlockObjectWriter writer; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. This array does not need to be large enough to hold a single + // record; + final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + + // Because this output will be read during shuffle, its compression codec must be controlled by + // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use + // createTempShuffleBlock here; see SPARK-3426 for more details. + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = spilledFileInfo._2(); + final TempShuffleBlockId blockId = spilledFileInfo._1(); + final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); + + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + final SerializerInstance ser = DummySerializerInstance.INSTANCE; + + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); + + int currentPartition = -1; + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final int partition = sortedRecords.packedRecordPointer.getPartitionId(); + assert (partition >= currentPartition); + if (partition != currentPartition) { + // Switch to the new partition + if (currentPartition != -1) { + writer.commitAndClose(); + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + } + currentPartition = partition; + writer = + blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); + } + + final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); + final Object recordPage = memoryManager.getPage(recordPointer); + final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer); + int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + 4; // skip over record length + while (dataRemaining > 0) { + final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); + PlatformDependent.copyMemory( + recordPage, + recordReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + toTransfer); + writer.write(writeBuffer, 0, toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + } + writer.recordWritten(); + } + + if (writer != null) { + writer.commitAndClose(); + // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted, + // then the file might be empty. Note that it might be better to avoid calling + // writeSortedFile() in that case. + if (currentPartition != -1) { + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + spills.add(spillInfo); + } + } + + if (!isLastFile) { // i.e. this is a spill file + // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records + // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter + // relies on its `recordWritten()` method being called in order to trigger periodic updates to + // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that + // counter at a higher-level, then the in-progress metrics for records written and bytes + // written would get out of sync. + // + // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter; + // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those + // metrics to the true write metrics here. The reason for performing this copying is so that + // we can avoid reporting spilled bytes as shuffle write bytes. + // + // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. + // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. + // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. + writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten()); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten()); + } + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + @VisibleForTesting + void spill() throws IOException { + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spills.size(), + spills.size() > 1 ? " times" : " time"); + + writeSortedFile(false); + final long sorterMemoryUsage = sorter.getMemoryUsage(); + sorter = null; + shuffleMemoryManager.release(sorterMemoryUsage); + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + + initializeForWriting(); + } + + private long getMemoryUsage() { + return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + } + + private long freeMemory() { + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + memoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); + } + allocatedPages.clear(); + currentPage = null; + currentPagePosition = -1; + freeSpaceInCurrentPage = 0; + return memoryFreed; + } + + /** + * Force all memory and spill files to be deleted; called by shuffle error-handling code. + */ + public void cleanupAfterError() { + freeMemory(); + for (SpillInfo spill : spills) { + if (spill.file.exists() && !spill.file.delete()) { + logger.error("Unable to delete spill file {}", spill.file.getPath()); + } + } + if (sorter != null) { + shuffleMemoryManager.release(sorter.getMemoryUsage()); + sorter = null; + } + } + + /** + * Checks whether there is enough space to insert a new record into the sorter. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + + * @return true if the record can be inserted without requiring more allocations, false otherwise. + */ + private boolean haveSpaceForRecord(int requiredSpace) { + assert (requiredSpace > 0); + return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + */ + private void allocateSpaceForRecord(int requiredSpace) throws IOException { + if (!sorter.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to expand sort pointer array"); + final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + sorter.expandPointerArray(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + } + } + if (requiredSpace > freeSpaceInCurrentPage) { + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, + freeSpaceInCurrentPage); + // TODO: we should track metrics on the amount of space wasted when we roll over to a new page + // without using the free space at the end of the current page. We should also do this for + // BytesToBytesMap. + if (requiredSpace > PAGE_SIZE) { + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquired < PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + } + } + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = PAGE_SIZE; + allocatedPages.add(currentPage); + } + } + } + + /** + * Write a record to the shuffle sorter. + */ + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + int partitionId) throws IOException { + // Need 4 bytes to store the record length. + final int totalSpaceRequired = lengthInBytes + 4; + if (!haveSpaceForRecord(totalSpaceRequired)) { + allocateSpaceForRecord(totalSpaceRequired); + } + + final long recordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object dataPageBaseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); + currentPagePosition += 4; + freeSpaceInCurrentPage -= 4; + PlatformDependent.copyMemory( + recordBaseObject, + recordBaseOffset, + dataPageBaseObject, + currentPagePosition, + lengthInBytes); + currentPagePosition += lengthInBytes; + freeSpaceInCurrentPage -= lengthInBytes; + sorter.insertRecord(recordAddress, partitionId); + } + + /** + * Close the sorter, causing any buffered data to be sorted and written out to disk. + * + * @return metadata for the spill files written by this sorter. If no records were ever inserted + * into this sorter, then this will return an empty array. + * @throws IOException + */ + public SpillInfo[] closeAndGetSpills() throws IOException { + try { + if (sorter != null) { + // Do not count the final file towards the spill count. + writeSortedFile(true); + freeMemory(); + } + return spills.toArray(new SpillInfo[spills.size()]); + } catch (IOException e) { + cleanupAfterError(); + throw e; + } + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java new file mode 100644 index 0000000000000..5bab501da9364 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.util.Comparator; + +import org.apache.spark.util.collection.Sorter; + +final class UnsafeShuffleInMemorySorter { + + private final Sorter sorter; + private static final class SortComparator implements Comparator { + @Override + public int compare(PackedRecordPointer left, PackedRecordPointer right) { + return left.getPartitionId() - right.getPartitionId(); + } + } + private static final SortComparator SORT_COMPARATOR = new SortComparator(); + + /** + * An array of record pointers and partition ids that have been encoded by + * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating + * records. + */ + private long[] pointerArray; + + /** + * The position in the pointer array where new records can be inserted. + */ + private int pointerArrayInsertPosition = 0; + + public UnsafeShuffleInMemorySorter(int initialSize) { + assert (initialSize > 0); + this.pointerArray = new long[initialSize]; + this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); + } + + public void expandPointerArray() { + final long[] oldArray = pointerArray; + // Guard against overflow: + final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; + pointerArray = new long[newLength]; + System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + } + + public boolean hasSpaceForAnotherRecord() { + return pointerArrayInsertPosition + 1 < pointerArray.length; + } + + public long getMemoryUsage() { + return pointerArray.length * 8L; + } + + /** + * Inserts a record to be sorted. + * + * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to + * certain pointer compression techniques used by the sorter, the sort can + * only operate on pointers that point to locations in the first + * {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page. + * @param partitionId the partition id, which must be less than or equal to + * {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}. + */ + public void insertRecord(long recordPointer, int partitionId) { + if (!hasSpaceForAnotherRecord()) { + if (pointerArray.length == Integer.MAX_VALUE) { + throw new IllegalStateException("Sort pointer array has reached maximum size"); + } else { + expandPointerArray(); + } + } + pointerArray[pointerArrayInsertPosition] = + PackedRecordPointer.packPointer(recordPointer, partitionId); + pointerArrayInsertPosition++; + } + + /** + * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining. + */ + public static final class UnsafeShuffleSorterIterator { + + private final long[] pointerArray; + private final int numRecords; + final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); + private int position = 0; + + public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) { + this.numRecords = numRecords; + this.pointerArray = pointerArray; + } + + public boolean hasNext() { + return position < numRecords; + } + + public void loadNext() { + packedRecordPointer.set(pointerArray[position]); + position++; + } + } + + /** + * Return an iterator over record pointers in sorted order. + */ + public UnsafeShuffleSorterIterator getSortedIterator() { + sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); + return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java new file mode 100644 index 0000000000000..a66d74ee44782 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import org.apache.spark.util.collection.SortDataFormat; + +final class UnsafeShuffleSortDataFormat extends SortDataFormat { + + public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat(); + + private UnsafeShuffleSortDataFormat() { } + + @Override + public PackedRecordPointer getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public PackedRecordPointer newKey() { + return new PackedRecordPointer(); + } + + @Override + public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) { + reuse.set(data[pos]); + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + final long temp = data[pos0]; + data[pos0] = data[pos1]; + data[pos1] = temp; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos] = src[srcPos]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos, dst, dstPos, length); + } + + @Override + public long[] allocate(int length) { + return new long[length]; + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java new file mode 100644 index 0000000000000..764578b181422 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.*; +import java.nio.channels.FileChannel; +import java.util.Iterator; +import javax.annotation.Nullable; + +import scala.Option; +import scala.Product2; +import scala.collection.JavaConversions; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.ByteStreams; +import com.google.common.io.Closeables; +import com.google.common.io.Files; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.*; +import org.apache.spark.annotation.Private; +import org.apache.spark.io.CompressionCodec; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.LZFCompressionCodec; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TimeTrackingOutputStream; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +@Private +public class UnsafeShuffleWriter extends ShuffleWriter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); + + private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + + @VisibleForTesting + static final int INITIAL_SORT_BUFFER_SIZE = 4096; + + private final BlockManager blockManager; + private final IndexShuffleBlockResolver shuffleBlockResolver; + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final SerializerInstance serializer; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final int shuffleId; + private final int mapId; + private final TaskContext taskContext; + private final SparkConf sparkConf; + private final boolean transferToEnabled; + + private MapStatus mapStatus = null; + private UnsafeShuffleExternalSorter sorter = null; + + /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ + private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { + public MyByteArrayOutputStream(int size) { super(size); } + public byte[] getBuf() { return buf; } + } + + private MyByteArrayOutputStream serBuffer; + private SerializationStream serOutputStream; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; + + public UnsafeShuffleWriter( + BlockManager blockManager, + IndexShuffleBlockResolver shuffleBlockResolver, + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + UnsafeShuffleHandle handle, + int mapId, + TaskContext taskContext, + SparkConf sparkConf) throws IOException { + final int numPartitions = handle.dependency().partitioner().numPartitions(); + if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) { + throw new IllegalArgumentException( + "UnsafeShuffleWriter can only be used for shuffles with at most " + + UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions"); + } + this.blockManager = blockManager; + this.shuffleBlockResolver = shuffleBlockResolver; + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.mapId = mapId; + final ShuffleDependency dep = handle.dependency(); + this.shuffleId = dep.shuffleId(); + this.serializer = Serializer.getSerializer(dep.serializer()).newInstance(); + this.partitioner = dep.partitioner(); + this.writeMetrics = new ShuffleWriteMetrics(); + taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.taskContext = taskContext; + this.sparkConf = sparkConf; + this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); + open(); + } + + /** + * This convenience method should only be called in test code. + */ + @VisibleForTesting + public void write(Iterator> records) throws IOException { + write(JavaConversions.asScalaIterator(records)); + } + + @Override + public void write(scala.collection.Iterator> records) throws IOException { + // Keep track of success so we know if we ecountered an exception + // We do this rather than a standard try/catch/re-throw to handle + // generic throwables. + boolean success = false; + try { + while (records.hasNext()) { + insertRecordIntoSorter(records.next()); + } + closeAndWriteOutput(); + success = true; + } finally { + if (sorter != null) { + try { + sorter.cleanupAfterError(); + } catch (Exception e) { + // Only throw this error if we won't be masking another + // error. + if (success) { + throw e; + } else { + logger.error("In addition to a failure during writing, we failed during " + + "cleanup.", e); + } + } + } + } + } + + private void open() throws IOException { + assert (sorter == null); + sorter = new UnsafeShuffleExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + INITIAL_SORT_BUFFER_SIZE, + partitioner.numPartitions(), + sparkConf, + writeMetrics); + serBuffer = new MyByteArrayOutputStream(1024 * 1024); + serOutputStream = serializer.serializeStream(serBuffer); + } + + @VisibleForTesting + void closeAndWriteOutput() throws IOException { + serBuffer = null; + serOutputStream = null; + final SpillInfo[] spills = sorter.closeAndGetSpills(); + sorter = null; + final long[] partitionLengths; + try { + partitionLengths = mergeSpills(spills); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && ! spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); + } + } + } + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + } + + @VisibleForTesting + void insertRecordIntoSorter(Product2 record) throws IOException { + final K key = record._1(); + final int partitionId = partitioner.getPartition(key); + serBuffer.reset(); + serOutputStream.writeKey(key, OBJECT_CLASS_TAG); + serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); + serOutputStream.flush(); + + final int serializedRecordSize = serBuffer.size(); + assert (serializedRecordSize > 0); + + sorter.insertRecord( + serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + } + + @VisibleForTesting + void forceSorterToSpill() throws IOException { + assert (sorter != null); + sorter.spill(); + } + + /** + * Merge zero or more spill files together, choosing the fastest merging strategy based on the + * number of spills and the IO compression codec. + * + * @return the partition lengths in the merged file. + */ + private long[] mergeSpills(SpillInfo[] spills) throws IOException { + final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); + final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); + final boolean fastMergeEnabled = + sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); + final boolean fastMergeIsSupported = + !compressionEnabled || compressionCodec instanceof LZFCompressionCodec; + try { + if (spills.length == 0) { + new FileOutputStream(outputFile).close(); // Create an empty file + return new long[partitioner.numPartitions()]; + } else if (spills.length == 1) { + // Here, we don't need to perform any metrics updates because the bytes written to this + // output file would have already been counted as shuffle bytes written. + Files.move(spills[0].file, outputFile); + return spills[0].partitionLengths; + } else { + final long[] partitionLengths; + // There are multiple spills to merge, so none of these spill files' lengths were counted + // towards our shuffle write count or shuffle write time. If we use the slow merge path, + // then the final output file's size won't necessarily be equal to the sum of the spill + // files' sizes. To guard against this case, we look at the output file's actual size when + // computing shuffle bytes written. + // + // We allow the individual merge methods to report their own IO times since different merge + // strategies use different IO techniques. We count IO during merge towards the shuffle + // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" + // branch in ExternalSorter. + if (fastMergeEnabled && fastMergeIsSupported) { + // Compression is disabled or we are using an IO compression codec that supports + // decompression of concatenated compressed streams, so we can perform a fast spill merge + // that doesn't need to interpret the spilled bytes. + if (transferToEnabled) { + logger.debug("Using transferTo-based fast merge"); + partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); + } else { + logger.debug("Using fileStream-based fast merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + } + } else { + logger.debug("Using slow merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + } + // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has + // in-memory records, we write out the in-memory records to a file but do not count that + // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs + // to be counted as shuffle write, but this will lead to double-counting of the final + // SpillInfo's bytes. + writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length()); + writeMetrics.incShuffleBytesWritten(outputFile.length()); + return partitionLengths; + } + } catch (IOException e) { + if (outputFile.exists() && !outputFile.delete()) { + logger.error("Unable to delete output file {}", outputFile.getPath()); + } + throw e; + } + } + + /** + * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, + * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in + * cases where the IO compression codec does not support concatenation of compressed data, or in + * cases where users have explicitly disabled use of {@code transferTo} in order to work around + * kernel bugs. + * + * @param spills the spills to merge. + * @param outputFile the file to write the merged data to. + * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. + * @return the partition lengths in the merged file. + */ + private long[] mergeSpillsWithFileStream( + SpillInfo[] spills, + File outputFile, + @Nullable CompressionCodec compressionCodec) throws IOException { + assert (spills.length >= 2); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final InputStream[] spillInputStreams = new FileInputStream[spills.length]; + OutputStream mergedFileOutputStream = null; + + boolean threwException = true; + try { + for (int i = 0; i < spills.length; i++) { + spillInputStreams[i] = new FileInputStream(spills[i].file); + } + for (int partition = 0; partition < numPartitions; partition++) { + final long initialFileLength = outputFile.length(); + mergedFileOutputStream = + new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); + if (compressionCodec != null) { + mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); + } + + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = + new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + } + ByteStreams.copy(partitionInputStream, mergedFileOutputStream); + } + } + mergedFileOutputStream.flush(); + mergedFileOutputStream.close(); + partitionLengths[partition] = (outputFile.length() - initialFileLength); + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (InputStream stream : spillInputStreams) { + Closeables.close(stream, threwException); + } + Closeables.close(mergedFileOutputStream, threwException); + } + return partitionLengths; + } + + /** + * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes. + * This is only safe when the IO compression codec and serializer support concatenation of + * serialized streams. + * + * @return the partition lengths in the merged file. + */ + private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { + assert (spills.length >= 2); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final FileChannel[] spillInputChannels = new FileChannel[spills.length]; + final long[] spillInputChannelPositions = new long[spills.length]; + FileChannel mergedFileOutputChannel = null; + + boolean threwException = true; + try { + for (int i = 0; i < spills.length; i++) { + spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); + } + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. + mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); + + long bytesWrittenToMergedFile = 0; + for (int partition = 0; partition < numPartitions; partition++) { + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + long bytesToTransfer = partitionLengthInSpill; + final FileChannel spillInputChannel = spillInputChannels[i]; + final long writeStartTime = System.nanoTime(); + while (bytesToTransfer > 0) { + final long actualBytesTransferred = spillInputChannel.transferTo( + spillInputChannelPositions[i], + bytesToTransfer, + mergedFileOutputChannel); + spillInputChannelPositions[i] += actualBytesTransferred; + bytesToTransfer -= actualBytesTransferred; + } + writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + bytesWrittenToMergedFile += partitionLengthInSpill; + partitionLengths[partition] += partitionLengthInSpill; + } + } + // Check the position after transferTo loop to see if it is in the right position and raise an + // exception if it is incorrect. The position will not be increased to the expected length + // after calling transferTo in kernel version 2.6.32. This issue is described at + // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. + if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { + throw new IOException( + "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + + " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + + "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + + "to disable this NIO feature." + ); + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (int i = 0; i < spills.length; i++) { + assert(spillInputChannelPositions[i] == spills[i].file.length()); + Closeables.close(spillInputChannels[i], threwException); + } + Closeables.close(mergedFileOutputChannel, threwException); + } + return partitionLengths; + } + + @Override + public Option stop(boolean success) { + try { + if (stopping) { + return Option.apply(null); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + shuffleBlockResolver.removeDataByMap(shuffleId, mapId); + return Option.apply(null); + } + } + } finally { + if (sorter != null) { + // If sorter is non-null, then this implies that we called stop() in response to an error, + // so we need to clean up memory and spill files created by the sorter + sorter.cleanupAfterError(); + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java new file mode 100644 index 0000000000000..dc2aa30466cc6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage; + +import java.io.IOException; +import java.io.OutputStream; + +import org.apache.spark.annotation.Private; +import org.apache.spark.executor.ShuffleWriteMetrics; + +/** + * Intercepts write calls and tracks total time spent writing in order to update shuffle write + * metrics. Not thread safe. + */ +@Private +public final class TimeTrackingOutputStream extends OutputStream { + + private final ShuffleWriteMetrics writeMetrics; + private final OutputStream outputStream; + + public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) { + this.writeMetrics = writeMetrics; + this.outputStream = outputStream; + } + + @Override + public void write(int b) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void write(byte[] b) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b, off, len); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void flush() throws IOException { + final long startTime = System.nanoTime(); + outputStream.flush(); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void close() throws IOException { + final long startTime = System.nanoTime(); + outputStream.close(); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java new file mode 100644 index 0000000000000..45b78829e4cf7 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.annotation.Private; + +/** + * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific + * comparisons, such as lexicographic comparison for strings. + */ +@Private +public abstract class PrefixComparator { + public abstract int compare(long prefix1, long prefix2); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java new file mode 100644 index 0000000000000..438742565c51d --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import com.google.common.base.Charsets; +import com.google.common.primitives.Longs; +import com.google.common.primitives.UnsignedBytes; + +import org.apache.spark.annotation.Private; +import org.apache.spark.unsafe.types.UTF8String; + +@Private +public class PrefixComparators { + private PrefixComparators() {} + + public static final StringPrefixComparator STRING = new StringPrefixComparator(); + public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); + public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); + public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); + + public static final class StringPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + // TODO: can done more efficiently + byte[] a = Longs.toByteArray(aPrefix); + byte[] b = Longs.toByteArray(bPrefix); + for (int i = 0; i < 8; i++) { + int c = UnsignedBytes.compare(a[i], b[i]); + if (c != 0) return c; + } + return 0; + } + + public long computePrefix(byte[] bytes) { + if (bytes == null) { + return 0L; + } else { + byte[] padded = new byte[8]; + System.arraycopy(bytes, 0, padded, 0, Math.min(bytes.length, 8)); + return Longs.fromByteArray(padded); + } + } + + public long computePrefix(String value) { + return value == null ? 0L : computePrefix(value.getBytes(Charsets.UTF_8)); + } + + public long computePrefix(UTF8String value) { + return value == null ? 0L : computePrefix(value.getBytes()); + } + } + + /** + * Prefix comparator for all integral types (boolean, byte, short, int, long). + */ + public static final class IntegralPrefixComparator extends PrefixComparator { + @Override + public int compare(long a, long b) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public final long NULL_PREFIX = Long.MIN_VALUE; + } + + public static final class FloatPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + float a = Float.intBitsToFloat((int) aPrefix); + float b = Float.intBitsToFloat((int) bPrefix); + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public long computePrefix(float value) { + return Float.floatToIntBits(value) & 0xffffffffL; + } + + public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); + } + + public static final class DoublePrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public long computePrefix(double value) { + return Double.doubleToLongBits(value); + } + + public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY); + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java new file mode 100644 index 0000000000000..09e4258792204 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +/** + * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte + * prefix, this may simply return 0. + */ +public abstract class RecordComparator { + + /** + * Compare two records for order. + * + * @return a negative integer, zero, or a positive integer as the first record is less than, + * equal to, or greater than the second. + */ + public abstract int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java new file mode 100644 index 0000000000000..0c4ebde407cfc --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +final class RecordPointerAndKeyPrefix { + /** + * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * description of how these addresses are encoded. + */ + public long recordPointer; + + /** + * A key prefix, for use in comparisons. + */ + public long keyPrefix; +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java new file mode 100644 index 0000000000000..4d6731ee60af3 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; +import java.util.LinkedList; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +/** + * External sorter based on {@link UnsafeInMemorySorter}. + */ +public final class UnsafeExternalSorter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); + + private static final int PAGE_SIZE = 1 << 27; // 128 megabytes + @VisibleForTesting + static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; + + private final PrefixComparator prefixComparator; + private final RecordComparator recordComparator; + private final int initialSize; + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext; + private ShuffleWriteMetrics writeMetrics; + + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList(); + + // These variables are reset after spilling: + private UnsafeInMemorySorter sorter; + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + private long freeSpaceInCurrentPage = 0; + + private final LinkedList spillWriters = new LinkedList<>(); + + public UnsafeExternalSorter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + SparkConf conf) throws IOException { + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.initialSize = initialSize; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units + this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + initializeForWriting(); + } + + // TODO: metrics tracking + integration with shuffle write metrics + // need to connect the write metrics to task metrics so we count the spill IO somewhere. + + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeForWriting() throws IOException { + this.writeMetrics = new ShuffleWriteMetrics(); + // TODO: move this sizing calculation logic into a static method of sorter: + final long memoryRequested = initialSize * 8L * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); + } + + this.sorter = + new UnsafeInMemorySorter(memoryManager, recordComparator, prefixComparator, initialSize); + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + @VisibleForTesting + public void spill() throws IOException { + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spillWriters.size(), + spillWriters.size() > 1 ? " times" : " time"); + + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, + sorter.numRecords()); + spillWriters.add(spillWriter); + final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator(); + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final Object baseObject = sortedRecords.getBaseObject(); + final long baseOffset = sortedRecords.getBaseOffset(); + final int recordLength = sortedRecords.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); + } + spillWriter.close(); + final long sorterMemoryUsage = sorter.getMemoryUsage(); + sorter = null; + shuffleMemoryManager.release(sorterMemoryUsage); + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + initializeForWriting(); + } + + private long getMemoryUsage() { + return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + } + + public long freeMemory() { + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + memoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); + } + allocatedPages.clear(); + currentPage = null; + currentPagePosition = -1; + freeSpaceInCurrentPage = 0; + return memoryFreed; + } + + /** + * Checks whether there is enough space to insert a new record into the sorter. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + + * @return true if the record can be inserted without requiring more allocations, false otherwise. + */ + private boolean haveSpaceForRecord(int requiredSpace) { + assert (requiredSpace > 0); + return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + */ + private void allocateSpaceForRecord(int requiredSpace) throws IOException { + // TODO: merge these steps to first calculate total memory requirements for this insert, + // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the + // data page. + if (!sorter.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to expand sort pointer array"); + final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + sorter.expandPointerArray(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + } + } + + if (requiredSpace > freeSpaceInCurrentPage) { + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, + freeSpaceInCurrentPage); + // TODO: we should track metrics on the amount of space wasted when we roll over to a new page + // without using the free space at the end of the current page. We should also do this for + // BytesToBytesMap. + if (requiredSpace > PAGE_SIZE) { + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquired < PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + } + } + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = PAGE_SIZE; + allocatedPages.add(currentPage); + } + } + } + + /** + * Write a record to the sorter. + */ + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + long prefix) throws IOException { + // Need 4 bytes to store the record length. + final int totalSpaceRequired = lengthInBytes + 4; + if (!haveSpaceForRecord(totalSpaceRequired)) { + allocateSpaceForRecord(totalSpaceRequired); + } + + final long recordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object dataPageBaseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); + currentPagePosition += 4; + PlatformDependent.copyMemory( + recordBaseObject, + recordBaseOffset, + dataPageBaseObject, + currentPagePosition, + lengthInBytes); + currentPagePosition += lengthInBytes; + + sorter.insertRecord(recordAddress, prefix); + } + + public UnsafeSorterIterator getSortedIterator() throws IOException { + final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator(); + int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); + if (spillWriters.isEmpty()) { + return inMemoryIterator; + } else { + final UnsafeSorterSpillMerger spillMerger = + new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + spillMerger.addSpill(spillWriter.getReader(blockManager)); + } + spillWriters.clear(); + if (inMemoryIterator.hasNext()) { + spillMerger.addSpill(inMemoryIterator); + } + return spillMerger.getSortedIterator(); + } + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java new file mode 100644 index 0000000000000..fc34ad9cff369 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.Comparator; + +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.util.collection.Sorter; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records + * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm + * compares records, it will first compare the stored key prefixes; if the prefixes are not equal, + * then we do not need to traverse the record pointers to compare the actual records. Avoiding these + * random memory accesses improves cache hit rates. + */ +public final class UnsafeInMemorySorter { + + private static final class SortComparator implements Comparator { + + private final RecordComparator recordComparator; + private final PrefixComparator prefixComparator; + private final TaskMemoryManager memoryManager; + + SortComparator( + RecordComparator recordComparator, + PrefixComparator prefixComparator, + TaskMemoryManager memoryManager) { + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.memoryManager = memoryManager; + } + + @Override + public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { + final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix); + if (prefixComparisonResult == 0) { + final Object baseObject1 = memoryManager.getPage(r1.recordPointer); + final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + 4; // skip length + final Object baseObject2 = memoryManager.getPage(r2.recordPointer); + final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + 4; // skip length + return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2); + } else { + return prefixComparisonResult; + } + } + } + + private final TaskMemoryManager memoryManager; + private final Sorter sorter; + private final Comparator sortComparator; + + /** + * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ + private long[] pointerArray; + + /** + * The position in the sort buffer where new records can be inserted. + */ + private int pointerArrayInsertPosition = 0; + + public UnsafeInMemorySorter( + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + int initialSize) { + assert (initialSize > 0); + this.pointerArray = new long[initialSize * 2]; + this.memoryManager = memoryManager; + this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); + this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + } + + /** + * @return the number of records that have been inserted into this sorter. + */ + public int numRecords() { + return pointerArrayInsertPosition / 2; + } + + public long getMemoryUsage() { + return pointerArray.length * 8L; + } + + public boolean hasSpaceForAnotherRecord() { + return pointerArrayInsertPosition + 2 < pointerArray.length; + } + + public void expandPointerArray() { + final long[] oldArray = pointerArray; + // Guard against overflow: + final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; + pointerArray = new long[newLength]; + System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + } + + /** + * Inserts a record to be sorted. Assumes that the record pointer points to a record length + * stored as a 4-byte integer, followed by the record's bytes. + * + * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. + * @param keyPrefix a user-defined key prefix + */ + public void insertRecord(long recordPointer, long keyPrefix) { + if (!hasSpaceForAnotherRecord()) { + expandPointerArray(); + } + pointerArray[pointerArrayInsertPosition] = recordPointer; + pointerArrayInsertPosition++; + pointerArray[pointerArrayInsertPosition] = keyPrefix; + pointerArrayInsertPosition++; + } + + private static final class SortedIterator extends UnsafeSorterIterator { + + private final TaskMemoryManager memoryManager; + private final int sortBufferInsertPosition; + private final long[] sortBuffer; + private int position = 0; + private Object baseObject; + private long baseOffset; + private long keyPrefix; + private int recordLength; + + SortedIterator( + TaskMemoryManager memoryManager, + int sortBufferInsertPosition, + long[] sortBuffer) { + this.memoryManager = memoryManager; + this.sortBufferInsertPosition = sortBufferInsertPosition; + this.sortBuffer = sortBuffer; + } + + @Override + public boolean hasNext() { + return position < sortBufferInsertPosition; + } + + @Override + public void loadNext() { + // This pointer points to a 4-byte record length, followed by the record's bytes + final long recordPointer = sortBuffer[position]; + baseObject = memoryManager.getPage(recordPointer); + baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length + recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4); + keyPrefix = sortBuffer[position + 1]; + position += 2; + } + + @Override + public Object getBaseObject() { return baseObject; } + + @Override + public long getBaseOffset() { return baseOffset; } + + @Override + public int getRecordLength() { return recordLength; } + + @Override + public long getKeyPrefix() { return keyPrefix; } + } + + /** + * Return an iterator over record pointers in sorted order. For efficiency, all calls to + * {@code next()} will return the same mutable object. + */ + public UnsafeSorterIterator getSortedIterator() { + sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator); + return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray); + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java new file mode 100644 index 0000000000000..d09c728a7a638 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.util.collection.SortDataFormat; + +/** + * Supports sorting an array of (record pointer, key prefix) pairs. + * Used in {@link UnsafeInMemorySorter}. + *

+ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ +final class UnsafeSortDataFormat extends SortDataFormat { + + public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); + + private UnsafeSortDataFormat() { } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public RecordPointerAndKeyPrefix newKey() { + return new RecordPointerAndKeyPrefix(); + } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { + reuse.recordPointer = data[pos * 2]; + reuse.keyPrefix = data[pos * 2 + 1]; + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + long tempPointer = data[pos0 * 2]; + long tempKeyPrefix = data[pos0 * 2 + 1]; + data[pos0 * 2] = data[pos1 * 2]; + data[pos0 * 2 + 1] = data[pos1 * 2 + 1]; + data[pos1 * 2] = tempPointer; + data[pos1 * 2 + 1] = tempKeyPrefix; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos * 2] = src[srcPos * 2]; + dst[dstPos * 2 + 1] = src[srcPos * 2 + 1]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2); + } + + @Override + public long[] allocate(int length) { + assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; + return new long[length * 2]; + } + +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java new file mode 100644 index 0000000000000..16ac2e8d821ba --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; + +public abstract class UnsafeSorterIterator { + + public abstract boolean hasNext(); + + public abstract void loadNext() throws IOException; + + public abstract Object getBaseObject(); + + public abstract long getBaseOffset(); + + public abstract int getRecordLength(); + + public abstract long getKeyPrefix(); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java new file mode 100644 index 0000000000000..8272c2a5be0d1 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; +import java.util.Comparator; +import java.util.PriorityQueue; + +final class UnsafeSorterSpillMerger { + + private final PriorityQueue priorityQueue; + + public UnsafeSorterSpillMerger( + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + final int numSpills) { + final Comparator comparator = new Comparator() { + + @Override + public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { + final int prefixComparisonResult = + prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); + if (prefixComparisonResult == 0) { + return recordComparator.compare( + left.getBaseObject(), left.getBaseOffset(), + right.getBaseObject(), right.getBaseOffset()); + } else { + return prefixComparisonResult; + } + } + }; + priorityQueue = new PriorityQueue(numSpills, comparator); + } + + public void addSpill(UnsafeSorterIterator spillReader) throws IOException { + if (spillReader.hasNext()) { + spillReader.loadNext(); + } + priorityQueue.add(spillReader); + } + + public UnsafeSorterIterator getSortedIterator() throws IOException { + return new UnsafeSorterIterator() { + + private UnsafeSorterIterator spillReader; + + @Override + public boolean hasNext() { + return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); + } + + @Override + public void loadNext() throws IOException { + if (spillReader != null) { + if (spillReader.hasNext()) { + spillReader.loadNext(); + priorityQueue.add(spillReader); + } + } + spillReader = priorityQueue.remove(); + } + + @Override + public Object getBaseObject() { return spillReader.getBaseObject(); } + + @Override + public long getBaseOffset() { return spillReader.getBaseOffset(); } + + @Override + public int getRecordLength() { return spillReader.getRecordLength(); } + + @Override + public long getKeyPrefix() { return spillReader.getKeyPrefix(); } + }; + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java new file mode 100644 index 0000000000000..29e9e0f30f934 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.*; + +import com.google.common.io.ByteStreams; + +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description + * of the file format). + */ +final class UnsafeSorterSpillReader extends UnsafeSorterIterator { + + private InputStream in; + private DataInputStream din; + + // Variables that change with every record read: + private int recordLength; + private long keyPrefix; + private int numRecordsRemaining; + + private byte[] arr = new byte[1024 * 1024]; + private Object baseObject = arr; + private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; + + public UnsafeSorterSpillReader( + BlockManager blockManager, + File file, + BlockId blockId) throws IOException { + assert (file.length() > 0); + final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + numRecordsRemaining = din.readInt(); + } + + @Override + public boolean hasNext() { + return (numRecordsRemaining > 0); + } + + @Override + public void loadNext() throws IOException { + recordLength = din.readInt(); + keyPrefix = din.readLong(); + if (recordLength > arr.length) { + arr = new byte[recordLength]; + baseObject = arr; + } + ByteStreams.readFully(in, arr, 0, recordLength); + numRecordsRemaining--; + if (numRecordsRemaining == 0) { + in.close(); + in = null; + din = null; + } + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } + + @Override + public int getRecordLength() { + return recordLength; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java new file mode 100644 index 0000000000000..71eed29563d4a --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.File; +import java.io.IOException; + +import scala.Tuple2; + +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.DiskBlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Spills a list of sorted records to disk. Spill files have the following format: + * + * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...] + */ +final class UnsafeSorterSpillWriter { + + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. + private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + + private final File file; + private final BlockId blockId; + private final int numRecordsToWrite; + private DiskBlockObjectWriter writer; + private int numRecordsSpilled = 0; + + public UnsafeSorterSpillWriter( + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics, + int numRecordsToWrite) throws IOException { + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempLocalBlock(); + this.file = spilledFileInfo._2(); + this.blockId = spilledFileInfo._1(); + this.numRecordsToWrite = numRecordsToWrite; + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + writer = blockManager.getDiskWriter( + blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); + // Write the number of records + writeIntToBuffer(numRecordsToWrite, 0); + writer.write(writeBuffer, 0, 4); + } + + // Based on DataOutputStream.writeLong. + private void writeLongToBuffer(long v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 56); + writeBuffer[offset + 1] = (byte)(v >>> 48); + writeBuffer[offset + 2] = (byte)(v >>> 40); + writeBuffer[offset + 3] = (byte)(v >>> 32); + writeBuffer[offset + 4] = (byte)(v >>> 24); + writeBuffer[offset + 5] = (byte)(v >>> 16); + writeBuffer[offset + 6] = (byte)(v >>> 8); + writeBuffer[offset + 7] = (byte)(v >>> 0); + } + + // Based on DataOutputStream.writeInt. + private void writeIntToBuffer(int v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 24); + writeBuffer[offset + 1] = (byte)(v >>> 16); + writeBuffer[offset + 2] = (byte)(v >>> 8); + writeBuffer[offset + 3] = (byte)(v >>> 0); + } + + /** + * Write a record to a spill file. + * + * @param baseObject the base object / memory page containing the record + * @param baseOffset the base offset which points directly to the record data. + * @param recordLength the length of the record. + * @param keyPrefix a sort key prefix + */ + public void write( + Object baseObject, + long baseOffset, + int recordLength, + long keyPrefix) throws IOException { + if (numRecordsSpilled == numRecordsToWrite) { + throw new IllegalStateException( + "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite); + } else { + numRecordsSpilled++; + } + writeIntToBuffer(recordLength, 0); + writeLongToBuffer(keyPrefix, 4); + int dataRemaining = recordLength; + int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len + long recordReadPosition = baseOffset; + while (dataRemaining > 0) { + final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); + PlatformDependent.copyMemory( + baseObject, + recordReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + toTransfer); + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE; + } + if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) { + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer)); + } + writer.recordWritten(); + } + + public void close() throws IOException { + writer.commitAndClose(); + writer = null; + writeBuffer = null; + } + + public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { + return new UnsafeSorterSpillReader(blockManager, file, blockId); + } +} diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties new file mode 100644 index 0000000000000..b146f8a784127 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -0,0 +1,12 @@ +# Set everything to be logged to the console +log4j.rootCategory=WARN, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index 013db8df9b363..0b450dc76bc38 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -50,4 +50,9 @@ $(function() { $("span.additional-metric-title").click(function() { $(this).parent().find('input[type="checkbox"]').trigger('click'); }); + + // Trigger a double click on the span to show full job description. + $(".description-input").dblclick(function() { + $(this).removeClass("description-input").addClass("description-input-full"); + }); }); diff --git a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js index acf2d93b718b2..2d9262b972a59 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +++ b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js @@ -20,7 +20,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -module.exports={graphlib:require("./lib/graphlib"),dagre:require("./lib/dagre"),intersect:require("./lib/intersect"),render:require("./lib/render"),util:require("./lib/util"),version:require("./lib/version")}},{"./lib/dagre":8,"./lib/graphlib":9,"./lib/intersect":10,"./lib/render":23,"./lib/util":25,"./lib/version":26}],2:[function(require,module,exports){var util=require("./util");module.exports={"default":normal,normal:normal,vee:vee,undirected:undirected};function normal(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function vee(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 L 4 5 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function undirected(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 5 L 10 5").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}},{"./util":25}],3:[function(require,module,exports){var _=require("./lodash"),addLabel=require("./label/add-label"),util=require("./util");module.exports=createClusters;function createClusters(selection,g){var clusters=g.nodes().filter(function(v){return util.isSubgraph(g,v)}),svgClusters=selection.selectAll("g.cluster").data(clusters,function(v){return v});var makeClusterIdentifier=function(v){return"cluster_"+v.replace(/^cluster/,"")};svgClusters.enter().append("g").attr("id",makeClusterIdentifier).attr("name",function(v){return g.node(v).label}).classed("cluster",true).style("opacity",0).append("rect");var sortedClusters=util.orderByRank(g,svgClusters.data());for(var i=0;i0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){ +module.exports={graphlib:require("./lib/graphlib"),dagre:require("./lib/dagre"),intersect:require("./lib/intersect"),render:require("./lib/render"),util:require("./lib/util"),version:require("./lib/version")}},{"./lib/dagre":8,"./lib/graphlib":9,"./lib/intersect":10,"./lib/render":23,"./lib/util":25,"./lib/version":26}],2:[function(require,module,exports){var util=require("./util");module.exports={"default":normal,normal:normal,vee:vee,undirected:undirected};function normal(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function vee(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 L 4 5 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function undirected(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 5 L 10 5").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}},{"./util":25}],3:[function(require,module,exports){var _=require("./lodash"),addLabel=require("./label/add-label"),util=require("./util");module.exports=createClusters;function createClusters(selection,g){var clusters=g.nodes().filter(function(v){return util.isSubgraph(g,v)}),svgClusters=selection.selectAll("g.cluster").data(clusters,function(v){return v});var makeClusterIdentifier=function(v){return"cluster_"+v.replace(/^cluster/,"")};svgClusters.enter().append("g").attr("class",makeClusterIdentifier).attr("name",function(v){return g.node(v).label}).classed("cluster",true).style("opacity",0).append("rect");var sortedClusters=util.orderByRank(g,svgClusters.data());for(var i=0;i0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){ parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdxwLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{}; diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index dbacbf19beee5..dde6069000bc4 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -100,7 +100,7 @@ sorttable = { this.removeChild(document.getElementById('sorttable_sortfwdind')); sortrevind = document.createElement('span'); sortrevind.id = "sorttable_sortrevind"; - sortrevind.innerHTML = stIsIE ? ' 5' : ' ▴'; + sortrevind.innerHTML = stIsIE ? ' 5' : ' ▾'; this.appendChild(sortrevind); return; } @@ -113,7 +113,7 @@ sorttable = { this.removeChild(document.getElementById('sorttable_sortrevind')); sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); return; } @@ -134,7 +134,7 @@ sorttable = { this.className += ' sorttable_sorted'; sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); // build an array to sort. This is a Schwartzian transform thing, diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css index 18c72694f3e2d..3b4ae2ed354b8 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css @@ -15,33 +15,26 @@ * limitations under the License. */ -#dag-viz-graph svg path { - stroke: #444; - stroke-width: 1.5px; +#dag-viz-graph a, #dag-viz-graph a:hover { + text-decoration: none; } -#dag-viz-graph svg g.cluster rect { - stroke-width: 1px; +#dag-viz-graph .label { + font-weight: normal; + text-shadow: none; } -#dag-viz-graph svg g.node circle { - fill: #444; +#dag-viz-graph svg path { + stroke: #444; + stroke-width: 1.5px; } -#dag-viz-graph svg g.node rect { - fill: #C3EBFF; - stroke: #3EC0FF; +#dag-viz-graph svg g.cluster rect { stroke-width: 1px; } -#dag-viz-graph svg g.node.cached circle { - fill: #444; -} - -#dag-viz-graph svg g.node.cached rect { - fill: #B3F5C5; - stroke: #56F578; - stroke-width: 1px; +#dag-viz-graph div#empty-dag-viz-message { + margin: 15px; } /* Job page specific styles */ @@ -57,12 +50,23 @@ stroke-width: 1px; } -#dag-viz-graph svg.job g.cluster[id*="stage"] rect { +#dag-viz-graph svg.job g.cluster.skipped rect { + fill: #D6D6D6; + stroke: #B7B7B7; + stroke-width: 1px; +} + +#dag-viz-graph svg.job g.cluster.stage rect { fill: #FFFFFF; stroke: #FF99AC; stroke-width: 1px; } +#dag-viz-graph svg.job g.cluster.stage.skipped rect { + stroke: #ADADAD; + stroke-width: 1px; +} + #dag-viz-graph svg.job g#cross-stage-edges path { fill: none; } @@ -71,6 +75,20 @@ fill: #333; } +#dag-viz-graph svg.job g.cluster.skipped text { + fill: #666; +} + +#dag-viz-graph svg.job g.node circle { + fill: #444; +} + +#dag-viz-graph svg.job g.node.cached circle { + fill: #A3F545; + stroke: #52C366; + stroke-width: 2px; +} + /* Stage page specific styles */ #dag-viz-graph svg.stage g.cluster rect { @@ -79,7 +97,7 @@ stroke-width: 1px; } -#dag-viz-graph svg.stage g.cluster[id*="stage"] rect { +#dag-viz-graph svg.stage g.cluster.stage rect { fill: #FFFFFF; stroke: #FFA6B6; stroke-width: 1px; @@ -93,11 +111,14 @@ fill: #333; } -#dag-viz-graph a, #dag-viz-graph a:hover { - text-decoration: none; +#dag-viz-graph svg.stage g.node rect { + fill: #C3EBFF; + stroke: #3EC0FF; + stroke-width: 1px; } -#dag-viz-graph .label { - font-weight: normal; - text-shadow: none; +#dag-viz-graph svg.stage g.node.cached rect { + fill: #B3F5C5; + stroke: #52C366; + stroke-width: 2px; } diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index f7d0d3c61457c..9fa53baaf4212 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -57,9 +57,7 @@ var VizConstants = { stageSep: 40, graphPrefix: "graph_", nodePrefix: "node_", - stagePrefix: "stage_", - clusterPrefix: "cluster_", - stageClusterPrefix: "cluster_stage_" + clusterPrefix: "cluster_" }; var JobPageVizConstants = { @@ -86,7 +84,7 @@ function toggleDagViz(forJob) { $(arrowSelector).toggleClass('arrow-open'); var shouldShow = $(arrowSelector).hasClass("arrow-open"); if (shouldShow) { - var shouldRender = graphContainer().select("svg").empty(); + var shouldRender = graphContainer().select("*").empty(); if (shouldRender) { renderDagViz(forJob); } @@ -108,7 +106,7 @@ function toggleDagViz(forJob) { * Output DOM hierarchy: * div#dag-viz-graph > * svg > - * g#cluster_stage_[stageId] + * g.cluster_stage_[stageId] * * Note that the input metadata is populated by o.a.s.ui.UIUtils.showDagViz. * Any changes in the input format here must be reflected there. @@ -117,17 +115,23 @@ function renderDagViz(forJob) { // If there is not a dot file to render, fail fast and report error var jobOrStage = forJob ? "job" : "stage"; - if (metadataContainer().empty()) { - graphContainer() - .append("div") - .text("No visualization information available for this " + jobOrStage); + if (metadataContainer().empty() || + metadataContainer().selectAll("div").empty()) { + var message = + "No visualization information available for this " + jobOrStage + "!
" + + "If this is an old " + jobOrStage + ", its visualization metadata may have been " + + "cleaned up over time.
You may consider increasing the value of "; + if (forJob) { + message += "spark.ui.retainedJobs and spark.ui.retainedStages."; + } else { + message += "spark.ui.retainedStages"; + } + graphContainer().append("div").attr("id", "empty-dag-viz-message").html(message); return; } // Render - var svg = graphContainer() - .append("svg") - .attr("class", jobOrStage); + var svg = graphContainer().append("svg").attr("class", jobOrStage); if (forJob) { renderDagVizForJob(svg); } else { @@ -136,8 +140,9 @@ function renderDagViz(forJob) { // Find cached RDDs and mark them as such metadataContainer().selectAll(".cached-rdd").each(function(v) { - var nodeId = VizConstants.nodePrefix + d3.select(this).text(); - svg.selectAll("#" + nodeId).classed("cached", true); + var rddId = d3.select(this).text().trim(); + var nodeId = VizConstants.nodePrefix + rddId; + svg.selectAll("g." + nodeId).classed("cached", true); }); resizeSvg(svg); @@ -146,7 +151,7 @@ function renderDagViz(forJob) { /* Render the RDD DAG visualization on the stage page. */ function renderDagVizForStage(svgContainer) { var metadata = metadataContainer().select(".stage-metadata"); - var dot = metadata.select(".dot-file").text(); + var dot = metadata.select(".dot-file").text().trim(); var containerId = VizConstants.graphPrefix + metadata.attr("stage-id"); var container = svgContainer.append("g").attr("id", containerId); renderDot(dot, container, false); @@ -177,29 +182,35 @@ function renderDagVizForJob(svgContainer) { var dot = metadata.select(".dot-file").text(); var stageId = metadata.attr("stage-id"); var containerId = VizConstants.graphPrefix + stageId; - // Link each graph to the corresponding stage page (TODO: handle stage attempts) - var stageLink = "/stages/stage/?id=" + - stageId.replace(VizConstants.stagePrefix, "") + "&attempt=0&expandDagViz=true"; - var container = svgContainer - .append("a") - .attr("xlink:href", stageLink) - .append("g") - .attr("id", containerId); + var isSkipped = metadata.attr("skipped") == "true"; + var container; + if (isSkipped) { + container = svgContainer + .append("g") + .attr("id", containerId) + .attr("skipped", "true"); + } else { + // Link each graph to the corresponding stage page (TODO: handle stage attempts) + // Use the link from the stage table so it also works for the history server + var attemptId = 0 + var stageLink = d3.select("#stage-" + stageId + "-" + attemptId) + .select("a.name-link") + .attr("href") + "&expandDagViz=true"; + container = svgContainer + .append("a") + .attr("xlink:href", stageLink) + .append("g") + .attr("id", containerId); + } // Now we need to shift the container for this stage so it doesn't overlap with // existing ones, taking into account the position and width of the last stage's // container. We do not need to do this for the first stage of this job. if (i > 0) { - var existingStages = svgContainer - .selectAll("g.cluster") - .filter("[id*=\"" + VizConstants.stageClusterPrefix + "\"]"); + var existingStages = svgContainer.selectAll("g.cluster.stage") if (!existingStages.empty()) { var lastStage = d3.select(existingStages[0].pop()); - var lastStageId = lastStage.attr("id"); - var lastStageWidth = toFloat(svgContainer - .select("#" + lastStageId) - .select("rect") - .attr("width")); + var lastStageWidth = toFloat(lastStage.select("rect").attr("width")); var lastStagePosition = getAbsolutePosition(lastStage); var offset = lastStagePosition.x + lastStageWidth + VizConstants.stageSep; container.attr("transform", "translate(" + offset + ", 0)"); @@ -209,6 +220,12 @@ function renderDagVizForJob(svgContainer) { // Actually render the stage renderDot(dot, container, true); + // Mark elements as skipped if appropriate. Unfortunately we need to mark all + // elements instead of the parent container because of CSS override rules. + if (isSkipped) { + container.selectAll("g").classed("skipped", true); + } + // Round corners on rectangles container .selectAll("rect") @@ -219,7 +236,7 @@ function renderDagVizForJob(svgContainer) { // them separately later. Note that we cannot draw them now because we need to // put these edges in a separate container that is on top of all stage graphs. metadata.selectAll(".incoming-edge").each(function(v) { - var edge = d3.select(this).text().split(","); // e.g. 3,4 => [3, 4] + var edge = d3.select(this).text().trim().split(","); // e.g. 3,4 => [3, 4] crossStageEdges.push(edge); }); }); @@ -238,6 +255,9 @@ function renderDot(dot, container, forJob) { var renderer = new dagreD3.render(); preprocessGraphLayout(g, forJob); renderer(container, g); + + // Find the stage cluster and mark it for styling and post-processing + container.selectAll("g.cluster[name*=\"Stage\"]").classed("stage", true); } /* -------------------- * @@ -372,14 +392,14 @@ function getAbsolutePosition(d3selection) { function connectRDDs(fromRDDId, toRDDId, edgesContainer, svgContainer) { var fromNodeId = VizConstants.nodePrefix + fromRDDId; var toNodeId = VizConstants.nodePrefix + toRDDId; - var fromPos = getAbsolutePosition(svgContainer.select("#" + fromNodeId)); - var toPos = getAbsolutePosition(svgContainer.select("#" + toNodeId)); + var fromPos = getAbsolutePosition(svgContainer.select("g." + fromNodeId)); + var toPos = getAbsolutePosition(svgContainer.select("g." + toNodeId)); // On the job page, RDDs are rendered as dots (circles). When rendering the path, // we need to account for the radii of these circles. Otherwise the arrow heads // will bleed into the circle itself. var delta = toFloat(svgContainer - .select("g.node#" + toNodeId) + .select("g.node." + toNodeId) .select("circle") .attr("r")); if (fromPos.x < toPos.x) { @@ -431,10 +451,35 @@ function addTooltipsForRDDs(svgContainer) { node.select("circle") .attr("data-toggle", "tooltip") .attr("data-placement", "bottom") - .attr("title", tooltipText) + .attr("title", tooltipText); } + // Link tooltips for all nodes that belong to the same RDD + node.on("mouseenter", function() { triggerTooltipForRDD(node, true); }); + node.on("mouseleave", function() { triggerTooltipForRDD(node, false); }); }); - $("[data-toggle=tooltip]").tooltip({container: "body"}); + + $("[data-toggle=tooltip]") + .filter("g.node circle") + .tooltip({ container: "body", trigger: "manual" }); +} + +/* + * (Job page only) Helper function to show or hide tooltips for all nodes + * in the graph that refer to the same RDD the specified node represents. + */ +function triggerTooltipForRDD(d3node, show) { + var classes = d3node.node().classList; + for (var i = 0; i < classes.length; i++) { + var clazz = classes[i]; + var isRDDClass = clazz.indexOf(VizConstants.nodePrefix) == 0; + if (isRDDClass) { + graphContainer().selectAll("g." + clazz).each(function() { + var circle = d3.select(this).select("circle").node(); + var showOrHide = show ? "show" : "hide"; + $(circle).tooltip(showOrHide); + }); + } + } } /* Helper function to convert attributes to numeric values. */ diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css index d1e6d462b836f..0f400461c5293 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css @@ -24,6 +24,65 @@ div#application-timeline, div#job-timeline { margin-top: 5px; } +#task-assignment-timeline div.legend-area { + width: 574px; +} + +#task-assignment-timeline .legend-area > svg { + width: 100%; + height: 55px; +} + +#task-assignment-timeline div.item.range { + padding: 0px; + height: 26px; + border-width: 0; +} + +.task-assignment-timeline-content { + width: 100%; +} + +.task-assignment-timeline-duration-bar { + width: 100%; + height: 26px; +} + +rect.scheduler-delay-proportion { + fill: #80B1D3; + stroke: #6B94B0; +} + +rect.deserialization-time-proportion { + fill: #FB8072; + stroke: #D26B5F; +} + +rect.shuffle-read-time-proportion { + fill: #FDB462; + stroke: #D39651; +} + +rect.executor-runtime-proportion { + fill: #B3DE69; + stroke: #95B957; +} + +rect.shuffle-write-time-proportion { + fill: #FFED6F; + stroke: #D5C65C; +} + +rect.serialization-time-proportion { + fill: #BC80BD; + stroke: #9D6B9E; +} + +rect.getting-result-time-proportion { + fill: #8DD3C7; + stroke: #75B0A6; +} + .vis.timeline { line-height: 14px; } @@ -178,6 +237,10 @@ tr.corresponding-item-hover > td, tr.corresponding-item-hover > th { display: none; } +#task-assignment-timeline.collapsed { + display: none; +} + .control-panel { margin-bottom: 5px; } @@ -186,7 +249,8 @@ tr.corresponding-item-hover > td, tr.corresponding-item-hover > th { margin: 0; } -span.expand-application-timeline, span.expand-job-timeline { +span.expand-application-timeline, span.expand-job-timeline, +span.expand-task-assignment-timeline { cursor: pointer; } diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index 558beb8a5867f..ca74ef9d7e94e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -46,7 +46,7 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) { }; $(this).click(function() { - var jobPagePath = $(getSelectorForJobEntry(this)).find("a").attr("href") + var jobPagePath = $(getSelectorForJobEntry(this)).find("a.name-link").attr("href") window.location.href = jobPagePath }); @@ -105,7 +105,7 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { }; $(this).click(function() { - var stagePagePath = $(getSelectorForStageEntry(this)).find("a").attr("href") + var stagePagePath = $(getSelectorForStageEntry(this)).find("a.name-link").attr("href") window.location.href = stagePagePath }); @@ -133,6 +133,57 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { }); } +function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, maxFinishTime) { + var groups = new vis.DataSet(groupArray); + var items = new vis.DataSet(eventObjArray); + var container = $("#task-assignment-timeline")[0] + var options = { + groupOrder: function(a, b) { + return a.value - b.value + }, + editable: false, + align: 'left', + selectable: false, + showCurrentTime: false, + min: minLaunchTime, + max: maxFinishTime, + zoomable: false + }; + + var taskTimeline = new vis.Timeline(container) + taskTimeline.setOptions(options); + taskTimeline.setGroups(groups); + taskTimeline.setItems(items); + + // If a user zooms while a tooltip is displayed, the user may zoom such that the cursor is no + // longer over the task that the tooltip corresponds to. So, when a user zooms, we should hide + // any currently displayed tooltips. + var currentDisplayedTooltip = null; + $("#task-assignment-timeline").on({ + "mouseenter": function() { + currentDisplayedTooltip = this; + }, + "mouseleave": function() { + currentDisplayedTooltip = null; + } + }, ".task-assignment-timeline-content"); + taskTimeline.on("rangechange", function(prop) { + if (currentDisplayedTooltip !== null) { + $(currentDisplayedTooltip).tooltip("hide"); + } + }); + + setupZoomable("#task-assignment-timeline-zoom-lock", taskTimeline); + + $("span.expand-task-assignment-timeline").click(function() { + $("#task-assignment-timeline").toggleClass("collapsed"); + + // Switch the class of the arrow from open to closed. + $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-open"); + $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-closed"); + }); +} + function setupExecutorEventAction() { $(".item.box.executor").each(function () { $(this).hover( @@ -147,7 +198,7 @@ function setupExecutorEventAction() { } function setupZoomable(id, timeline) { - $(id + '>input[type="checkbox"]').click(function() { + $(id + ' > input[type="checkbox"]').click(function() { if (this.checked) { timeline.setOptions({zoomable: true}); } else { @@ -155,7 +206,7 @@ function setupZoomable(id, timeline) { } }); - $(id + ">span").click(function() { + $(id + " > span").click(function() { $(this).parent().find('input:checkbox').trigger('click'); }); } diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index e7c1d475d4e52..b1cef47042247 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -135,6 +135,14 @@ pre { display: block; } +.description-input-full { + overflow: hidden; + text-overflow: ellipsis; + width: 100%; + white-space: normal; + display: block; +} + .stacktrace-details { max-height: 300px; overflow-y: auto; diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 330df1d59a9b1..5a8d17bd99933 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -228,7 +228,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @tparam T result type */ class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) - extends Accumulable[T,T](initialValue, param, name) { + extends Accumulable[T, T](initialValue, param, name) { def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) } diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index af9765d313e9e..ceeb58075d345 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -34,8 +34,8 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. + // When spilling is enabled sorting will happen externally, but not necessarily with an + // ExternalSorter. private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") @@ -45,7 +45,7 @@ case class Aggregator[K, V, C] ( def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], context: TaskContext): Iterator[(K, C)] = { if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K,C] + val combiners = new AppendOnlyMap[K, C] var kv: Product2[K, V] = null val update = (hadValue: Boolean, oldValue: C) => { if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) @@ -76,7 +76,7 @@ case class Aggregator[K, V, C] ( : Iterator[(K, C)] = { if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K,C] + val combiners = new AppendOnlyMap[K, C] var kc: Product2[K, C] = null val update = (hadValue: Boolean, oldValue: C) => { if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 66bda68088502..0c50b4002cf7b 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -91,7 +91,7 @@ private[spark] class ExecutorAllocationManager( // How long there must be backlogged tasks for before an addition is triggered (seconds) private val schedulerBacklogTimeoutS = conf.getTimeAsSeconds( - "spark.dynamicAllocation.schedulerBacklogTimeout", "5s") + "spark.dynamicAllocation.schedulerBacklogTimeout", "1s") // Same as above, but used only after `schedulerBacklogTimeoutS` is exceeded private val sustainedSchedulerBacklogTimeoutS = conf.getTimeAsSeconds( @@ -99,7 +99,10 @@ private[spark] class ExecutorAllocationManager( // How long an executor must be idle for before it is removed (seconds) private val executorIdleTimeoutS = conf.getTimeAsSeconds( - "spark.dynamicAllocation.executorIdleTimeout", "600s") + "spark.dynamicAllocation.executorIdleTimeout", "60s") + + private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${Integer.MAX_VALUE}s") // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) @@ -150,6 +153,13 @@ private[spark] class ExecutorAllocationManager( // Metric source for ExecutorAllocationManager to expose internal status to MetricsSystem. val executorAllocationManagerSource = new ExecutorAllocationManagerSource + // Whether we are still waiting for the initial set of executors to be allocated. + // While this is true, we will not cancel outstanding executor requests. This is + // set to false when: + // (1) a stage is submitted, or + // (2) an executor idle timeout has elapsed. + @volatile private var initializing: Boolean = true + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -240,6 +250,7 @@ private[spark] class ExecutorAllocationManager( removeTimes.retain { case (executorId, expireTime) => val expired = now >= expireTime if (expired) { + initializing = false removeExecutor(executorId) } !expired @@ -261,13 +272,23 @@ private[spark] class ExecutorAllocationManager( private def updateAndSyncNumExecutorsTarget(now: Long): Int = synchronized { val maxNeeded = maxNumExecutorsNeeded - if (maxNeeded < numExecutorsTarget) { + if (initializing) { + // Do not change our target while we are still initializing, + // Otherwise the first job may have to ramp up unnecessarily + 0 + } else if (maxNeeded < numExecutorsTarget) { // The target number exceeds the number we actually need, so stop adding new // executors and inform the cluster manager to cancel the extra pending requests val oldNumExecutorsTarget = numExecutorsTarget numExecutorsTarget = math.max(maxNeeded, minNumExecutors) - client.requestTotalExecutors(numExecutorsTarget) numExecutorsToAdd = 1 + + // If the new target has not changed, avoid sending a message to the cluster manager + if (numExecutorsTarget < oldNumExecutorsTarget) { + client.requestTotalExecutors(numExecutorsTarget) + logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") + } numExecutorsTarget - oldNumExecutorsTarget } else if (addTime != NOT_SET && now >= addTime) { val delta = addExecutors(maxNeeded) @@ -292,9 +313,8 @@ private[spark] class ExecutorAllocationManager( private def addExecutors(maxNumExecutorsNeeded: Int): Int = { // Do not request more executors if it would put our target over the upper bound if (numExecutorsTarget >= maxNumExecutors) { - val numExecutorsPending = numExecutorsTarget - executorIds.size - logDebug(s"Not adding executors because there are already ${executorIds.size} registered " + - s"and ${numExecutorsPending} pending executor(s) (limit $maxNumExecutors)") + logDebug(s"Not adding executors because our current target total " + + s"is already $numExecutorsTarget (limit $maxNumExecutors)") numExecutorsToAdd = 1 return 0 } @@ -310,10 +330,19 @@ private[spark] class ExecutorAllocationManager( // Ensure that our target fits within configured bounds: numExecutorsTarget = math.max(math.min(numExecutorsTarget, maxNumExecutors), minNumExecutors) + val delta = numExecutorsTarget - oldNumExecutorsTarget + + // If our target has not changed, do not send a message + // to the cluster manager and reset our exponential growth + if (delta == 0) { + numExecutorsToAdd = 1 + return 0 + } + val addRequestAcknowledged = testing || client.requestTotalExecutors(numExecutorsTarget) if (addRequestAcknowledged) { - val delta = numExecutorsTarget - oldNumExecutorsTarget - logInfo(s"Requesting $delta new executor(s) because tasks are backlogged" + + val executorsString = "executor" + { if (delta > 1) "s" else "" } + logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" + s" (new desired total will be $numExecutorsTarget)") numExecutorsToAdd = if (delta == numExecutorsToAdd) { numExecutorsToAdd * 2 @@ -420,7 +449,7 @@ private[spark] class ExecutorAllocationManager( * This resets all variables used for adding executors. */ private def onSchedulerQueueEmpty(): Unit = synchronized { - logDebug(s"Clearing timer to add executors because there are no more pending tasks") + logDebug("Clearing timer to add executors because there are no more pending tasks") addTime = NOT_SET numExecutorsToAdd = 1 } @@ -433,9 +462,23 @@ private[spark] class ExecutorAllocationManager( private def onExecutorIdle(executorId: String): Unit = synchronized { if (executorIds.contains(executorId)) { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { + // Note that it is not necessary to query the executors since all the cached + // blocks we are concerned with are reported to the driver. Note that this + // does not include broadcast blocks. + val hasCachedBlocks = SparkEnv.get.blockManager.master.hasCachedBlocks(executorId) + val now = clock.getTimeMillis() + val timeout = { + if (hasCachedBlocks) { + // Use a different timeout if the executor has cached blocks. + now + cachedExecutorIdleTimeoutS * 1000 + } else { + now + executorIdleTimeoutS * 1000 + } + } + val realTimeout = if (timeout <= 0) Long.MaxValue else timeout // overflow + removeTimes(executorId) = realTimeout logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $executorIdleTimeoutS seconds)") - removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeoutS * 1000 + s"scheduled to run on the executor (to expire in ${(realTimeout - now)/1000} seconds)") } } else { logWarning(s"Attempted to mark unknown executor $executorId idle") @@ -467,6 +510,7 @@ private[spark] class ExecutorAllocationManager( private var numRunningTasks: Int = _ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + initializing = false val stageId = stageSubmitted.stageInfo.stageId val numTasks = stageSubmitted.stageInfo.numTasks allocationManager.synchronized { diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 91f9ef8ce7185..48792a958130c 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -150,7 +150,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def isCompleted: Boolean = jobWaiter.jobFinished - + override def isCancelled: Boolean = _cancelled override def value: Option[Try[T]] = { diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index f2b024ff6cb67..221b1dab43278 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -24,12 +24,12 @@ import scala.collection.mutable import org.apache.spark.executor.TaskMetrics import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId -import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.scheduler._ +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal - * components to convey liveness or execution information for in-progress tasks. It will also + * components to convey liveness or execution information for in-progress tasks. It will also * expire the hosts that have not heartbeated for more than spark.network.timeout. */ private[spark] case class Heartbeat( @@ -43,15 +43,25 @@ private[spark] case class Heartbeat( */ private[spark] case object TaskSchedulerIsSet -private[spark] case object ExpireDeadHosts - +private[spark] case object ExpireDeadHosts + +private case class ExecutorRegistered(executorId: String) + +private case class ExecutorRemoved(executorId: String) + private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(sc: SparkContext) - extends ThreadSafeRpcEndpoint with Logging { +private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) + extends ThreadSafeRpcEndpoint with SparkListener with Logging { + + def this(sc: SparkContext) { + this(sc, new SystemClock) + } + + sc.addSparkListener(this) override val rpcEnv: RpcEnv = sc.env.rpcEnv @@ -62,18 +72,18 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses // "milliseconds" - private val slaveTimeoutMs = + private val slaveTimeoutMs = sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") - private val executorTimeoutMs = + private val executorTimeoutMs = sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 - + // "spark.network.timeoutInterval" uses "seconds", while // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" - private val timeoutIntervalMs = + private val timeoutIntervalMs = sc.conf.getTimeAsMs("spark.storage.blockManagerTimeoutIntervalMs", "60s") - private val checkTimeoutIntervalMs = + private val checkTimeoutIntervalMs = sc.conf.getTimeAsSeconds("spark.network.timeoutInterval", s"${timeoutIntervalMs}ms") * 1000 - + private var timeoutCheckingTask: ScheduledFuture[_] = null // "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not @@ -86,30 +96,48 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) override def onStart(): Unit = { timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - Option(self).foreach(_.send(ExpireDeadHosts)) + Option(self).foreach(_.ask[Boolean](ExpireDeadHosts)) } }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS) } - override def receive: PartialFunction[Any, Unit] = { - case ExpireDeadHosts => - expireDeadHosts() + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + + // Messages sent and received locally + case ExecutorRegistered(executorId) => + executorLastSeen(executorId) = clock.getTimeMillis() + context.reply(true) + case ExecutorRemoved(executorId) => + executorLastSeen.remove(executorId) + context.reply(true) case TaskSchedulerIsSet => scheduler = sc.taskScheduler - } + context.reply(true) + case ExpireDeadHosts => + expireDeadHosts() + context.reply(true) - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + // Messages received from executors case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { - executorLastSeen(executorId) = System.currentTimeMillis() - eventLoopThread.submit(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) - context.reply(response) - } - }) + if (executorLastSeen.contains(executorId)) { + executorLastSeen(executorId) = clock.getTimeMillis() + eventLoopThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + context.reply(response) + } + }) + } else { + // This may happen if we get an executor's in-flight heartbeat immediately + // after we just removed it. It's not really an error condition so we should + // not log warning here. Otherwise there may be a lot of noise especially if + // we explicitly remove executors (SPARK-4134). + logDebug(s"Received heartbeat from unknown executor $executorId") + context.reply(HeartbeatResponse(reregisterBlockManager = true)) + } } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to @@ -119,9 +147,30 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) } } + /** + * If the heartbeat receiver is not stopped, notify it of executor registrations. + */ + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId))) + } + + /** + * If the heartbeat receiver is not stopped, notify it of executor removals so it doesn't + * log superfluous errors. + * + * Note that we must do this after the executor is actually removed to guard against the + * following race condition: if we remove an executor's metadata from our data structure + * prematurely, we may get an in-flight heartbeat from the executor before the executor is + * actually removed, in which case we will still mark the executor as a dead host later + * and expire it with loud error messages. + */ + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId))) + } + private def expireDeadHosts(): Unit = { logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.") - val now = System.currentTimeMillis() + val now = clock.getTimeMillis() for ((executorId, lastSeenMs) <- executorLastSeen) { if (now - lastSeenMs > executorTimeoutMs) { logWarning(s"Removing executor $executorId with no recent heartbeats: " + @@ -140,7 +189,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) } } } - + override def onStop(): Unit = { if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel(true) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 7e706bcc42f04..7cf7bc0dc6810 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -50,8 +50,8 @@ private[spark] class HttpFileServer( def stop() { httpServer.stop() - - // If we only stop sc, but the driver process still run as a services then we need to delete + + // If we only stop sc, but the driver process still run as a services then we need to delete // the tmp dir, if not, it will create too many tmp dirs try { Utils.deleteRecursively(baseDir) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 419d093d55643..f0598816d6c07 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -121,14 +121,28 @@ trait Logging { if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements if (!log4j12Initialized) { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") + // scalastyle:off println + if (Utils.isInInterpreter) { + val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" + Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's repl log4j profile: $replDefaultLogProps") + System.err.println("To adjust logging level use sc.setLogLevel(\"INFO\")") + case None => + System.err.println(s"Spark was unable to load $replDefaultLogProps") + } + } else { + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") + } } + // scalastyle:on println } } Logging.initialized = true @@ -145,7 +159,7 @@ private object Logging { try { // We use reflection here to handle the case where users remove the // slf4j-to-jul bridge order to route their logs to JUL. - val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + val bridgeClass = Utils.classForName("org.slf4j.bridge.SLF4JBridgeHandler") bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] if (!installed) { diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 018422827e1c8..862ffe868f58f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,7 +21,7 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, Map} +import scala.collection.mutable.{HashMap, HashSet, Map} import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -284,6 +284,53 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) } + /** + * Return a list of locations that each have fraction of map output greater than the specified + * threshold. + * + * @param shuffleId id of the shuffle + * @param reducerId id of the reduce task + * @param numReducers total number of reducers in the shuffle + * @param fractionThreshold fraction of total map output size that a location must have + * for it to be considered large. + * + * This method is not thread-safe. + */ + def getLocationsWithLargestOutputs( + shuffleId: Int, + reducerId: Int, + numReducers: Int, + fractionThreshold: Double) + : Option[Array[BlockManagerId]] = { + + if (mapStatuses.contains(shuffleId)) { + val statuses = mapStatuses(shuffleId) + if (statuses.nonEmpty) { + // HashMap to add up sizes of all blocks at the same location + val locs = new HashMap[BlockManagerId, Long] + var totalOutputSize = 0L + var mapIdx = 0 + while (mapIdx < statuses.length) { + val status = statuses(mapIdx) + val blockSize = status.getSizeForBlock(reducerId) + if (blockSize > 0) { + locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize + totalOutputSize += blockSize + } + mapIdx = mapIdx + 1 + } + val topLocs = locs.filter { case (loc, size) => + size.toDouble / totalOutputSize >= fractionThreshold + } + // Return if we have any locations which satisfy the required threshold + if (topLocs.nonEmpty) { + return Some(topLocs.map(_._1).toArray) + } + } + } + None + } + def incrementEpoch() { epochLock.synchronized { epoch += 1 diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index b8d244408bc5b..82889bcd30988 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -103,7 +103,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { */ class RangePartitioner[K : Ordering : ClassTag, V]( @transient partitions: Int, - @transient rdd: RDD[_ <: Product2[K,V]], + @transient rdd: RDD[_ <: Product2[K, V]], private var ascending: Boolean = true) extends Partitioner { @@ -185,7 +185,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } override def equals(other: Any): Boolean = other match { - case r: RangePartitioner[_,_] => + case r: RangePartitioner[_, _] => r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending case _ => false @@ -249,7 +249,7 @@ private[spark] object RangePartitioner { * @param sampleSizePerPartition max sample size per partition * @return (total number of items, an array of (partitionId, number of items, sample)) */ - def sketch[K:ClassTag]( + def sketch[K : ClassTag]( rdd: RDD[K], sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { val shift = rdd.id @@ -272,7 +272,7 @@ private[spark] object RangePartitioner { * @param partitions number of partitions * @return selected bounds */ - def determineBounds[K:Ordering:ClassTag]( + def determineBounds[K : Ordering : ClassTag]( candidates: ArrayBuffer[(K, Float)], partitions: Int): Array[K] = { val ordering = implicitly[Ordering[K]] diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 2cdc167f85af0..32df42d57dbd6 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -17,7 +17,9 @@ package org.apache.spark -import java.io.File +import java.io.{File, FileInputStream} +import java.security.{KeyStore, NoSuchAlgorithmException} +import javax.net.ssl.{KeyManager, KeyManagerFactory, SSLContext, TrustManager, TrustManagerFactory} import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import org.eclipse.jetty.util.ssl.SslContextFactory @@ -38,7 +40,7 @@ import org.eclipse.jetty.util.ssl.SslContextFactory * @param trustStore a path to the trust-store file * @param trustStorePassword a password to access the trust-store file * @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java - * @param enabledAlgorithms a set of encryption algorithms to use + * @param enabledAlgorithms a set of encryption algorithms that may be used */ private[spark] case class SSLOptions( enabled: Boolean = false, @@ -48,7 +50,8 @@ private[spark] case class SSLOptions( trustStore: Option[File] = None, trustStorePassword: Option[String] = None, protocol: Option[String] = None, - enabledAlgorithms: Set[String] = Set.empty) { + enabledAlgorithms: Set[String] = Set.empty) + extends Logging { /** * Creates a Jetty SSL context factory according to the SSL settings represented by this object. @@ -63,7 +66,7 @@ private[spark] case class SSLOptions( trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) keyPassword.foreach(sslContextFactory.setKeyManagerPassword) protocol.foreach(sslContextFactory.setProtocol) - sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*) + sslContextFactory.setIncludeCipherSuites(supportedAlgorithms.toSeq: _*) Some(sslContextFactory) } else { @@ -94,7 +97,7 @@ private[spark] case class SSLOptions( .withValue("akka.remote.netty.tcp.security.protocol", ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) .withValue("akka.remote.netty.tcp.security.enabled-algorithms", - ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq)) + ConfigValueFactory.fromIterable(supportedAlgorithms.toSeq)) .withValue("akka.remote.netty.tcp.enable-ssl", ConfigValueFactory.fromAnyRef(true))) } else { @@ -102,6 +105,36 @@ private[spark] case class SSLOptions( } } + /* + * The supportedAlgorithms set is a subset of the enabledAlgorithms that + * are supported by the current Java security provider for this protocol. + */ + private val supportedAlgorithms: Set[String] = { + var context: SSLContext = null + try { + context = SSLContext.getInstance(protocol.orNull) + /* The set of supported algorithms does not depend upon the keys, trust, or + rng, although they will influence which algorithms are eventually used. */ + context.init(null, null, null) + } catch { + case npe: NullPointerException => + logDebug("No SSL protocol specified") + context = SSLContext.getDefault + case nsa: NoSuchAlgorithmException => + logDebug(s"No support for requested SSL protocol ${protocol.get}") + context = SSLContext.getDefault + } + + val providerAlgorithms = context.getServerSocketFactory.getSupportedCipherSuites.toSet + + // Log which algorithms we are discarding + (enabledAlgorithms &~ providerAlgorithms).foreach { cipher => + logDebug(s"Discarding unsupported cipher $cipher") + } + + enabledAlgorithms & providerAlgorithms + } + /** Returns a string representation of this SSLOptions with all the passwords masked. */ override def toString: String = s"SSLOptions{enabled=$enabled, " + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 8aed1e20e0686..673ef49e7c1c5 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -192,7 +192,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" - private val authOn = sparkConf.getBoolean("spark.authenticate", false) + private val authOn = sparkConf.getBoolean(SecurityManager.SPARK_AUTH_CONF, false) // keep spark.ui.acls.enable for backwards compatibility with 1.0 private var aclsOn = sparkConf.getBoolean("spark.acls.enable", sparkConf.getBoolean("spark.ui.acls.enable", false)) @@ -365,10 +365,12 @@ private[spark] class SecurityManager(sparkConf: SparkConf) cookie } else { // user must have set spark.authenticate.secret config - sparkConf.getOption("spark.authenticate.secret") match { + // For Master/Worker, auth secret is in conf; for Executors, it is in env variable + sys.env.get(SecurityManager.ENV_AUTH_SECRET) + .orElse(sparkConf.getOption(SecurityManager.SPARK_AUTH_SECRET_CONF)) match { case Some(value) => value case None => throw new Exception("Error: a secret key must be specified via the " + - "spark.authenticate.secret config") + SecurityManager.SPARK_AUTH_SECRET_CONF + " config") } } sCookie @@ -449,3 +451,12 @@ private[spark] class SecurityManager(sparkConf: SparkConf) override def getSaslUser(appId: String): String = getSaslUser() override def getSecretKey(appId: String): String = getSecretKey() } + +private[spark] object SecurityManager { + + val SPARK_AUTH_CONF: String = "spark.authenticate" + val SPARK_AUTH_SECRET_CONF: String = "spark.authenticate.secret" + // This is used to set auth secret to an executor's env variable. It should have the same + // value as SPARK_AUTH_SECERET_CONF set in SparkConf + val ENV_AUTH_SECRET = "_SPARK_AUTH_SECRET" +} diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala index cb2cae185256a..beb2e27254725 100644 --- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala +++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala @@ -41,7 +41,7 @@ class SerializableWritable[T <: Writable](@transient var t: T) extends Serializa private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() val ow = new ObjectWritable() - ow.setConf(new Configuration()) + ow.setConf(new Configuration(false)) ow.readFields(in) t = ow.get().asInstanceOf[T] } diff --git a/core/src/main/scala/org/apache/spark/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/SizeEstimator.scala deleted file mode 100644 index 54fc3a856adfa..0000000000000 --- a/core/src/main/scala/org/apache/spark/SizeEstimator.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.apache.spark.annotation.DeveloperApi - -/** - * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in - * memory-aware caches. - * - * Based on the following JavaWorld article: - * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html - */ -@DeveloperApi -object SizeEstimator { - /** - * :: DeveloperApi :: - * Estimate the number of bytes that the given object takes up on the JVM heap. The estimate - * includes space taken up by objects referenced by the given object, their references, and so on - * and so forth. - * - * This is useful for determining the amount of heap space a broadcast variable will occupy on - * each executor or the amount of space each object will take when caching objects in - * deserialized form. This is not the same as the serialized size of the object, which will - * typically be much smaller. - */ - @DeveloperApi - def estimate(obj: AnyRef): Long = org.apache.spark.util.SizeEstimator.estimate(obj) -} diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index a8fc90ad2050e..6cf36fbbd6254 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -227,7 +227,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsBytes(key: String, defaultValue: String): Long = { Utils.byteStringAsBytes(get(key, defaultValue)) } - + /** * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. @@ -244,7 +244,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsKb(key: String, defaultValue: String): Long = { Utils.byteStringAsKb(get(key, defaultValue)) } - + /** * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. @@ -261,7 +261,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsMb(key: String, defaultValue: String): Long = { Utils.byteStringAsMb(get(key, defaultValue)) } - + /** * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. @@ -278,7 +278,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsGb(key: String, defaultValue: String): Long = { Utils.byteStringAsGb(get(key, defaultValue)) } - + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) @@ -480,8 +480,8 @@ private[spark] object SparkConf extends Logging { "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + "are no longer accepted. To specify the equivalent now, one may use '64k'.") ) - - Map(configs.map { cfg => (cfg.key -> cfg) }:_*) + + Map(configs.map { cfg => (cfg.key -> cfg) } : _*) } /** @@ -508,8 +508,8 @@ private[spark] object SparkConf extends Logging { "spark.reducer.maxSizeInFlight" -> Seq( AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), "spark.kryoserializer.buffer" -> - Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", - translation = s => s"${s.toDouble * 1000}k")), + Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${(s.toDouble * 1000).toInt}k")), "spark.kryoserializer.buffer.max" -> Seq( AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), "spark.shuffle.file.buffer" -> Seq( @@ -557,7 +557,7 @@ private[spark] object SparkConf extends Logging { def isExecutorStartupConf(name: String): Boolean = { isAkkaConf(name) || name.startsWith("spark.akka") || - name.startsWith("spark.auth") || + (name.startsWith("spark.auth") && name != SecurityManager.SPARK_AUTH_SECRET_CONF) || name.startsWith("spark.ssl") || isSparkPortConf(name) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b59f562d05ead..bd1cc332a63e7 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -315,6 +315,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _dagScheduler = ds } + /** + * A unique identifier for the Spark application. + * Its format depends on the scheduler implementation. + * (i.e. + * in case of local spark app something like 'local-1433865536131' + * in case of YARN something like 'application_1433865536131_34483' + * ) + */ def applicationId: String = _applicationId def applicationAttemptId: Option[String] = _applicationAttemptId @@ -371,6 +379,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli throw new SparkException("An application name must be set in your configuration") } + // System property spark.yarn.app.id must be set if user code ran by AM on a YARN cluster + // yarn-standalone is deprecated, but still supported + if ((master == "yarn-cluster" || master == "yarn-standalone") && + !_conf.contains("spark.yarn.app.id")) { + throw new SparkException("Detected yarn-cluster mode, but isn't running on a cluster. " + + "Deployment to YARN is not supported directly by SparkContext. Please use spark-submit.") + } + if (_conf.getBoolean("spark.logConf", false)) { logInfo("Spark configuration:\n" + _conf.toDebugString) } @@ -381,7 +397,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) - _jars =_conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten + _jars = _conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten _files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)) .toSeq.flatten @@ -430,7 +446,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _ui = if (conf.getBoolean("spark.ui.enabled", true)) { Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, - _env.securityManager,appName, startTime = startTime)) + _env.securityManager, appName, startTime = startTime)) } else { // For tests, do not enable the UI None @@ -482,7 +498,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _schedulerBackend = sched _taskScheduler = ts _dagScheduler = new DAGScheduler(this) - _heartbeatReceiver.send(TaskSchedulerIsSet) + _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's // constructor @@ -516,7 +532,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _executorAllocationManager = if (dynamicAllocationEnabled) { assert(supportDynamicAllocation, - "Dynamic allocation of executors is currently only supported in YARN mode") + "Dynamic allocation of executors is currently only supported in YARN and Mesos mode") Some(new ExecutorAllocationManager(this, listenerBus, _conf)) } else { None @@ -537,7 +553,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Post init _taskScheduler.postStartHook() - _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) _executorAllocationManager.foreach { e => _env.metricsSystem.registerSource(e.executorAllocationManagerSource) @@ -670,7 +685,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * Note: Return statements are NOT allowed in the given body. */ - private def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) + private[spark] def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) // Methods for creating RDDs @@ -689,6 +704,78 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } + /** + * Creates a new RDD[Long] containing elements from `start` to `end`(exclusive), increased by + * `step` every element. + * + * @note if we need to cache this RDD, we should make sure each partition does not exceed limit. + * + * @param start the start value. + * @param end the end value. + * @param step the incremental step + * @param numSlices the partition number of the new RDD. + * @return + */ + def range( + start: Long, + end: Long, + step: Long = 1, + numSlices: Int = defaultParallelism): RDD[Long] = withScope { + assertNotStopped() + // when step is 0, range will run infinitely + require(step != 0, "step cannot be 0") + val numElements: BigInt = { + val safeStart = BigInt(start) + val safeEnd = BigInt(end) + if ((safeEnd - safeStart) % step == 0 || safeEnd > safeStart ^ step > 0) { + (safeEnd - safeStart) / step + } else { + // the remainder has the same sign with range, could add 1 more + (safeEnd - safeStart) / step + 1 + } + } + parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex((i, _) => { + val partitionStart = (i * numElements) / numSlices * step + start + val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start + def getSafeMargin(bi: BigInt): Long = + if (bi.isValidLong) { + bi.toLong + } else if (bi > 0) { + Long.MaxValue + } else { + Long.MinValue + } + val safePartitionStart = getSafeMargin(partitionStart) + val safePartitionEnd = getSafeMargin(partitionEnd) + + new Iterator[Long] { + private[this] var number: Long = safePartitionStart + private[this] var overflow: Boolean = false + + override def hasNext = + if (!overflow) { + if (step > 0) { + number < safePartitionEnd + } else { + number > safePartitionEnd + } + } else false + + override def next() = { + val ret = number + number += step + if (number < ret ^ step < 0) { + // we have Long.MaxValue + Long.MaxValue < Long.MaxValue + // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step + // back, we are pretty sure that we have an overflow. + overflow = true + } + ret + } + } + }) + } + /** Distribute a local Scala collection to form an RDD. * * This method is identical to `parallelize`. @@ -744,7 +831,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * @note Small files are preferred, large file is also allowable, but may cause bad performance. - * + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` * @param minPartitions A suggestion value of the minimal splitting number for input data. */ def wholeTextFiles( @@ -765,7 +853,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions).setName(path) } - /** * :: Experimental :: * @@ -791,9 +878,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @param minPartitions A suggestion value of the minimal splitting number for input data. - * * @note Small files are preferred; very large files may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` + * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @Experimental def binaryFiles( @@ -837,7 +925,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli classOf[FixedLengthBinaryInputFormat], classOf[LongWritable], classOf[BytesWritable], - conf=conf) + conf = conf) val data = br.map { case (k, v) => val bytes = v.getBytes assert(bytes.length == recordLength, "Byte array does not have correct length") @@ -894,7 +982,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope { assertNotStopped() // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. - val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) + val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) new HadoopRDD( this, @@ -1079,8 +1167,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]): RDD[(K, V)] = { withScope { assertNotStopped() - val kc = kcf() - val vc = vcf() + val kc = clean(kcf)() + val vc = clean(vcf)() val format = classOf[SequenceFileInputFormat[Writable, Writable]] val writables = hadoopFile(path, format, kc.writableClass(km).asInstanceOf[Class[Writable]], @@ -1187,7 +1275,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { - val param = new GrowableAccumulableParam[R,T] + val param = new GrowableAccumulableParam[R, T] val acc = new Accumulable(initialValue, param) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1236,7 +1324,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val uri = new URI(path) val schemeCorrectedPath = uri.getScheme match { case null | "local" => new File(path).getCanonicalFile.toURI.toString - case _ => path + case _ => path } val hadoopPath = new Path(schemeCorrectedPath) @@ -1275,10 +1363,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * Return whether dynamically adjusting the amount of resources allocated to - * this application is supported. This is currently only available for YARN. + * this application is supported. This is currently only available for YARN + * and Mesos coarse-grained mode. */ - private[spark] def supportDynamicAllocation = - master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false) + private[spark] def supportDynamicAllocation: Boolean = { + (master.contains("yarn") + || master.contains("mesos") + || _conf.getBoolean("spark.dynamicAllocation.testing", false)) + } /** * :: DeveloperApi :: @@ -1296,7 +1388,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") + "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestTotalExecutors(numExecutors) @@ -1314,7 +1406,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @DeveloperApi override def requestExecutors(numAdditionalExecutors: Int): Boolean = { assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") + "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestExecutors(numAdditionalExecutors) @@ -1332,7 +1424,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @DeveloperApi override def killExecutors(executorIds: Seq[String]): Boolean = { assert(supportDynamicAllocation, - "Killing executors is currently only supported in YARN mode") + "Killing executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.killExecutors(executorIds) @@ -1804,7 +1896,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * @param f the closure to clean * @param checkSerializable whether or not to immediately check f for serializability - * @throws SparkException if checkSerializable is set but f is not + * @throws SparkException if checkSerializable is set but f is not * serializable */ private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = { @@ -1817,6 +1909,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * be a HDFS path if running on a cluster. */ def setCheckpointDir(directory: String) { + + // If we are running on a cluster, log a warning if the directory is local. + // Otherwise, the driver may attempt to reconstruct the checkpointed RDD from + // its own local file system, which is incorrect because the checkpoint files + // are actually on the executor machines. + if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) { + logWarning("Checkpoint directory must be non-local " + + "if Spark is running on a cluster: " + directory) + } + checkpointDir = Option(directory).map { dir => val path = new Path(dir, UUID.randomUUID().toString) val fs = path.getFileSystem(hadoopConfiguration) @@ -1866,7 +1968,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli for (className <- listenerClassNames) { // Use reflection to find the right constructor val constructors = { - val listenerClass = Class.forName(className) + val listenerClass = Utils.classForName(className) listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]] } val constructorTakingSparkConf = constructors.find { c => @@ -1911,7 +2013,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Note: this code assumes that the task scheduler has been initialized and has contacted // the cluster manager to get an application ID (in case the cluster manager provides one). listenerBus.post(SparkListenerApplicationStart(appName, Some(applicationId), - startTime, sparkUser, applicationAttemptId)) + startTime, sparkUser, applicationAttemptId, schedulerBackend.getDriverLogUrls)) } /** Post the application end event */ @@ -2401,7 +2503,7 @@ object SparkContext extends Logging { "\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.") } val scheduler = try { - val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") + val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { @@ -2413,7 +2515,7 @@ object SparkContext extends Logging { } val backend = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend") + Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend") val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { @@ -2426,8 +2528,7 @@ object SparkContext extends Logging { case "yarn-client" => val scheduler = try { - val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler") + val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] @@ -2439,7 +2540,7 @@ object SparkContext extends Logging { val backend = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") + Utils.classForName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0c4d28f786edd..adfece4d6e7c0 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -20,7 +20,8 @@ package org.apache.spark import java.io.File import java.net.Socket -import scala.collection.JavaConversions._ +import akka.actor.ActorSystem + import scala.collection.mutable import scala.util.Properties @@ -75,7 +76,8 @@ class SparkEnv ( val conf: SparkConf) extends Logging { // TODO Remove actorSystem - val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + @deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0") + val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -87,39 +89,42 @@ class SparkEnv ( private var driverTmpDirToDelete: Option[String] = None private[spark] def stop() { - isStopped = true - pythonWorkers.foreach { case(key, worker) => worker.stop() } - Option(httpFileServer).foreach(_.stop()) - mapOutputTracker.stop() - shuffleManager.stop() - broadcastManager.stop() - blockManager.stop() - blockManager.master.stop() - metricsSystem.stop() - outputCommitCoordinator.stop() - rpcEnv.shutdown() - - // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut - // down, but let's call it anyway in case it gets fixed in a later release - // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. - // actorSystem.awaitTermination() - - // Note that blockTransferService is stopped by BlockManager since it is started by it. - - // If we only stop sc, but the driver process still run as a services then we need to delete - // the tmp dir, if not, it will create too many tmp dirs. - // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the - // current working dir in executor which we do not need to delete. - driverTmpDirToDelete match { - case Some(path) => { - try { - Utils.deleteRecursively(new File(path)) - } catch { - case e: Exception => - logWarning(s"Exception while deleting Spark temp dir: $path", e) + + if (!isStopped) { + isStopped = true + pythonWorkers.values.foreach(_.stop()) + Option(httpFileServer).foreach(_.stop()) + mapOutputTracker.stop() + shuffleManager.stop() + broadcastManager.stop() + blockManager.stop() + blockManager.master.stop() + metricsSystem.stop() + outputCommitCoordinator.stop() + rpcEnv.shutdown() + + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut + // down, but let's call it anyway in case it gets fixed in a later release + // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. + // actorSystem.awaitTermination() + + // Note that blockTransferService is stopped by BlockManager since it is started by it. + + // If we only stop sc, but the driver process still run as a services then we need to delete + // the tmp dir, if not, it will create too many tmp dirs. + // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the + // current working dir in executor which we do not need to delete. + driverTmpDirToDelete match { + case Some(path) => { + try { + Utils.deleteRecursively(new File(path)) + } catch { + case e: Exception => + logWarning(s"Exception while deleting Spark temp dir: $path", e) + } } + case None => // We just need to delete tmp dir created by driver, so do nothing on executor } - case None => // We just need to delete tmp dir created by driver, so do nothing on executor } } @@ -168,7 +173,7 @@ object SparkEnv extends Logging { /** * Returns the ThreadLocal SparkEnv. */ - @deprecated("Use SparkEnv.get instead", "1.2") + @deprecated("Use SparkEnv.get instead", "1.2.0") def getThreadLocal: SparkEnv = { env } @@ -256,7 +261,7 @@ object SparkEnv extends Logging { // Create an instance of the class with the given name, possibly initializing it with our conf def instantiateClass[T](className: String): T = { - val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader) + val cls = Utils.classForName(className) // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just // SparkConf, then one taking no arguments try { @@ -298,7 +303,7 @@ object SparkEnv extends Logging { } } - val mapOutputTracker = if (isDriver) { + val mapOutputTracker = if (isDriver) { new MapOutputTrackerMaster(conf) } else { new MapOutputTrackerWorker(conf) @@ -313,7 +318,8 @@ object SparkEnv extends Logging { // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", - "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") + "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager", + "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) @@ -347,7 +353,7 @@ object SparkEnv extends Logging { val fileServerPort = conf.getInt("spark.fileserver.port", 0) val server = new HttpFileServer(conf, securityManager, fileServerPort) server.initialize() - conf.set("spark.fileserver.uri", server.serverUri) + conf.set("spark.fileserver.uri", server.serverUri) server } else { null @@ -378,7 +384,7 @@ object SparkEnv extends Logging { } val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { - new OutputCommitCoordinator(conf) + new OutputCommitCoordinator(conf, isDriver) } val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator", new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 2ec42d3aea169..f5dd36cbcfe6d 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD +import org.apache.spark.util.SerializableJobConf /** * Internal helper class that saves an RDD using a Hadoop OutputFormat. @@ -42,7 +43,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) with Serializable { private val now = new Date() - private val conf = new SerializableWritable(jobConf) + private val conf = new SerializableJobConf(jobConf) private var jobID = 0 private var splitID = 0 @@ -50,8 +51,8 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private var jID: SerializableWritable[JobID] = null private var taID: SerializableWritable[TaskAttemptID] = null - @transient private var writer: RecordWriter[AnyRef,AnyRef] = null - @transient private var format: OutputFormat[AnyRef,AnyRef] = null + @transient private var writer: RecordWriter[AnyRef, AnyRef] = null + @transient private var format: OutputFormat[AnyRef, AnyRef] = null @transient private var committer: OutputCommitter = null @transient private var jobContext: JobContext = null @transient private var taskContext: TaskAttemptContext = null @@ -114,10 +115,10 @@ class SparkHadoopWriter(@transient jobConf: JobConf) // ********* Private Functions ********* - private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = { + private def getOutputFormat(): OutputFormat[AnyRef, AnyRef] = { if (format == null) { format = conf.value.getOutputFormat() - .asInstanceOf[OutputFormat[AnyRef,AnyRef]] + .asInstanceOf[OutputFormat[AnyRef, AnyRef]] } format } @@ -138,7 +139,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private def getTaskContext(): TaskAttemptContext = { if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) + taskContext = newTaskAttemptContext(conf.value, taID.value) } taskContext } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 398ca41e16151..a1ebbecf93b7b 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -51,7 +51,7 @@ private[spark] object TestUtils { classpathUrls: Seq[URL] = Seq()): URL = { val tempDir = Utils.createTempDir() val files1 = for (name <- classNames) yield { - createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) + createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) } val files2 = for ((childName, baseName) <- classNamesWithBase) yield { createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls) @@ -105,23 +105,18 @@ private[spark] object TestUtils { URI.create(s"string:///${name.replace(".", "/")}${SOURCE.extension}") } - private class JavaSourceFromString(val name: String, val code: String) + private[spark] class JavaSourceFromString(val name: String, val code: String) extends SimpleJavaFileObject(createURI(name), SOURCE) { override def getCharContent(ignoreEncodingErrors: Boolean): String = code } - /** Creates a compiled class with the given name. Class file will be placed in destDir. */ + /** Creates a compiled class with the source file. Class file will be placed in destDir. */ def createCompiledClass( className: String, destDir: File, - toStringValue: String = "", - baseClass: String = null, - classpathUrls: Seq[URL] = Seq()): File = { + sourceFile: JavaSourceFromString, + classpathUrls: Seq[URL]): File = { val compiler = ToolProvider.getSystemJavaCompiler - val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") - val sourceFile = new JavaSourceFromString(className, - "public class " + className + extendsText + " implements java.io.Serializable {" + - " @Override public String toString() { return \"" + toStringValue + "\"; }}") // Calling this outputs a class file in pwd. It's easier to just rename the file than // build a custom FileManager that controls the output location. @@ -144,4 +139,18 @@ private[spark] object TestUtils { assert(out.exists(), "Destination file not moved: " + out.getAbsolutePath()) out } + + /** Creates a compiled class with the given name. Class file will be placed in destDir. */ + def createCompiledClass( + className: String, + destDir: File, + toStringValue: String = "", + baseClass: String = null, + classpathUrls: Seq[URL] = Seq()): File = { + val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") + val sourceFile = new JavaSourceFromString(className, + "public class " + className + extendsText + " implements java.io.Serializable {" + + " @Override public String toString() { return \"" + toStringValue + "\"; }}") + createCompiledClass(className, destDir, sourceFile, classpathUrls) + } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 61af867b11b9c..a650df605b92e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -137,7 +137,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) */ def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index db4e996feb31c..ed312770ee131 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -101,7 +101,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD. - * + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] @@ -109,10 +109,10 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. - * + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 8bf0627fc420d..c95615a5a9307 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -60,10 +60,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { @deprecated("Use partitions() instead.", "1.1.0") def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) - + /** Set of partitions in this RDD. */ def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + /** The partitioner of this RDD. */ + def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) + /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ def context: SparkContext = rdd.context @@ -96,7 +99,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsWithIndex[R]( f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = - new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), + new JavaRDD(rdd.mapPartitionsWithIndex(((a, b) => f(a, asJavaIterator(b))), preservesPartitioning)(fakeClassTag))(fakeClassTag) /** @@ -386,9 +389,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to - * modify t1 and return it as its result value to avoid object allocation; however, it should not - * modify t2. + * given associative and commutative function and a neutral "zero value". The function + * op(t1, t2) is allowed to modify t1 and return it as its result value to avoid object + * allocation; however, it should not modify t2. + * + * This behaves somewhat differently from fold operations implemented for non-distributed + * collections in functional languages like Scala. This fold operation may be applied to + * partitions individually, and then fold those results into the final result, rather than + * apply the fold to each element sequentially in some defined ordering. For functions + * that are not commutative, the result may differ from that of a fold applied to a + * non-distributed collection. */ def fold(zeroValue: T)(f: JFunction2[T, T, T]): T = rdd.fold(zeroValue)(f) @@ -485,9 +495,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { new java.util.ArrayList(arr) } - def takeSample(withReplacement: Boolean, num: Int): JList[T] = + def takeSample(withReplacement: Boolean, num: Int): JList[T] = takeSample(withReplacement, num, Utils.random.nextLong) - + def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index c9181a29d4756..b959b683d1674 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -19,8 +19,8 @@ package org.apache.spark.api.python import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SerializableWritable, SparkException} +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, SparkException} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ import scala.util.{Failure, Success, Try} @@ -61,7 +61,7 @@ private[python] object Converter extends Logging { * Other objects are passed through without conversion. */ private[python] class WritableToJavaConverter( - conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] { + conf: Broadcast[SerializableConfiguration]) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 7409dc2d866f6..dc9f62f39e6d5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -36,7 +36,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import scala.util.control.NonFatal @@ -47,6 +47,7 @@ private[spark] class PythonRDD( pythonIncludes: JList[String], preservePartitoning: Boolean, pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { @@ -210,6 +211,8 @@ private[spark] class PythonRDD( val dataOut = new DataOutputStream(stream) // Partition index dataOut.writeInt(split.index) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) // Python includes (*.zip and *.egg files) @@ -442,7 +445,7 @@ private[spark] object PythonRDD extends Logging { val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -468,7 +471,7 @@ private[spark] object PythonRDD extends Logging { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -494,7 +497,7 @@ private[spark] object PythonRDD extends Logging { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -537,7 +540,7 @@ private[spark] object PythonRDD extends Logging { val rdd = hadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -563,7 +566,7 @@ private[spark] object PythonRDD extends Logging { val rdd = hadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -720,7 +723,7 @@ private[spark] object PythonRDD extends Logging { val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] - converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec) + converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec = codec) } /** @@ -794,10 +797,10 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) - /** + /** * We try to reuse a single Socket to transfer accumulator updates, as they are all added * by the DAGScheduler's single-threaded actor anyway. - */ + */ @transient var socket: Socket = _ def openSocket(): Socket = synchronized { @@ -840,6 +843,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: * An Wrapper for Python Broadcast, which is written into disk by Python. It also will * write the data into disk after deserialization, then Python can read it from disks. */ +// scalastyle:off no.finalize private[spark] class PythonBroadcast(@transient var path: String) extends Serializable { /** @@ -881,3 +885,4 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial } } } +// scalastyle:on no.finalize diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index efb6b93cfc35d..90dacaeb93429 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -50,8 +50,15 @@ private[spark] object PythonUtils { /** * Convert list of T into seq of T (for calling API with varargs) */ - def toSeq[T](cols: JList[T]): Seq[T] = { - cols.toList.toSeq + def toSeq[T](vs: JList[T]): Seq[T] = { + vs.toList.toSeq + } + + /** + * Convert list of T into array of T (for calling API with array) + */ + def toArray[T](vs: JList[T]): Array[T] = { + vs.toArray().asInstanceOf[Array[T]] } /** diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 3a2c94bd9d875..b7e72d4d0ed0b 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.r import java.io.{DataOutputStream, File, FileOutputStream, IOException} -import java.net.{InetSocketAddress, ServerSocket} +import java.net.{InetAddress, InetSocketAddress, ServerSocket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap @@ -29,7 +29,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.LengthFieldBasedFrameDecoder import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkConf} /** * Netty-based backend server that is used to communicate between R and Java. @@ -41,14 +41,15 @@ private[spark] class RBackend { private[this] var bossGroup: EventLoopGroup = null def init(): Int = { - bossGroup = new NioEventLoopGroup(2) + val conf = new SparkConf() + bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2)) val workerGroup = bossGroup val handler = new RBackendHandler(this) - + bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) .channel(classOf[NioServerSocketChannel]) - + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { def initChannel(ch: SocketChannel): Unit = { ch.pipeline() @@ -65,7 +66,7 @@ private[spark] class RBackend { } }) - channelFuture = bootstrap.bind(new InetSocketAddress(0)) + channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0)) channelFuture.syncUninterruptibly() channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() } @@ -94,14 +95,16 @@ private[spark] class RBackend { private[spark] object RBackend extends Logging { def main(args: Array[String]): Unit = { if (args.length < 1) { + // scalastyle:off println System.err.println("Usage: RBackend ") + // scalastyle:on println System.exit(-1) } val sparkRBackend = new RBackend() try { // bind to random port val boundPort = sparkRBackend.init() - val serverSocket = new ServerSocket(0, 1) + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() // tell the R process via temporary file diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 0075d963711f1..9658e9a696ffa 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -26,6 +26,7 @@ import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.Logging import org.apache.spark.api.r.SerDe._ +import org.apache.spark.util.Utils /** * Handler for RBackend @@ -77,7 +78,7 @@ private[r] class RBackendHandler(server: RBackend) val reply = bos.toByteArray ctx.write(reply) } - + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { ctx.flush() } @@ -98,7 +99,7 @@ private[r] class RBackendHandler(server: RBackend) var obj: Object = null try { val cls = if (isStatic) { - Class.forName(objId) + Utils.classForName(objId) } else { JVMObjectTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) @@ -124,7 +125,7 @@ private[r] class RBackendHandler(server: RBackend) } throw new Exception(s"No matched method found for $cls.$methodName") } - val ret = methods.head.invoke(obj, args:_*) + val ret = methods.head.invoke(obj, args : _*) // Write status bit writeInt(dos, 0) @@ -135,7 +136,7 @@ private[r] class RBackendHandler(server: RBackend) matchMethod(numArgs, args, x.getParameterTypes) }.head - val obj = ctor.newInstance(args:_*) + val obj = ctor.newInstance(args : _*) writeInt(dos, 0) writeObject(dos, obj.asInstanceOf[AnyRef]) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 6fea5e1144f2f..23a470d6afcae 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.r import java.io._ -import java.net.ServerSocket +import java.net.{InetAddress, ServerSocket} import java.util.{Map => JMap} import scala.collection.JavaConversions._ @@ -39,7 +39,6 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( deserializer: String, serializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Broadcast[Object]]) extends RDD[U](parent) with Logging { protected var dataStream: DataInputStream = _ @@ -55,12 +54,12 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( val parentIterator = firstParent[T].iterator(partition, context) // we expect two connections - val serverSocket = new ServerSocket(0, 2) + val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() // The stdout/stderr is shared by multiple tasks, because we use one daemon // to launch child process as worker. - val errThread = RRDD.createRWorker(rLibDir, listenPort) + val errThread = RRDD.createRWorker(listenPort) // We use two sockets to separate input and output, then it's easy to manage // the lifecycle of them to avoid deadlock. @@ -161,7 +160,9 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( dataOut.write(elem.asInstanceOf[Array[Byte]]) } else if (deserializer == SerializationFormats.STRING) { // write string(for StringRRDD) + // scalastyle:off println printOut.println(elem) + // scalastyle:on println } } @@ -233,11 +234,10 @@ private class PairwiseRRDD[T: ClassTag]( hashFunc: Array[Byte], deserializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, (Int, Array[Byte])]( parent, numPartitions, hashFunc, deserializer, - SerializationFormats.BYTE, packageNames, rLibDir, + SerializationFormats.BYTE, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): (Int, Array[Byte]) = { @@ -264,10 +264,9 @@ private class RRDD[T: ClassTag]( deserializer: String, serializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, Array[Byte]]( - parent, -1, func, deserializer, serializer, packageNames, rLibDir, + parent, -1, func, deserializer, serializer, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): Array[Byte] = { @@ -291,10 +290,9 @@ private class StringRRDD[T: ClassTag]( func: Array[Byte], deserializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, String]( - parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, + parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): String = { @@ -309,7 +307,7 @@ private class StringRRDD[T: ClassTag]( } private object SpecialLengths { - val TIMING_DATA = -1 + val TIMING_DATA = -1 } private[r] class BufferedStreamThread( @@ -355,7 +353,6 @@ private[r] object RRDD { val sparkConf = new SparkConf().setAppName(appName) .setSparkHome(sparkHome) - .setJars(jars) // Override `master` if we have a user-specified value if (master != "") { @@ -373,7 +370,11 @@ private[r] object RRDD { sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) } - new JavaSparkContext(sparkConf) + val jsc = new JavaSparkContext(sparkConf) + jars.foreach { jar => + jsc.addJar(jar) + } + jsc } /** @@ -387,9 +388,10 @@ private[r] object RRDD { thread } - private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { - val rCommand = "Rscript" + private def createRProcess(port: Int, script: String): BufferedStreamThread = { + val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript") val rOptions = "--vanilla" + val rLibDir = RUtils.sparkRPackagePath(isDriver = false) val rExecScript = rLibDir + "/SparkR/worker/" + script val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) // Unset the R_TESTS environment variable for workers. @@ -408,15 +410,15 @@ private[r] object RRDD { /** * ProcessBuilder used to launch worker R processes. */ - def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = { + def createRWorker(port: Int): BufferedStreamThread = { val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) if (!Utils.isWindows && useDaemon) { synchronized { if (daemonChannel == null) { // we expect one connections - val serverSocket = new ServerSocket(0, 1) + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val daemonPort = serverSocket.getLocalPort - errThread = createRProcess(rLibDir, daemonPort, "daemon.R") + errThread = createRProcess(daemonPort, "daemon.R") // the socket used to send out the input of task serverSocket.setSoTimeout(10000) val sock = serverSocket.accept() @@ -438,7 +440,7 @@ private[r] object RRDD { errThread } } else { - createRProcess(rLibDir, port, "worker.R") + createRProcess(port, "worker.R") } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala new file mode 100644 index 0000000000000..d53abd3408c55 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.File + +import org.apache.spark.{SparkEnv, SparkException} + +private[spark] object RUtils { + /** + * Get the SparkR package path in the local spark distribution. + */ + def localSparkRPackagePath: Option[String] = { + val sparkHome = sys.env.get("SPARK_HOME") + sparkHome.map( + Seq(_, "R", "lib").mkString(File.separator) + ) + } + + /** + * Get the SparkR package path in various deployment modes. + * This assumes that Spark properties `spark.master` and `spark.submit.deployMode` + * and environment variable `SPARK_HOME` are set. + */ + def sparkRPackagePath(isDriver: Boolean): String = { + val (master, deployMode) = + if (isDriver) { + (sys.props("spark.master"), sys.props("spark.submit.deployMode")) + } else { + val sparkConf = SparkEnv.get.conf + (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode")) + } + + val isYarnCluster = master.contains("yarn") && deployMode == "cluster" + val isYarnClient = master.contains("yarn") && deployMode == "client" + + // In YARN mode, the SparkR package is distributed as an archive symbolically + // linked to the "sparkr" file in the current directory. Note that this does not apply + // to the driver in client mode because it is run outside of the cluster. + if (isYarnCluster || (isYarnClient && !isDriver)) { + new File("sparkr").getAbsolutePath + } else { + // Otherwise, assume the package is local + // TODO: support this for Mesos + localSparkRPackagePath.getOrElse { + throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 371dfe454d1a2..56adc857d4ce0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.r import java.io.{DataInputStream, DataOutputStream} -import java.sql.{Date, Time} +import java.sql.{Timestamp, Date, Time} import scala.collection.JavaConversions._ @@ -107,9 +107,12 @@ private[spark] object SerDe { Date.valueOf(readString(in)) } - def readTime(in: DataInputStream): Time = { - val t = in.readDouble() - new Time((t * 1000L).toLong) + def readTime(in: DataInputStream): Timestamp = { + val seconds = in.readDouble() + val sec = Math.floor(seconds).toLong + val t = new Timestamp(sec * 1000L) + t.setNanos(((seconds - sec) * 1e9).toInt) + t } def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { @@ -157,9 +160,11 @@ private[spark] object SerDe { val keysLen = readInt(in) val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) - val valuesType = readObjectType(in) val valuesLen = readInt(in) - val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType)) + val values = (0 until valuesLen).map(_ => { + val valueType = readObjectType(in) + readTypedObject(in, valueType) + }) mapAsJavaMap(keys.zip(values).toMap) } else { new java.util.HashMap[Object, Object]() @@ -225,6 +230,9 @@ private[spark] object SerDe { case "java.sql.Time" => writeType(dos, "time") writeTime(dos, value.asInstanceOf[Time]) + case "java.sql.Timestamp" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Timestamp]) case "[B" => writeType(dos, "raw") writeBytes(dos, value.asInstanceOf[Array[Byte]]) @@ -287,6 +295,9 @@ private[spark] object SerDe { out.writeDouble(value.getTime.toDouble / 1000.0) } + def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) + } // NOTE: Only works for ASCII right now def writeString(out: DataOutputStream, value: String): Unit = { diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 685313ac009ba..fac6666bb3410 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.util.Utils private[spark] class BroadcastManager( val isDriver: Boolean, @@ -42,7 +43,7 @@ private[spark] class BroadcastManager( conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject broadcastFactory.initialize(isDriver, conf, securityManager) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 4457c75e8b0fc..b69af639f7862 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -125,7 +125,7 @@ private[broadcast] object HttpBroadcast extends Logging { securityManager = securityMgr if (isDriver) { createServer(conf) - conf.set("spark.httpBroadcast.uri", serverUri) + conf.set("spark.httpBroadcast.uri", serverUri) } serverUri = conf.get("spark.httpBroadcast.uri") cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf) @@ -187,7 +187,7 @@ private[broadcast] object HttpBroadcast extends Logging { } private def read[T: ClassTag](id: Long): T = { - logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) + logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name var uc: URLConnection = null diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 848b62f9de71b..f03875a3e8c89 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -18,17 +18,17 @@ package org.apache.spark.deploy import scala.collection.mutable.HashSet -import scala.concurrent._ +import scala.concurrent.ExecutionContext +import scala.reflect.ClassTag +import scala.util.{Failure, Success} -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} import org.apache.log4j.{Level, Logger} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} +import org.apache.spark.util.{ThreadUtils, SparkExitCode, Utils} /** * Proxy that relays messages to the driver. @@ -36,20 +36,30 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} * We currently don't support retry if submission fails. In HA mode, client will submit request to * all masters and see which one could handle it. */ -private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { - - private val masterActors = driverArgs.masters.map { m => - context.actorSelection(Master.toAkkaUrl(m, AkkaUtils.protocol(context.system))) - } - private val lostMasters = new HashSet[Address] - private var activeMasterActor: ActorSelection = null - - val timeout = RpcUtils.askTimeout(conf) - - override def preStart(): Unit = { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - +private class ClientEndpoint( + override val rpcEnv: RpcEnv, + driverArgs: ClientArguments, + masterEndpoints: Seq[RpcEndpointRef], + conf: SparkConf) + extends ThreadSafeRpcEndpoint with Logging { + + // A scheduled executor used to send messages at the specified time. + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("client-forward-message") + // Used to provide the implicit parameter of `Future` methods. + private val forwardMessageExecutionContext = + ExecutionContext.fromExecutor(forwardMessageThread, + t => t match { + case ie: InterruptedException => // Exit normally + case e: Throwable => + logError(e.getMessage, e) + System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) + }) + + private val lostMasters = new HashSet[RpcAddress] + private var activeMasterEndpoint: RpcEndpointRef = null + + override def onStart(): Unit = { driverArgs.cmd match { case "launch" => // TODO: We could add an env variable here and intercept it in `sc.addJar` that would @@ -82,44 +92,52 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) driverArgs.cores, driverArgs.supervise, command) - - // This assumes only one Master is active at a time - for (masterActor <- masterActors) { - masterActor ! RequestSubmitDriver(driverDescription) - } + ayncSendToMasterAndForwardReply[SubmitDriverResponse]( + RequestSubmitDriver(driverDescription)) case "kill" => val driverId = driverArgs.driverId - // This assumes only one Master is active at a time - for (masterActor <- masterActors) { - masterActor ! RequestKillDriver(driverId) - } + ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) + } + } + + /** + * Send the message to master and forward the reply to self asynchronously. + */ + private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { + for (masterEndpoint <- masterEndpoints) { + masterEndpoint.ask[T](message).onComplete { + case Success(v) => self.send(v) + case Failure(e) => + logWarning(s"Error sending messages to master $masterEndpoint", e) + }(forwardMessageExecutionContext) } } /* Find out driver status then exit the JVM */ def pollAndReportStatus(driverId: String) { - println("... waiting before polling master for driver state") + // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread + // is fine. + logInfo("... waiting before polling master for driver state") Thread.sleep(5000) - println("... polling master for driver state") - val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout) - .mapTo[DriverStatusResponse] - val statusResponse = Await.result(statusFuture, timeout) + logInfo("... polling master for driver state") + val statusResponse = + activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId)) statusResponse.found match { case false => - println(s"ERROR: Cluster master did not recognize $driverId") + logError(s"ERROR: Cluster master did not recognize $driverId") System.exit(-1) case true => - println(s"State of $driverId is ${statusResponse.state.get}") + logInfo(s"State of $driverId is ${statusResponse.state.get}") // Worker node, if present (statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match { case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) => - println(s"Driver running on $hostPort ($id)") + logInfo(s"Driver running on $hostPort ($id)") case _ => } // Exception, if present statusResponse.exception.map { e => - println(s"Exception from cluster was: $e") + logError(s"Exception from cluster was: $e") e.printStackTrace() System.exit(-1) } @@ -127,50 +145,62 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) } } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { - case SubmitDriverResponse(success, driverId, message) => - println(message) + case SubmitDriverResponse(master, success, driverId, message) => + logInfo(message) if (success) { - activeMasterActor = context.actorSelection(sender.path) + activeMasterEndpoint = master pollAndReportStatus(driverId.get) } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } - case KillDriverResponse(driverId, success, message) => - println(message) + case KillDriverResponse(master, driverId, success, message) => + logInfo(message) if (success) { - activeMasterActor = context.actorSelection(sender.path) + activeMasterEndpoint = master pollAndReportStatus(driverId) } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } + } - case DisassociatedEvent(_, remoteAddress, _) => - if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master $remoteAddress.") - lostMasters += remoteAddress - // Note that this heuristic does not account for the fact that a Master can recover within - // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This - // is not currently a concern, however, because this client does not retry submissions. - if (lostMasters.size >= masterActors.size) { - println("No master is available, exiting.") - System.exit(-1) - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + logError(s"Error connecting to master $remoteAddress.") + lostMasters += remoteAddress + // Note that this heuristic does not account for the fact that a Master can recover within + // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This + // is not currently a concern, however, because this client does not retry submissions. + if (lostMasters.size >= masterEndpoints.size) { + logError("No master is available, exiting.") + System.exit(-1) } + } + } - case AssociationErrorEvent(cause, _, remoteAddress, _, _) => - if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master ($remoteAddress).") - println(s"Cause was: $cause") - lostMasters += remoteAddress - if (lostMasters.size >= masterActors.size) { - println("No master is available, exiting.") - System.exit(-1) - } + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + logError(s"Error connecting to master ($remoteAddress).") + logError(s"Cause was: $cause") + lostMasters += remoteAddress + if (lostMasters.size >= masterEndpoints.size) { + logError("No master is available, exiting.") + System.exit(-1) } + } + } + + override def onError(cause: Throwable): Unit = { + logError(s"Error processing messages, exiting.") + cause.printStackTrace() + System.exit(-1) + } + + override def onStop(): Unit = { + forwardMessageThread.shutdownNow() } } @@ -179,10 +209,12 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) */ object Client { def main(args: Array[String]) { + // scalastyle:off println if (!sys.props.contains("SPARK_SUBMIT")) { println("WARNING: This client is deprecated and will be removed in a future version of Spark") println("Use ./bin/spark-submit with \"--master spark://host:port\"") } + // scalastyle:on println val conf = new SparkConf() val driverArgs = new ClientArguments(args) @@ -194,15 +226,13 @@ object Client { conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) - val (actorSystem, _) = AkkaUtils.createActorSystem( - "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) + val rpcEnv = + RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely - for (m <- driverArgs.masters) { - Master.toAkkaUrl(m, AkkaUtils.protocol(actorSystem)) - } - actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) + val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL). + map(rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, _, Master.ENDPOINT_NAME)) + rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf)) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 316e2d59f01b8..72cc330a398da 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -72,9 +72,11 @@ private[deploy] class ClientArguments(args: Array[String]) { cmd = "launch" if (!ClientArguments.isValidJarUrl(_jarUrl)) { + // scalastyle:off println println(s"Jar url '${_jarUrl}' is not in valid format.") println(s"Must be a jar file path in URL format " + "(e.g. hdfs://host:port/XX.jar, file:///XX.jar)") + // scalastyle:on println printUsageAndExit(-1) } @@ -110,14 +112,16 @@ private[deploy] class ClientArguments(args: Array[String]) { | (default: $DEFAULT_SUPERVISE) | -v, --verbose Print more debugging output """.stripMargin + // scalastyle:off println System.err.println(usage) + // scalastyle:on println System.exit(exitCode) } } private[deploy] object ClientArguments { val DEFAULT_CORES = 1 - val DEFAULT_MEMORY = 512 // MB + val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // MB val DEFAULT_SUPERVISE = false def isValidJarUrl(s: String): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 9db6fd1ac4dbe..12727de9b4cf3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -24,11 +24,12 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable -/** Contains messages sent between Scheduler actor nodes. */ +/** Contains messages sent between Scheduler endpoint nodes. */ private[deploy] object DeployMessages { // Worker to Master @@ -37,6 +38,7 @@ private[deploy] object DeployMessages { id: String, host: String, port: Int, + worker: RpcEndpointRef, cores: Int, memory: Int, webUiPort: Int, @@ -63,11 +65,11 @@ private[deploy] object DeployMessages { case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription], driverIds: Seq[String]) - case class Heartbeat(workerId: String) extends DeployMessage + case class Heartbeat(workerId: String, worker: RpcEndpointRef) extends DeployMessage // Master to Worker - case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage + case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage case class RegisterWorkerFailed(message: String) extends DeployMessage @@ -92,13 +94,13 @@ private[deploy] object DeployMessages { // Worker internal - case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders + case object WorkDirCleanup // Sent to Worker endpoint periodically for cleaning up app folders case object ReregisterWithMaster // used when a worker attempts to reconnect to a master // AppClient to Master - case class RegisterApplication(appDescription: ApplicationDescription) + case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef) extends DeployMessage case class UnregisterApplication(appId: String) @@ -107,7 +109,7 @@ private[deploy] object DeployMessages { // Master to AppClient - case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage + case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends DeployMessage // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { @@ -123,12 +125,14 @@ private[deploy] object DeployMessages { case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage - case class SubmitDriverResponse(success: Boolean, driverId: Option[String], message: String) + case class SubmitDriverResponse( + master: RpcEndpointRef, success: Boolean, driverId: Option[String], message: String) extends DeployMessage case class RequestKillDriver(driverId: String) extends DeployMessage - case class KillDriverResponse(driverId: String, success: Boolean, message: String) + case class KillDriverResponse( + master: RpcEndpointRef, driverId: String, success: Boolean, message: String) extends DeployMessage case class RequestDriverStatus(driverId: String) extends DeployMessage @@ -142,7 +146,7 @@ private[deploy] object DeployMessages { // Master to Worker & AppClient - case class MasterChanged(masterUrl: String, masterWebUiUrl: String) + case class MasterChanged(master: RpcEndpointRef, masterWebUiUrl: String) // MasterWebUI To Master diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index c048b78910f38..b4edb6109e839 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -65,7 +65,7 @@ private object FaultToleranceTest extends App with Logging { private val workers = ListBuffer[TestWorkerInfo]() private var sc: SparkContext = _ - private val zk = SparkCuratorUtil.newClient(conf) + private val zk = SparkCuratorUtil.newClient(conf) private var numPassed = 0 private var numFailed = 0 diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 2954f932b4f41..ccffb36652988 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -76,12 +76,13 @@ private[deploy] object JsonProtocol { } def writeMasterState(obj: MasterStateResponse): JObject = { + val aliveWorkers = obj.workers.filter(_.isAlive()) ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ - ("cores" -> obj.workers.map(_.cores).sum) ~ - ("coresused" -> obj.workers.map(_.coresUsed).sum) ~ - ("memory" -> obj.workers.map(_.memory).sum) ~ - ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~ + ("cores" -> aliveWorkers.map(_.cores).sum) ~ + ("coresused" -> aliveWorkers.map(_.coresUsed).sum) ~ + ("memory" -> aliveWorkers.map(_.memory).sum) ~ + ("memoryused" -> aliveWorkers.map(_.memoryUsed).sum) ~ ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~ ("activedrivers" -> obj.activeDrivers.toList.map(writeDriverInfo)) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 860e1a24901b6..53356addf6edb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorSystem - +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master @@ -41,8 +40,10 @@ class LocalSparkCluster( extends Logging { private val localHostname = Utils.localHostName() - private val masterActorSystems = ArrayBuffer[ActorSystem]() - private val workerActorSystems = ArrayBuffer[ActorSystem]() + private val masterRpcEnvs = ArrayBuffer[RpcEnv]() + private val workerRpcEnvs = ArrayBuffer[RpcEnv]() + // exposed for testing + var masterWebUIPort = -1 def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") @@ -53,16 +54,17 @@ class LocalSparkCluster( .set("spark.shuffle.service.enabled", "false") /* Start the Master */ - val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, _conf) - masterActorSystems += masterSystem - val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + masterPort + val (rpcEnv, webUiPort, _) = Master.startRpcEnvAndEndpoint(localHostname, 0, 0, _conf) + masterWebUIPort = webUiPort + masterRpcEnvs += rpcEnv + val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + rpcEnv.address.port val masters = Array(masterUrl) /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, + val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker, memoryPerWorker, masters, null, Some(workerNum), _conf) - workerActorSystems += workerSystem + workerRpcEnvs += workerEnv } masters @@ -73,11 +75,11 @@ class LocalSparkCluster( // Stop the workers before the master so they don't get upset that it disconnected // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! // This is unfortunate, but for now we just comment it out. - workerActorSystems.foreach(_.shutdown()) + workerRpcEnvs.foreach(_.shutdown()) // workerActorSystems.foreach(_.awaitTermination()) - masterActorSystems.foreach(_.shutdown()) + masterRpcEnvs.foreach(_.shutdown()) // masterActorSystems.foreach(_.awaitTermination()) - masterActorSystems.clear() - workerActorSystems.clear() + masterRpcEnvs.clear() + workerRpcEnvs.clear() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 53e18c4bcec23..c2ed43a5397d6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -18,9 +18,11 @@ package org.apache.spark.deploy import java.net.URI +import java.io.File import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ +import scala.util.Try import org.apache.spark.api.python.PythonUtils import org.apache.spark.util.{RedirectThread, Utils} @@ -81,16 +83,13 @@ object PythonRunner { throw new IllegalArgumentException("Launching Python applications through " + s"spark-submit is currently only supported for local files: $path") } - val windows = Utils.isWindows || testWindows - var formattedPath = if (windows) Utils.formatWindowsPath(path) else path - - // Strip the URI scheme from the path - formattedPath = - new URI(formattedPath).getScheme match { - case null => formattedPath - case Utils.windowsDrive(d) if windows => formattedPath - case _ => new URI(formattedPath).getPath - } + // get path when scheme is file. + val uri = Try(new URI(path)).getOrElse(new File(path).toURI) + var formattedPath = uri.getScheme match { + case null => path + case "file" | "local" => uri.getPath + case _ => null + } // Guard against malformed paths potentially throwing NPE if (formattedPath == null) { @@ -99,7 +98,9 @@ object PythonRunner { // In Windows, the drive should not be prefixed with "/" // For instance, python does not understand "/C:/path/to/sheep.py" - formattedPath = if (windows) formattedPath.stripPrefix("/") else formattedPath + if (Utils.isWindows && formattedPath.matches("/[a-zA-Z]:/.*")) { + formattedPath = formattedPath.stripPrefix("/") + } formattedPath } diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index e99779f299785..c0cab22fa8252 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.fs.Path -import org.apache.spark.api.r.RBackend +import org.apache.spark.api.r.{RBackend, RUtils} import org.apache.spark.util.RedirectThread /** @@ -71,9 +71,10 @@ object RRunner { val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) - val sparkHome = System.getenv("SPARK_HOME") + val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) + env.put("SPARKR_PACKAGE_DIR", rPackageDir) env.put("R_PROFILE_USER", - Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator)) + Seq(rPackageDir, "SparkR", "profile", "general.R").mkString(File.separator)) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() @@ -85,7 +86,9 @@ object RRunner { } System.exit(returnCode) } else { + // scalastyle:off println System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + // scalastyle:on println System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 7fa75ac8c2b54..9f94118829ff1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -178,7 +178,7 @@ class SparkHadoopUtil extends Logging { private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { val statisticsDataClass = - Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") + Utils.classForName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") statisticsDataClass.getDeclaredMethod(methodName) } @@ -334,6 +334,19 @@ class SparkHadoopUtil extends Logging { * Stop the thread that does the delegation token updates. */ private[spark] def stopExecutorDelegationTokenRenewer() {} + + /** + * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism. + * This is to prevent the DFSClient from using an old cached token to connect to the NameNode. + */ + private[spark] def getConfBypassingFSCache( + hadoopConf: Configuration, + scheme: String): Configuration = { + val newConf = new Configuration(hadoopConf) + val confKey = s"fs.${scheme}.impl.disable.cache" + newConf.setBoolean(confKey, true) + newConf + } } object SparkHadoopUtil { @@ -343,7 +356,7 @@ object SparkHadoopUtil { System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) if (yarnMode) { try { - Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") + Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") .newInstance() .asInstanceOf[SparkHadoopUtil] } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 329fa06ba8ba5..036cb6e054791 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -35,7 +35,9 @@ import org.apache.ivy.core.resolve.ResolveOptions import org.apache.ivy.core.retrieve.RetrieveOptions import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher -import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver} +import org.apache.ivy.plugins.repository.file.FileRepository +import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} +import org.apache.spark.api.r.RUtils import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -78,17 +80,19 @@ object SparkSubmit { private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" private val SPARKR_SHELL = "sparkr-shell" + private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + // scalastyle:off println // Exposed for testing - private[spark] var exitFn: () => Unit = () => System.exit(1) + private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) private[spark] var printStream: PrintStream = System.err private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) private[spark] def printErrorAndExit(str: String): Unit = { printStream.println("Error: " + str) printStream.println("Run with --help for usage help or --verbose for debug output") - exitFn() + exitFn(1) } private[spark] def printVersionAndExit(): Unit = { printStream.println("""Welcome to @@ -99,13 +103,16 @@ object SparkSubmit { /_/ """.format(SPARK_VERSION)) printStream.println("Type --help for more information.") - exitFn() + exitFn(0) } + // scalastyle:on println def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { + // scalastyle:off println printStream.println(appArgs) + // scalastyle:on println } appArgs.action match { case SparkSubmitAction.SUBMIT => submit(appArgs) @@ -159,8 +166,10 @@ object SparkSubmit { // makes the message printed to the output by the JVM not very helpful. Instead, // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { + // scalastyle:off println printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") - exitFn() + // scalastyle:on println + exitFn(1) } else { throw e } @@ -177,7 +186,9 @@ object SparkSubmit { // to use the legacy gateway if the master endpoint turns out to be not a REST server. if (args.isStandaloneCluster && args.useRest) { try { + // scalastyle:off println printStream.println("Running Spark using the REST application submission protocol.") + // scalastyle:on println doRunMain() } catch { // Fail over to use the legacy submission gateway @@ -253,6 +264,12 @@ object SparkSubmit { } } + // Update args.deployMode if it is null. It will be passed down as a Spark property later. + (args.deployMode, deployMode) match { + case (null, CLIENT) => args.deployMode = "client" + case (null, CLUSTER) => args.deployMode = "cluster" + case _ => + } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER @@ -324,53 +341,35 @@ object SparkSubmit { // Usage: PythonAppRunner

[app arguments] args.mainClass = "org.apache.spark.deploy.PythonRunner" args.childArgs = ArrayBuffer(args.primaryResource, args.pyFiles) ++ args.childArgs - args.files = mergeFileLists(args.files, args.primaryResource) + if (clusterManager != YARN) { + // The YARN backend distributes the primary file differently, so don't merge it. + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + if (clusterManager != YARN) { + // The YARN backend handles python files differently, so don't merge the lists. + args.files = mergeFileLists(args.files, args.pyFiles) } - args.files = mergeFileLists(args.files, args.pyFiles) if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } } - // In yarn mode for a python app, add pyspark archives to files + // In YARN mode for an R app, add the SparkR package archive to archives // that can be distributed with the job - if (args.isPython && clusterManager == YARN) { - var pyArchives: String = null - val pyArchivesEnvOpt = sys.env.get("PYSPARK_ARCHIVES_PATH") - if (pyArchivesEnvOpt.isDefined) { - pyArchives = pyArchivesEnvOpt.get - } else { - if (!sys.env.contains("SPARK_HOME")) { - printErrorAndExit("SPARK_HOME does not exist for python application in yarn mode.") - } - val pythonPath = new ArrayBuffer[String] - for (sparkHome <- sys.env.get("SPARK_HOME")) { - val pyLibPath = Seq(sparkHome, "python", "lib").mkString(File.separator) - val pyArchivesFile = new File(pyLibPath, "pyspark.zip") - if (!pyArchivesFile.exists()) { - printErrorAndExit("pyspark.zip does not exist for python application in yarn mode.") - } - val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip") - if (!py4jFile.exists()) { - printErrorAndExit("py4j-0.8.2.1-src.zip does not exist for python application " + - "in yarn mode.") - } - pythonPath += pyArchivesFile.getAbsolutePath() - pythonPath += py4jFile.getAbsolutePath() - } - pyArchives = pythonPath.mkString(",") + if (args.isR && clusterManager == YARN) { + val rPackagePath = RUtils.localSparkRPackagePath + if (rPackagePath.isEmpty) { + printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") + } + val rPackageFile = new File(rPackagePath.get, SPARKR_PACKAGE_ARCHIVE) + if (!rPackageFile.exists()) { + printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") } + val localURI = Utils.resolveURI(rPackageFile.getAbsolutePath) - pyArchives = pyArchives.split(",").map { localPath=> - val localURI = Utils.resolveURI(localPath) - if (localURI.getScheme != "local") { - args.files = mergeFileLists(args.files, localURI.toString) - new Path(localPath).getName - } else { - localURI.getPath - } - }.mkString(File.pathSeparator) - sysProps("spark.submit.pyArchives") = pyArchives + // Assigns a symbol link name "sparkr" to the shipped package. + args.archives = mergeFileLists(args.archives, localURI.toString + "#sparkr") } // If we're running a R app, set the main class to our specific R runner @@ -386,19 +385,10 @@ object SparkSubmit { } } - if (isYarnCluster) { - // In yarn-cluster mode for a python app, add primary resource and pyFiles to files - // that can be distributed with the job - if (args.isPython) { - args.files = mergeFileLists(args.files, args.primaryResource) - args.files = mergeFileLists(args.files, args.pyFiles) - } - + if (isYarnCluster && args.isR) { // In yarn-cluster mode for a R app, add primary resource to files // that can be distributed with the job - if (args.isR) { - args.files = mergeFileLists(args.files, args.primaryResource) - } + args.files = mergeFileLists(args.files, args.primaryResource) } // Special flag to avoid deprecation warnings at the client @@ -410,6 +400,8 @@ object SparkSubmit { // All cluster managers OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), + OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + sysProp = "spark.submit.deployMode"), OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), @@ -425,9 +417,10 @@ object SparkSubmit { // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), - OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), + OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), + OptionAssigner(args.keytab, YARN, CLIENT, sysProp = "spark.yarn.keytab"), // Yarn cluster only OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), @@ -440,13 +433,11 @@ object SparkSubmit { OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"), OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), - - // Yarn client or cluster - OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, clOption = "--principal"), - OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, clOption = "--keytab"), + OptionAssigner(args.principal, YARN, CLUSTER, clOption = "--principal"), + OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"), // Other options - OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, + OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), @@ -516,17 +507,18 @@ object SparkSubmit { } } + // Let YARN know it's a pyspark app, so it distributes needed libraries. + if (clusterManager == YARN && args.isPython) { + sysProps.put("spark.yarn.isPython", "true") + } + // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (isYarnCluster) { childMainClass = "org.apache.spark.deploy.yarn.Client" if (args.isPython) { - val mainPyFile = new Path(args.primaryResource).getName - childArgs += ("--primary-py-file", mainPyFile) + childArgs += ("--primary-py-file", args.primaryResource) if (args.pyFiles != null) { - // These files will be distributed to each machine's working directory, so strip the - // path prefix - val pyFilesNames = args.pyFiles.split(",").map(p => (new Path(p)).getName).mkString(",") - childArgs += ("--py-files", pyFilesNames) + childArgs += ("--py-files", args.pyFiles) } childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") } else if (args.isR) { @@ -601,6 +593,7 @@ object SparkSubmit { sysProps: Map[String, String], childMainClass: String, verbose: Boolean): Unit = { + // scalastyle:off println if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") @@ -608,6 +601,7 @@ object SparkSubmit { printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } + // scalastyle:on println val loader = if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { @@ -630,13 +624,15 @@ object SparkSubmit { var mainClass: Class[_] = null try { - mainClass = Class.forName(childMainClass, true, loader) + mainClass = Utils.classForName(childMainClass) } catch { case e: ClassNotFoundException => e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { + // scalastyle:off println printStream.println(s"Failed to load main class $childMainClass.") printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") + // scalastyle:on println } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } @@ -700,7 +696,7 @@ object SparkSubmit { /** * Return whether the given main class represents a sql shell. */ - private def isSqlShell(mainClass: String): Boolean = { + private[deploy] def isSqlShell(mainClass: String): Boolean = { mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" } @@ -753,7 +749,9 @@ private[spark] object SparkSubmitUtils { * @param artifactId the artifactId of the coordinate * @param version the version of the coordinate */ - private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) + private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) { + override def toString: String = s"$groupId:$artifactId:$version" + } /** * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided @@ -776,6 +774,16 @@ private[spark] object SparkSubmitUtils { } } + /** Path of the local Maven cache. */ + private[spark] def m2Path: File = { + if (Utils.isTesting) { + // test builds delete the maven cache, and this can cause flakiness + new File("dummy", ".m2" + File.separator + "repository") + } else { + new File(System.getProperty("user.home"), ".m2" + File.separator + "repository") + } + } + /** * Extracts maven coordinates from a comma-delimited string * @param remoteRepos Comma-delimited string of remote repositories @@ -787,20 +795,36 @@ private[spark] object SparkSubmitUtils { val cr = new ChainResolver cr.setName("list") + val repositoryList = remoteRepos.getOrElse("") + // add any other remote repositories other than maven central + if (repositoryList.trim.nonEmpty) { + repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + // scalastyle:off println + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + // scalastyle:on println + } + } + val localM2 = new IBiblioResolver localM2.setM2compatible(true) - val m2Path = ".m2" + File.separator + "repository" + File.separator - localM2.setRoot(new File(System.getProperty("user.home"), m2Path).toURI.toString) + localM2.setRoot(m2Path.toURI.toString) localM2.setUsepoms(true) localM2.setName("local-m2-cache") cr.add(localM2) - val localIvy = new IBiblioResolver - localIvy.setRoot(new File(ivySettings.getDefaultIvyUserDir, - "local" + File.separator).toURI.toString) + val localIvy = new FileSystemResolver + val localIvyRoot = new File(ivySettings.getDefaultIvyUserDir, "local") + localIvy.setLocal(true) + localIvy.setRepository(new FileRepository(localIvyRoot)) val ivyPattern = Seq("[organisation]", "[module]", "[revision]", "[type]s", "[artifact](-[classifier]).[ext]").mkString(File.separator) - localIvy.setPattern(ivyPattern) + localIvy.addIvyPattern(localIvyRoot.getAbsolutePath + File.separator + ivyPattern) localIvy.setName("local-ivy-cache") cr.add(localIvy) @@ -817,20 +841,6 @@ private[spark] object SparkSubmitUtils { sp.setRoot("http://dl.bintray.com/spark-packages/maven") sp.setName("spark-packages") cr.add(sp) - - val repositoryList = remoteRepos.getOrElse("") - // add any other remote repositories other than maven central - if (repositoryList.trim.nonEmpty) { - repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => - val brr: IBiblioResolver = new IBiblioResolver - brr.setM2compatible(true) - brr.setUsepoms(true) - brr.setRoot(repo) - brr.setName(s"repo-${i + 1}") - cr.add(brr) - printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") - } - } cr } @@ -860,22 +870,20 @@ private[spark] object SparkSubmitUtils { val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) val dd = new DefaultDependencyDescriptor(ri, false, false) dd.addDependencyConfiguration(ivyConfName, ivyConfName) + // scalastyle:off println printStream.println(s"${dd.getDependencyId} added as a dependency") + // scalastyle:on println md.addDependency(dd) } } - + /** Add exclusion rules for dependencies already included in the spark-assembly */ def addExclusionRules( ivySettings: IvySettings, ivyConfName: String, md: DefaultModuleDescriptor): Unit = { // Add scala exclusion rule - val scalaArtifacts = new ArtifactId(new ModuleId("*", "scala-library"), "*", "*", "*") - val scalaDependencyExcludeRule = - new DefaultExcludeRule(scalaArtifacts, ivySettings.getMatcher("glob"), null) - scalaDependencyExcludeRule.addConfiguration(ivyConfName) - md.addExcludeRule(scalaDependencyExcludeRule) + md.addExcludeRule(createExclusion("*:scala-library:*", ivySettings, ivyConfName)) // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka and // other spark-streaming utility components. Underscore is there to differentiate between @@ -884,13 +892,8 @@ private[spark] object SparkSubmitUtils { "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") components.foreach { comp => - val sparkArtifacts = - new ArtifactId(new ModuleId("org.apache.spark", s"spark-$comp*"), "*", "*", "*") - val sparkDependencyExcludeRule = - new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null) - sparkDependencyExcludeRule.addConfiguration(ivyConfName) - - md.addExcludeRule(sparkDependencyExcludeRule) + md.addExcludeRule(createExclusion(s"org.apache.spark:spark-$comp*:*", ivySettings, + ivyConfName)) } } @@ -903,6 +906,7 @@ private[spark] object SparkSubmitUtils { * @param coordinates Comma-delimited string of maven coordinates * @param remoteRepos Comma-delimited string of remote repositories other than maven central * @param ivyPath The path to the local ivy repository + * @param exclusions Exclusions to apply when resolving transitive dependencies * @return The comma-delimited path to the jars of the given maven artifacts including their * transitive dependencies */ @@ -910,76 +914,107 @@ private[spark] object SparkSubmitUtils { coordinates: String, remoteRepos: Option[String], ivyPath: Option[String], + exclusions: Seq[String] = Nil, isTest: Boolean = false): String = { if (coordinates == null || coordinates.trim.isEmpty) { "" } else { val sysOut = System.out - // To prevent ivy from logging to system out - System.setOut(printStream) - val artifacts = extractMavenCoordinates(coordinates) - // Default configuration name for ivy - val ivyConfName = "default" - // set ivy settings for location of cache - val ivySettings: IvySettings = new IvySettings - // Directories for caching downloads through ivy and storing the jars when maven coordinates - // are supplied to spark-submit - val alternateIvyCache = ivyPath.getOrElse("") - val packagesDirectory: File = - if (alternateIvyCache.trim.isEmpty) { - new File(ivySettings.getDefaultIvyUserDir, "jars") + try { + // To prevent ivy from logging to system out + System.setOut(printStream) + val artifacts = extractMavenCoordinates(coordinates) + // Default configuration name for ivy + val ivyConfName = "default" + // set ivy settings for location of cache + val ivySettings: IvySettings = new IvySettings + // Directories for caching downloads through ivy and storing the jars when maven coordinates + // are supplied to spark-submit + val alternateIvyCache = ivyPath.getOrElse("") + val packagesDirectory: File = + if (alternateIvyCache.trim.isEmpty) { + new File(ivySettings.getDefaultIvyUserDir, "jars") + } else { + ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) + ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) + new File(alternateIvyCache, "jars") + } + // scalastyle:off println + printStream.println( + s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") + printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // scalastyle:on println + // create a pattern matcher + ivySettings.addMatcher(new GlobPatternMatcher) + // create the dependency resolvers + val repoResolver = createRepoResolvers(remoteRepos, ivySettings) + ivySettings.addResolver(repoResolver) + ivySettings.setDefaultResolver(repoResolver.getName) + + val ivy = Ivy.newInstance(ivySettings) + // Set resolve options to download transitive dependencies as well + val resolveOptions = new ResolveOptions + resolveOptions.setTransitive(true) + val retrieveOptions = new RetrieveOptions + // Turn downloading and logging off for testing + if (isTest) { + resolveOptions.setDownload(false) + resolveOptions.setLog(LogOptions.LOG_QUIET) + retrieveOptions.setLog(LogOptions.LOG_QUIET) } else { - ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) - ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) - new File(alternateIvyCache, "jars") + resolveOptions.setDownload(true) } - printStream.println( - s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") - printStream.println(s"The jars for the packages stored in: $packagesDirectory") - // create a pattern matcher - ivySettings.addMatcher(new GlobPatternMatcher) - // create the dependency resolvers - val repoResolver = createRepoResolvers(remoteRepos, ivySettings) - ivySettings.addResolver(repoResolver) - ivySettings.setDefaultResolver(repoResolver.getName) - - val ivy = Ivy.newInstance(ivySettings) - // Set resolve options to download transitive dependencies as well - val resolveOptions = new ResolveOptions - resolveOptions.setTransitive(true) - val retrieveOptions = new RetrieveOptions - // Turn downloading and logging off for testing - if (isTest) { - resolveOptions.setDownload(false) - resolveOptions.setLog(LogOptions.LOG_QUIET) - retrieveOptions.setLog(LogOptions.LOG_QUIET) - } else { - resolveOptions.setDownload(true) - } - - // A Module descriptor must be specified. Entries are dummy strings - val md = getModuleDescriptor - md.setDefaultConf(ivyConfName) - // Add exclusion rules for Spark and Scala Library - addExclusionRules(ivySettings, ivyConfName, md) - // add all supplied maven artifacts as dependencies - addDependenciesToIvy(md, artifacts, ivyConfName) + // A Module descriptor must be specified. Entries are dummy strings + val md = getModuleDescriptor + // clear ivy resolution from previous launches. The resolution file is usually at + // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file + // leads to confusion with Ivy when the files can no longer be found at the repository + // declared in that file/ + val mdId = md.getModuleRevisionId + val previousResolution = new File(ivySettings.getDefaultCache, + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") + if (previousResolution.exists) previousResolution.delete + + md.setDefaultConf(ivyConfName) + + // Add exclusion rules for Spark and Scala Library + addExclusionRules(ivySettings, ivyConfName, md) + // add all supplied maven artifacts as dependencies + addDependenciesToIvy(md, artifacts, ivyConfName) + + exclusions.foreach { e => + md.addExcludeRule(createExclusion(e + ":*", ivySettings, ivyConfName)) + } - // resolve dependencies - val rr: ResolveReport = ivy.resolve(md, resolveOptions) - if (rr.hasError) { - throw new RuntimeException(rr.getAllProblemMessages.toString) + // resolve dependencies + val rr: ResolveReport = ivy.resolve(md, resolveOptions) + if (rr.hasError) { + throw new RuntimeException(rr.getAllProblemMessages.toString) + } + // retrieve all resolved dependencies + ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, + packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision].[ext]", + retrieveOptions.setConfs(Array(ivyConfName))) + resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + } finally { + System.setOut(sysOut) } - // retrieve all resolved dependencies - ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, - packagesDirectory.getAbsolutePath + File.separator + - "[organization]_[artifact]-[revision].[ext]", - retrieveOptions.setConfs(Array(ivyConfName))) - System.setOut(sysOut) - resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) } } + + private def createExclusion( + coords: String, + ivySettings: IvySettings, + ivyConfName: String): ExcludeRule = { + val c = extractMavenCoordinates(coords)(0) + val id = new ArtifactId(new ModuleId(c.groupId, c.artifactId), "*", "*", "*") + val rule = new DefaultExcludeRule(id, ivySettings.getMatcher("glob"), null) + rule.addConfiguration(ivyConfName) + rule + } + } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c0e4c771908b3..b3710073e330c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,12 +17,15 @@ package org.apache.spark.deploy +import java.io.{ByteArrayOutputStream, PrintStream} +import java.lang.reflect.InvocationTargetException import java.net.URI import java.util.{List => JList} import java.util.jar.JarFile import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.io.Source import org.apache.spark.deploy.SparkSubmitAction._ import org.apache.spark.launcher.SparkSubmitArgumentsParser @@ -76,6 +79,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() + // scalastyle:off println if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => Utils.getPropertiesFromFile(filename).foreach { case (k, v) => @@ -83,6 +87,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } } + // scalastyle:on println defaultProperties } @@ -159,6 +164,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orNull executorCores = Option(executorCores) .orElse(sparkProperties.get("spark.executor.cores")) + .orElse(env.get("SPARK_EXECUTOR_CORES")) .orNull totalExecutorCores = Option(totalExecutorCores) .orElse(sparkProperties.get("spark.cores.max")) @@ -169,6 +175,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) + keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull + principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && !isR && primaryResource != null) { @@ -410,6 +418,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case VERSION => SparkSubmit.printVersionAndExit() + case USAGE_ERROR => + printUsageAndExit(1) + case _ => throw new IllegalArgumentException(s"Unexpected argument '$opt'.") } @@ -443,15 +454,20 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { + // scalastyle:off println val outStream = SparkSubmit.printStream if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) } - outStream.println( + val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] - |Usage: spark-submit --status [submission ID] --master [spark://...] - | + |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin) + outStream.println(command) + + val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB + outStream.println( + s""" |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or @@ -477,7 +493,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --properties-file FILE Path to a file from which to load extra properties. If not | specified, this will look for conf/spark-defaults.conf. | - | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512M). + | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: ${mem_mb}M). | --driver-java-options Extra Java options to pass to the driver. | --driver-library-path Extra library path entries to pass to the driver. | --driver-class-path Extra class path entries to pass to the driver. Note that @@ -523,6 +539,66 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | delegation tokens periodically. """.stripMargin ) - SparkSubmit.exitFn() + + if (SparkSubmit.isSqlShell(mainClass)) { + outStream.println("CLI options:") + outStream.println(getSqlShellOptions()) + } + // scalastyle:on println + + SparkSubmit.exitFn(exitCode) } + + /** + * Run the Spark SQL CLI main class with the "--help" option and catch its output. Then filter + * the results to remove unwanted lines. + * + * Since the CLI will call `System.exit()`, we install a security manager to prevent that call + * from working, and restore the original one afterwards. + */ + private def getSqlShellOptions(): String = { + val currentOut = System.out + val currentErr = System.err + val currentSm = System.getSecurityManager() + try { + val out = new ByteArrayOutputStream() + val stream = new PrintStream(out) + System.setOut(stream) + System.setErr(stream) + + val sm = new SecurityManager() { + override def checkExit(status: Int): Unit = { + throw new SecurityException() + } + + override def checkPermission(perm: java.security.Permission): Unit = {} + } + System.setSecurityManager(sm) + + try { + Utils.classForName(mainClass).getMethod("main", classOf[Array[String]]) + .invoke(null, Array(HELP)) + } catch { + case e: InvocationTargetException => + // Ignore SecurityException, since we throw it above. + if (!e.getCause().isInstanceOf[SecurityException]) { + throw e + } + } + + stream.flush() + + // Get the output and discard any unnecessary lines from it. + Source.fromString(new String(out.toByteArray())).getLines + .filter { line => + !line.startsWith("log4j") && !line.startsWith("usage") + } + .mkString("\n") + } finally { + System.setSecurityManager(currentSm) + System.setOut(currentOut) + System.setErr(currentErr) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 43c8a934c311a..79b251e7e62fe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -17,20 +17,17 @@ package org.apache.spark.deploy.client -import java.util.concurrent.TimeoutException +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ThreadUtils, Utils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -40,98 +37,143 @@ import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} * @param masterUrls Each url should look like spark://host:port. */ private[spark] class AppClient( - actorSystem: ActorSystem, + rpcEnv: RpcEnv, masterUrls: Array[String], appDescription: ApplicationDescription, listener: AppClientListener, conf: SparkConf) extends Logging { - private val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) + private val masterRpcAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) - private val REGISTRATION_TIMEOUT = 20.seconds + private val REGISTRATION_TIMEOUT_SECONDS = 20 private val REGISTRATION_RETRIES = 3 - private var masterAddress: Address = null - private var actor: ActorRef = null + private var endpoint: RpcEndpointRef = null private var appId: String = null - private var registered = false - private var activeMasterUrl: String = null + @volatile private var registered = false + + private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint + with Logging { + + private var master: Option[RpcEndpointRef] = None + // To avoid calling listener.disconnected() multiple times + private var alreadyDisconnected = false + @volatile private var alreadyDead = false // To avoid calling listener.dead() multiple times + @volatile private var registerMasterFutures: Array[JFuture[_]] = null + @volatile private var registrationRetryTimer: JScheduledFuture[_] = null + + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) - private class ClientActor extends Actor with ActorLogReceive with Logging { - var master: ActorSelection = null - var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times - var alreadyDead = false // To avoid calling listener.dead() multiple times - var registrationRetryTimer: Option[Cancellable] = None + // A scheduled executor for scheduling the registration actions + private val registrationRetryThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread") - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + override def onStart(): Unit = { try { - registerWithMaster() + registerWithMaster(1) } catch { case e: Exception => logWarning("Failed to connect to master", e) markDisconnected() - context.stop(self) + stop() } } - def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterApplication(appDescription) + /** + * Register with all masters asynchronously and returns an array `Future`s for cancellation. + */ + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + for (masterAddress <- masterRpcAddresses) yield { + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = try { + if (registered) { + return + } + logInfo("Connecting to master " + masterAddress.toSparkURL + "...") + val masterRef = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterRef.send(RegisterApplication(appDescription, self)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + }) } } - def registerWithMaster() { - tryRegisterAllMasters() - import context.dispatcher - var retries = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { + /** + * Register with all masters asynchronously. It will call `registerWithMaster` every + * REGISTRATION_TIMEOUT_SECONDS seconds until exceeding REGISTRATION_RETRIES times. + * Once we connect to a master successfully, all scheduling work and Futures will be cancelled. + * + * nthRetry means this is the nth attempt to register with master. + */ + private def registerWithMaster(nthRetry: Int) { + registerMasterFutures = tryRegisterAllMasters() + registrationRetryTimer = registrationRetryThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = { Utils.tryOrExit { - retries += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - } else if (retries >= REGISTRATION_RETRIES) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() + } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { - tryRegisterAllMasters() + registerMasterFutures.foreach(_.cancel(true)) + registerWithMaster(nthRetry + 1) } } } - } + }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) } - def changeMaster(url: String) { - // activeMasterUrl is a valid Spark url since we receive it from master. - activeMasterUrl = url - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem)) + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => logWarning(s"Drop $message because has not yet connected to master") + } } - private def isPossibleMaster(remoteUrl: Address) = { - masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort) + private def isPossibleMaster(remoteAddress: RpcAddress): Boolean = { + masterRpcAddresses.contains(remoteAddress) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisteredApplication(appId_, masterUrl) => + override def receive: PartialFunction[Any, Unit] = { + case RegisteredApplication(appId_, masterRef) => + // FIXME How to handle the following cases? + // 1. A master receives multiple registrations and sends back multiple + // RegisteredApplications due to an unstable network. + // 2. Receive multiple RegisteredApplication from different masters because the master is + // changing. appId = appId_ registered = true - changeMaster(masterUrl) + master = Some(masterRef) listener.connected(appId) case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) - context.stop(self) + stop() case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - master ! ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None) + // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not + // guaranteed), `ExecutorStateChanged` may be sent to a dead master. + sendToMaster(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => @@ -142,24 +184,32 @@ private[spark] class AppClient( listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) } - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + master = Some(masterRef) alreadyDisconnected = false - sender ! MasterChangeAcknowledged(appId) + masterRef.send(MasterChangeAcknowledged(appId)) + } - case DisassociatedEvent(_, address, _) if address == masterAddress => + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case StopAppClient => + markDead("Application has been stopped.") + sendToMaster(UnregisterApplication(appId)) + context.reply(true) + stop() + } + + override def onDisconnected(address: RpcAddress): Unit = { + if (master.exists(_.address == address)) { logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() + } + } - case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => + override def onNetworkError(cause: Throwable, address: RpcAddress): Unit = { + if (isPossibleMaster(address)) { logWarning(s"Could not connect to $address: $cause") - - case StopAppClient => - markDead("Application has been stopped.") - master ! UnregisterApplication(appId) - sender ! true - context.stop(self) + } } /** @@ -179,28 +229,31 @@ private[spark] class AppClient( } } - override def postStop() { - registrationRetryTimer.foreach(_.cancel()) + override def onStop(): Unit = { + if (registrationRetryTimer != null) { + registrationRetryTimer.cancel(true) + } + registrationRetryThread.shutdownNow() + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() } } def start() { // Just launch an actor; it will call back into the listener. - actor = actorSystem.actorOf(Props(new ClientActor)) + endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv)) } def stop() { - if (actor != null) { + if (endpoint != null) { try { - val timeout = RpcUtils.askTimeout(conf) - val future = actor.ask(StopAppClient)(timeout) - Await.result(future, timeout) + endpoint.askWithRetry[Boolean](StopAppClient) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") } - actor = null + endpoint = null } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 40835b9550586..1c79089303e3d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.client +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] object TestClient { @@ -46,13 +47,12 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localHostName(), 0, - conf = conf, securityManager = new SecurityManager(conf)) + val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf)) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener - val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) + val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf) client.start() - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala index c5ac45c6730d3..a98b1fa8f83a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala @@ -19,7 +19,9 @@ package org.apache.spark.deploy.client private[spark] object TestExecutor { def main(args: Array[String]) { + // scalastyle:off println println("Hello world!") + // scalastyle:on println while (true) { Thread.sleep(1000) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 298a8201960d1..5f5e0fe1c34d7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -17,6 +17,9 @@ package org.apache.spark.deploy.history +import java.util.zip.ZipOutputStream + +import org.apache.spark.SparkException import org.apache.spark.ui.SparkUI private[spark] case class ApplicationAttemptInfo( @@ -62,4 +65,12 @@ private[history] abstract class ApplicationHistoryProvider { */ def getConfig(): Map[String, String] = Map() + /** + * Writes out the event logs to the output stream provided. The logs will be compressed into a + * single zip file and written out. + * @throws SparkException if the logs for the app id cannot be found. + */ + @throws(classOf[SparkException]) + def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 45c2be34c8680..2cc465e55fceb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -17,16 +17,18 @@ package org.apache.spark.deploy.history -import java.io.{BufferedInputStream, FileNotFoundException, IOException, InputStream} +import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} import java.util.concurrent.{ExecutorService, Executors, TimeUnit} +import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable +import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.permission.AccessControlException -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.scheduler._ @@ -59,7 +61,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .map { d => Utils.resolveURI(d).toString } .getOrElse(DEFAULT_LOG_DIR) - private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf)) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + private val fs = Utils.getHadoopFileSystem(logDir, hadoopConf) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs @@ -80,12 +83,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] - // Constants used to parse Spark 1.0.0 log directories. - private[history] val LOG_PREFIX = "EVENT_LOG_" - private[history] val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" - private[history] val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" - private[history] val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" - /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -143,7 +140,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) override def getAppUI(appId: String, attemptId: Option[String]): Option[SparkUI] = { try { applications.get(appId).flatMap { appInfo => - appInfo.attempts.find(_.attemptId == attemptId).map { attempt => + appInfo.attempts.find(_.attemptId == attemptId).flatMap { attempt => val replayBus = new ReplayListenerBus() val ui = { val conf = this.conf.clone() @@ -152,20 +149,20 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } - val appListener = new ApplicationEventListener() replayBus.addListener(appListener) val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - - ui.setAppName(s"${appInfo.name} ($appId)") - - val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setAcls(uiAclsEnabled) - // make sure to set admin acls before view acls so they are properly picked up - ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) - ui.getSecurityManager.setViewAcls(attempt.sparkUser, - appListener.viewAcls.getOrElse("")) - ui + appInfo.map { info => + ui.setAppName(s"${info.name} ($appId)") + + val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) + ui.getSecurityManager.setAcls(uiAclsEnabled) + // make sure to set admin acls before view acls so they are properly picked up + ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) + ui.getSecurityManager.setViewAcls(attempt.sparkUser, + appListener.viewAcls.getOrElse("")) + ui + } } } } catch { @@ -219,6 +216,58 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + override def writeEventLogs( + appId: String, + attemptId: Option[String], + zipStream: ZipOutputStream): Unit = { + + /** + * This method compresses the files passed in, and writes the compressed data out into the + * [[OutputStream]] passed in. Each file is written as a new [[ZipEntry]] with its name being + * the name of the file being compressed. + */ + def zipFileToStream(file: Path, entryName: String, outputStream: ZipOutputStream): Unit = { + val fs = FileSystem.get(hadoopConf) + val inputStream = fs.open(file, 1 * 1024 * 1024) // 1MB Buffer + try { + outputStream.putNextEntry(new ZipEntry(entryName)) + ByteStreams.copy(inputStream, outputStream) + outputStream.closeEntry() + } finally { + inputStream.close() + } + } + + applications.get(appId) match { + case Some(appInfo) => + try { + // If no attempt is specified, or there is no attemptId for attempts, return all attempts + appInfo.attempts.filter { attempt => + attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get + }.foreach { attempt => + val logPath = new Path(logDir, attempt.logPath) + // If this is a legacy directory, then add the directory to the zipStream and add + // each file to that directory. + if (isLegacyLogDirectory(fs.getFileStatus(logPath))) { + val files = fs.listStatus(logPath) + zipStream.putNextEntry(new ZipEntry(attempt.logPath + "/")) + zipStream.closeEntry() + files.foreach { file => + val path = file.getPath + zipFileToStream(path, attempt.logPath + Path.SEPARATOR + path.getName, zipStream) + } + } else { + zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) + } + } + } finally { + zipStream.close() + } + case None => throw new SparkException(s"Logs for $appId not found.") + } + } + + /** * Replay the log files in the list and merge the list of old applications with new ones */ @@ -227,8 +276,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val newAttempts = logs.flatMap { fileStatus => try { val res = replay(fileStatus, bus) - logInfo(s"Application log ${res.logPath} loaded successfully.") - Some(res) + res match { + case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") + case None => logWarning(s"Failed to load application log ${fileStatus.getPath}. " + + "The application may have not started.") + } + res } catch { case e: Exception => logError( @@ -374,9 +427,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replays the events in the specified log file and returns information about the associated - * application. + * application. Return `None` if the application ID cannot be located. */ - private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationAttemptInfo = { + private def replay( + eventLog: FileStatus, + bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") val logInput = @@ -390,16 +445,24 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) bus.replay(logInput, logPath.toString, !appCompleted) - new FsApplicationAttemptInfo( - logPath.getName(), - appListener.appName.getOrElse(NOT_STARTED), - appListener.appId.getOrElse(logPath.getName()), - appListener.appAttemptId, - appListener.startTime.getOrElse(-1L), - appListener.endTime.getOrElse(-1L), - getModificationTime(eventLog).get, - appListener.sparkUser.getOrElse(NOT_STARTED), - appCompleted) + + // Without an app ID, new logs will render incorrectly in the listing page, so do not list or + // try to show their UI. Some old versions of Spark generate logs without an app ID, so let + // logs generated by those versions go through. + if (appListener.appId.isDefined || !sparkVersionHasAppId(eventLog)) { + Some(new FsApplicationAttemptInfo( + logPath.getName(), + appListener.appName.getOrElse(NOT_STARTED), + appListener.appId.getOrElse(logPath.getName()), + appListener.appAttemptId, + appListener.startTime.getOrElse(-1L), + appListener.endTime.getOrElse(-1L), + getModificationTime(eventLog).get, + appListener.sparkUser.getOrElse(NOT_STARTED), + appCompleted)) + } else { + None + } } finally { logInput.close() } @@ -474,10 +537,34 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + /** + * Returns whether the version of Spark that generated logs records app IDs. App IDs were added + * in Spark 1.1. + */ + private def sparkVersionHasAppId(entry: FileStatus): Boolean = { + if (isLegacyLogDirectory(entry)) { + fs.listStatus(entry.getPath()) + .find { status => status.getPath().getName().startsWith(SPARK_VERSION_PREFIX) } + .map { status => + val version = status.getPath().getName().substring(SPARK_VERSION_PREFIX.length()) + version != "1.0" && version != "1.1" + } + .getOrElse(true) + } else { + true + } + } + } -private object FsHistoryProvider { +private[history] object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" + + // Constants used to parse Spark 1.0.0 log directories. + val LOG_PREFIX = "EVENT_LOG_" + val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" + val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" + val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" } private class FsApplicationAttemptInfo( diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 517cbe5176241..a076a9c3f984d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.history import java.util.NoSuchElementException +import java.util.zip.ZipOutputStream import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import com.google.common.cache._ @@ -25,7 +26,8 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.status.api.v1.{ApplicationInfo, ApplicationsListResource, JsonRootResource, UIRoot} +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, + UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.{SignalLogger, Utils} @@ -125,7 +127,7 @@ class HistoryServer( def initialize() { attachPage(new HistoryPage(this)) - attachHandler(JsonRootResource.getJsonServlet(this)) + attachHandler(ApiRootResource.getServletHandler(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) @@ -172,6 +174,13 @@ class HistoryServer( getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } + override def writeEventLogs( + appId: String, + attemptId: Option[String], + zipStream: ZipOutputStream): Unit = { + provider.writeEventLogs(appId, attemptId, zipStream) + } + /** * Returns the provider configuration to show in the listing page. * @@ -219,7 +228,7 @@ object HistoryServer extends Logging { val providerName = conf.getOption("spark.history.provider") .getOrElse(classOf[FsHistoryProvider].getName()) - val provider = Class.forName(providerName) + val provider = Utils.classForName(providerName) .getConstructor(classOf[SparkConf]) .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index a2a97a7877ce7..18265df9faa2c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.Utils /** * Command-line parser for the master. */ -private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) +private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { private var propertiesFile: String = null @@ -56,6 +56,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin Utils.loadDefaultSparkProperties(conf, propertiesFile) private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( """ |Usage: HistoryServer [options] @@ -84,6 +85,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin | spark.history.fs.updateInterval How often to reload log data from storage | (in seconds, default: 10) |""".stripMargin) + // scalastyle:on println System.exit(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 1620e95bea218..aa54ed9360f36 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -22,10 +22,9 @@ import java.util.Date import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorRef - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.ApplicationDescription +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class ApplicationInfo( @@ -33,7 +32,7 @@ private[spark] class ApplicationInfo( val id: String, val desc: ApplicationDescription, val submitDate: Date, - val driver: ActorRef, + val driver: RpcEndpointRef, defaultCores: Int) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index fccceb3ea528b..245b047e7dfbd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -21,20 +21,18 @@ import java.io.FileNotFoundException import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date +import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.Serialization import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil} @@ -47,23 +45,27 @@ import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} private[master] class Master( - host: String, - port: Int, + override val rpcEnv: RpcEnv, + address: RpcAddress, webUiPort: Int, val securityMgr: SecurityManager, val conf: SparkConf) - extends Actor with ActorLogReceive with Logging with LeaderElectable { + extends ThreadSafeRpcEndpoint with Logging with LeaderElectable { - import context.dispatcher // to use Akka's scheduler.schedule() + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") + + // TODO Remove it once we don't use akka.serialization.Serialization + private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - private val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000 + private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) @@ -75,10 +77,10 @@ private[master] class Master( val apps = new HashSet[ApplicationInfo] private val idToWorker = new HashMap[String, WorkerInfo] - private val addressToWorker = new HashMap[Address, WorkerInfo] + private val addressToWorker = new HashMap[RpcAddress, WorkerInfo] - private val actorToApp = new HashMap[ActorRef, ApplicationInfo] - private val addressToApp = new HashMap[Address, ApplicationInfo] + private val endpointToApp = new HashMap[RpcEndpointRef, ApplicationInfo] + private val addressToApp = new HashMap[RpcAddress, ApplicationInfo] private val completedApps = new ArrayBuffer[ApplicationInfo] private var nextAppNumber = 0 private val appIdToUI = new HashMap[String, SparkUI] @@ -89,21 +91,22 @@ private[master] class Master( private val waitingDrivers = new ArrayBuffer[DriverInfo] private var nextDriverNumber = 0 - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(address.host, "Expected hostname") private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, securityMgr) private val masterSource = new MasterSource(this) - private val webUi = new MasterWebUI(this, webUiPort) + // After onStart, webUi will be set + private var webUi: MasterWebUI = null private val masterPublicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else host + if (envVar != null) envVar else address.host } - private val masterUrl = "spark://" + host + ":" + port + private val masterUrl = address.toSparkURL private var masterWebUiUrl: String = _ private var state = RecoveryState.STANDBY @@ -112,7 +115,9 @@ private[master] class Master( private var leaderElectionAgent: LeaderElectionAgent = _ - private var recoveryCompletionTask: Cancellable = _ + private var recoveryCompletionTask: ScheduledFuture[_] = _ + + private var checkForWorkerTimeOutTask: ScheduledFuture[_] = _ // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app @@ -130,20 +135,23 @@ private[master] class Master( private val restServer = if (restServerEnabled) { val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(host, port, conf, self, masterUrl)) + Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) } else { None } private val restServerBoundPort = restServer.map(_.start()) - override def preStart() { + override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + webUi = new MasterWebUI(this, webUiPort) webUi.bind() masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort - context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) + checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CheckForWorkerTimeOut) + } + }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() @@ -157,16 +165,16 @@ private[master] class Master( case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system)) + new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem)) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system)) + new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem)) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => - val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) + val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory")) val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization]) - .newInstance(conf, SerializationExtension(context.system)) + .newInstance(conf, SerializationExtension(actorSystem)) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -176,18 +184,17 @@ private[master] class Master( leaderElectionAgent = leaderElectionAgent_ } - override def preRestart(reason: Throwable, message: Option[Any]) { - super.preRestart(reason, message) // calls postStop()! - logError("Master actor restarted due to exception", reason) - } - - override def postStop() { + override def onStop() { masterMetricsSystem.report() applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master if (recoveryCompletionTask != null) { - recoveryCompletionTask.cancel() + recoveryCompletionTask.cancel(true) } + if (checkForWorkerTimeOutTask != null) { + checkForWorkerTimeOutTask.cancel(true) + } + forwardMessageThread.shutdownNow() webUi.stop() restServer.foreach(_.stop()) masterMetricsSystem.stop() @@ -197,14 +204,14 @@ private[master] class Master( } override def electedLeader() { - self ! ElectedLeader + self.send(ElectedLeader) } override def revokedLeadership() { - self ! RevokedLeadership + self.send(RevokedLeadership) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { @@ -215,8 +222,11 @@ private[master] class Master( logInfo("I have been elected leader! New state: " + state) if (state == RecoveryState.RECOVERING) { beginRecovery(storedApps, storedDrivers, storedWorkers) - recoveryCompletionTask = context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis, self, - CompleteRecovery) + recoveryCompletionTask = forwardMessageThread.schedule(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CompleteRecovery) + } + }, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) } } @@ -227,111 +237,42 @@ private[master] class Master( System.exit(0) } - case RegisterWorker(id, workerHost, workerPort, cores, memory, workerUiPort, publicAddress) => - { + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { // ignore, don't send response } else if (idToWorker.contains(id)) { - sender ! RegisterWorkerFailed("Duplicate worker ID") + workerRef.send(RegisterWorkerFailed("Duplicate worker ID")) } else { val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - sender, workerUiPort, publicAddress) + workerRef, workerUiPort, publicAddress) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - sender ! RegisteredWorker(masterUrl, masterWebUiUrl) + workerRef.send(RegisteredWorker(self, masterWebUiUrl)) schedule() } else { - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address logWarning("Worker registration failed. Attempted to re-register worker at same " + "address: " + workerAddress) - sender ! RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress) - } - } - } - - case RequestSubmitDriver(description) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - "Can only accept driver submissions in ALIVE state." - sender ! SubmitDriverResponse(false, None, msg) - } else { - logInfo("Driver submitted " + description.command.mainClass) - val driver = createDriver(description) - persistenceEngine.addDriver(driver) - waitingDrivers += driver - drivers.add(driver) - schedule() - - // TODO: It might be good to instead have the submission client poll the master to determine - // the current status of the driver. For now it's simply "fire and forget". - - sender ! SubmitDriverResponse(true, Some(driver.id), - s"Driver successfully submitted as ${driver.id}") - } - } - - case RequestKillDriver(driverId) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - s"Can only kill drivers in ALIVE state." - sender ! KillDriverResponse(driverId, success = false, msg) - } else { - logInfo("Asked to kill driver " + driverId) - val driver = drivers.find(_.id == driverId) - driver match { - case Some(d) => - if (waitingDrivers.contains(d)) { - waitingDrivers -= d - self ! DriverStateChanged(driverId, DriverState.KILLED, None) - } else { - // We just notify the worker to kill the driver here. The final bookkeeping occurs - // on the return path when the worker submits a state change back to the master - // to notify it that the driver was successfully killed. - d.worker.foreach { w => - w.actor ! KillDriver(driverId) - } - } - // TODO: It would be nice for this to be a synchronous response - val msg = s"Kill request for $driverId submitted" - logInfo(msg) - sender ! KillDriverResponse(driverId, success = true, msg) - case None => - val msg = s"Driver $driverId has already finished or does not exist" - logWarning(msg) - sender ! KillDriverResponse(driverId, success = false, msg) - } - } - } - - case RequestDriverStatus(driverId) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - "Can only request driver status in ALIVE state." - sender ! DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg))) - } else { - (drivers ++ completedDrivers).find(_.id == driverId) match { - case Some(driver) => - sender ! DriverStatusResponse(found = true, Some(driver.state), - driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) - case None => - sender ! DriverStatusResponse(found = false, None, None, None, None) + workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) } } } - case RegisterApplication(description) => { + case RegisterApplication(description, driver) => { + // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { // ignore, don't send response } else { logInfo("Registering app " + description.name) - val app = createApplication(description, sender) + val app = createApplication(description, driver) registerApplication(app) logInfo("Registered app " + description.name + " with ID " + app.id) persistenceEngine.addApplication(app) - sender ! RegisteredApplication(app.id, masterUrl) + driver.send(RegisteredApplication(app.id, self)) schedule() } } @@ -343,7 +284,7 @@ private[master] class Master( val appInfo = idToApp(appId) exec.state = state if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } - exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") @@ -384,7 +325,7 @@ private[master] class Master( } } - case Heartbeat(workerId) => { + case Heartbeat(workerId, worker) => { idToWorker.get(workerId) match { case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() @@ -392,7 +333,7 @@ private[master] class Master( if (workers.map(_.id).contains(workerId)) { logWarning(s"Got heartbeat from unregistered worker $workerId." + " Asking it to re-register.") - sender ! ReconnectWorker(masterUrl) + worker.send(ReconnectWorker(masterUrl)) } else { logWarning(s"Got heartbeat from unregistered worker $workerId." + " This worker was never registered, so ignoring the heartbeat.") @@ -444,30 +385,103 @@ private[master] class Master( logInfo(s"Received unregister request from application $applicationId") idToApp.get(applicationId).foreach(finishApplication) - case DisassociatedEvent(_, address, _) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - logInfo(s"$address got disassociated, removing it.") - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + case CheckForWorkerTimeOut => { + timeOutDeadWorkers() } + } - case RequestMasterState => { - sender ! MasterStateResponse( - host, port, restServerBoundPort, - workers.toArray, apps.toArray, completedApps.toArray, - drivers.toArray, completedDrivers.toArray, state) + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestSubmitDriver(description) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only accept driver submissions in ALIVE state." + context.reply(SubmitDriverResponse(self, false, None, msg)) + } else { + logInfo("Driver submitted " + description.command.mainClass) + val driver = createDriver(description) + persistenceEngine.addDriver(driver) + waitingDrivers += driver + drivers.add(driver) + schedule() + + // TODO: It might be good to instead have the submission client poll the master to determine + // the current status of the driver. For now it's simply "fire and forget". + + context.reply(SubmitDriverResponse(self, true, Some(driver.id), + s"Driver successfully submitted as ${driver.id}")) + } } - case CheckForWorkerTimeOut => { - timeOutDeadWorkers() + case RequestKillDriver(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + s"Can only kill drivers in ALIVE state." + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } else { + logInfo("Asked to kill driver " + driverId) + val driver = drivers.find(_.id == driverId) + driver match { + case Some(d) => + if (waitingDrivers.contains(d)) { + waitingDrivers -= d + self.send(DriverStateChanged(driverId, DriverState.KILLED, None)) + } else { + // We just notify the worker to kill the driver here. The final bookkeeping occurs + // on the return path when the worker submits a state change back to the master + // to notify it that the driver was successfully killed. + d.worker.foreach { w => + w.endpoint.send(KillDriver(driverId)) + } + } + // TODO: It would be nice for this to be a synchronous response + val msg = s"Kill request for $driverId submitted" + logInfo(msg) + context.reply(KillDriverResponse(self, driverId, success = true, msg)) + case None => + val msg = s"Driver $driverId has already finished or does not exist" + logWarning(msg) + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } + } + } + + case RequestDriverStatus(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only request driver status in ALIVE state." + context.reply( + DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg)))) + } else { + (drivers ++ completedDrivers).find(_.id == driverId) match { + case Some(driver) => + context.reply(DriverStatusResponse(found = true, Some(driver.state), + driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception)) + case None => + context.reply(DriverStatusResponse(found = false, None, None, None, None)) + } + } + } + + case RequestMasterState => { + context.reply(MasterStateResponse( + address.host, address.port, restServerBoundPort, + workers.toArray, apps.toArray, completedApps.toArray, + drivers.toArray, completedDrivers.toArray, state)) } case BoundPortsRequest => { - sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort) + context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort)) } } + override def onDisconnected(address: RpcAddress): Unit = { + // The disconnected client could've been either a worker or an app; remove whichever it was + logInfo(s"$address got disassociated, removing it.") + addressToWorker.get(address).foreach(removeWorker) + addressToApp.get(address).foreach(finishApplication) + if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + } + private def canCompleteRecovery = workers.count(_.state == WorkerState.UNKNOWN) == 0 && apps.count(_.state == ApplicationState.UNKNOWN) == 0 @@ -479,7 +493,7 @@ private[master] class Master( try { registerApplication(app) app.state = ApplicationState.UNKNOWN - app.driver ! MasterChanged(masterUrl, masterWebUiUrl) + app.driver.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("App " + app.id + " had exception on reconnect") } @@ -496,7 +510,7 @@ private[master] class Master( try { registerWorker(worker) worker.state = WorkerState.UNKNOWN - worker.actor ! MasterChanged(masterUrl, masterWebUiUrl) + worker.endpoint.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect") } @@ -505,10 +519,8 @@ private[master] class Master( private def completeRecovery() { // Ensure "only-once" recovery semantics using a short synchronization period. - synchronized { - if (state != RecoveryState.RECOVERING) { return } - state = RecoveryState.COMPLETING_RECOVERY - } + if (state != RecoveryState.RECOVERING) { return } + state = RecoveryState.COMPLETING_RECOVERY // Kill off any workers and apps that didn't respond to us. workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) @@ -623,10 +635,10 @@ private[master] class Master( private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc): Unit = { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(masterUrl, - exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory) - exec.application.driver ! ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) + worker.endpoint.send(LaunchExecutor(masterUrl, + exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)) + exec.application.driver.send(ExecutorAdded( + exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) } private def registerWorker(worker: WorkerInfo): Boolean = { @@ -638,7 +650,7 @@ private[master] class Master( workers -= w } - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address if (addressToWorker.contains(workerAddress)) { val oldWorker = addressToWorker(workerAddress) if (oldWorker.state == WorkerState.UNKNOWN) { @@ -661,11 +673,11 @@ private[master] class Master( logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) idToWorker -= worker.id - addressToWorker -= worker.actor.path.address + addressToWorker -= worker.endpoint.address for (exec <- worker.executors.values) { logInfo("Telling app of lost executor: " + exec.id) - exec.application.driver ! ExecutorUpdated( - exec.id, ExecutorState.LOST, Some("worker lost"), None) + exec.application.driver.send(ExecutorUpdated( + exec.id, ExecutorState.LOST, Some("worker lost"), None)) exec.application.removeExecutor(exec) } for (driver <- worker.drivers.values) { @@ -687,14 +699,15 @@ private[master] class Master( schedule() } - private def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { + private def createApplication(desc: ApplicationDescription, driver: RpcEndpointRef): + ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores) } private def registerApplication(app: ApplicationInfo): Unit = { - val appAddress = app.driver.path.address + val appAddress = app.driver.address if (addressToApp.contains(appAddress)) { logInfo("Attempted to re-register application at same address: " + appAddress) return @@ -703,7 +716,7 @@ private[master] class Master( applicationMetricsSystem.registerSource(app.appSource) apps += app idToApp(app.id) = app - actorToApp(app.driver) = app + endpointToApp(app.driver) = app addressToApp(appAddress) = app waitingApps += app } @@ -717,8 +730,8 @@ private[master] class Master( logInfo("Removing app " + app.id) apps -= app idToApp -= app.id - actorToApp -= app.driver - addressToApp -= app.driver.path.address + endpointToApp -= app.driver + addressToApp -= app.driver.address if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { @@ -735,19 +748,19 @@ private[master] class Master( for (exec <- app.executors.values) { exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id) + exec.worker.endpoint.send(KillExecutor(masterUrl, exec.application.id, exec.id)) exec.state = ExecutorState.KILLED } app.markFinished(state) if (state != ApplicationState.FINISHED) { - app.driver ! ApplicationRemoved(state.toString) + app.driver.send(ApplicationRemoved(state.toString)) } persistenceEngine.removeApplication(app) schedule() // Tell all workers that the application has finished, so they can clean up any app state. workers.foreach { w => - w.actor ! ApplicationFinished(app.id) + w.endpoint.send(ApplicationFinished(app.id)) } } } @@ -768,7 +781,7 @@ private[master] class Master( } val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, None, app.desc.eventLogCodec) + eventLogDir, app.id, app.desc.eventLogCodec) val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) val inProgressExists = fs.exists(new Path(eventLogFilePrefix + EventLoggingListener.IN_PROGRESS)) @@ -832,14 +845,14 @@ private[master] class Master( private def timeOutDeadWorkers() { // Copy the workers into an array so we don't modify the hashset while iterating through it val currentTime = System.currentTimeMillis() - val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray + val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT_MS).toArray for (worker <- toRemove) { if (worker.state != WorkerState.DEAD) { logWarning("Removing %s because we got no heartbeat in %d seconds".format( - worker.id, WORKER_TIMEOUT/1000)) + worker.id, WORKER_TIMEOUT_MS / 1000)) removeWorker(worker) } else { - if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) { + if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT_MS)) { workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it } } @@ -862,7 +875,7 @@ private[master] class Master( logInfo("Launching driver " + driver.id + " on worker " + worker.id) worker.addDriver(driver) driver.worker = Some(worker) - worker.actor ! LaunchDriver(driver.id, driver.desc) + worker.endpoint.send(LaunchDriver(driver.id, driver.desc)) driver.state = DriverState.RUNNING } @@ -891,57 +904,33 @@ private[master] class Master( } private[deploy] object Master extends Logging { - val systemName = "sparkMaster" - private val actorName = "Master" + val SYSTEM_NAME = "sparkMaster" + val ENDPOINT_NAME = "Master" def main(argStrings: Array[String]) { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) - val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) - actorSystem.awaitTermination() - } - - /** - * Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid - */ - def toAkkaUrl(sparkUrl: String, protocol: String): String = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - AkkaUtils.address(protocol, systemName, host, port, actorName) - } - - /** - * Returns an akka `Address` for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid - */ - def toAkkaAddress(sparkUrl: String, protocol: String): Address = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - Address(protocol, systemName, host, port) + val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf) + rpcEnv.awaitTermination() } /** - * Start the Master and return a four tuple of: - * (1) The Master actor system - * (2) The bound port - * (3) The web UI bound port - * (4) The REST server bound port, if any + * Start the Master and return a three tuple of: + * (1) The Master RpcEnv + * (2) The web UI bound port + * (3) The REST server bound port, if any */ - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, - conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = { + conf: SparkConf): (RpcEnv, Int, Option[Int]) = { val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, - securityManager = securityMgr) - val actor = actorSystem.actorOf( - Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName) - val timeout = RpcUtils.askTimeout(conf) - val portsRequest = actor.ask(BoundPortsRequest)(timeout) - val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] - (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort) + val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr) + val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, + new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf)) + val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest) + (rpcEnv, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 435b9b12f83b8..44cefbc77f08e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -85,6 +85,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { * Print usage and exit JVM with the given exit code. */ private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Master [options]\n" + "\n" + @@ -95,6 +96,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8080)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index 15c6296888f70..68c937188b333 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -28,7 +28,7 @@ private[master] object MasterMessages { case object RevokedLeadership - // Actor System to Master + // Master to itself case object CheckForWorkerTimeOut diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala index 9c3f79f1244b7..66a9ff38678c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala @@ -30,6 +30,11 @@ private[spark] class MasterSource(val master: Master) extends Source { override def getValue: Int = master.workers.size }) + // Gauge for alive worker numbers in cluster + metricRegistry.register(MetricRegistry.name("aliveWorkers"), new Gauge[Int]{ + override def getValue: Int = master.workers.filter(_.state == WorkerState.ALIVE).size + }) + // Gauge for application numbers in cluster metricRegistry.register(MetricRegistry.name("apps"), new Gauge[Int] { override def getValue: Int = master.apps.size diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 9b3d48c6edc84..f751966605206 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -19,9 +19,7 @@ package org.apache.spark.deploy.master import scala.collection.mutable -import akka.actor.ActorRef - -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class WorkerInfo( @@ -30,7 +28,7 @@ private[spark] class WorkerInfo( val port: Int, val cores: Int, val memory: Int, - val actor: ActorRef, + val endpoint: RpcEndpointRef, val webUiPort: Int, val publicAddress: String) extends Serializable { @@ -107,4 +105,6 @@ private[spark] class WorkerInfo( def setState(state: WorkerState.Value): Unit = { this.state = state } + + def isAlive(): Boolean = this.state == WorkerState.ALIVE } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 52758d6a7c4be..6fdff86f66e01 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -17,10 +17,7 @@ package org.apache.spark.deploy.master -import akka.actor.ActorRef - import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} import org.apache.spark.deploy.SparkCuratorUtil diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 80db6d474b5c1..328d95a7a0c68 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -32,7 +32,7 @@ import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) extends PersistenceEngine with Logging { - + private val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" private val zk: CuratorFramework = SparkCuratorUtil.newClient(conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 06e265f99e231..e28e7e379ac91 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -19,11 +19,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask - import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.ExecutorDesc @@ -32,14 +29,12 @@ import org.apache.spark.util.Utils private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private val master = parent.masterEndpointRef /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithRetry[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 756927682cd24..c3e20ebf8d6eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -19,25 +19,21 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import org.json4s.JValue import org.apache.spark.deploy.JsonProtocol -import org.apache.spark.deploy.DeployMessages.{RequestKillDriver, MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver, MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master._ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private val master = parent.masterEndpointRef def getMasterState: MasterStateResponse = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - Await.result(stateFuture, timeout) + master.askWithRetry[MasterStateResponse](RequestMasterState) } override def renderJson(request: HttpServletRequest): JValue = { @@ -53,7 +49,9 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } def handleDriverKillRequest(request: HttpServletRequest): Unit = { - handleKillRequest(request, id => { master ! RequestKillDriver(id) }) + handleKillRequest(request, id => { + master.ask[KillDriverResponse](RequestKillDriver(id)) + }) } private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { @@ -75,6 +73,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory") val workers = state.workers.sortBy(_.id) + val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", @@ -108,12 +107,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { }.getOrElse { Seq.empty } } -
  • Workers: {state.workers.size}
  • -
  • Cores: {state.workers.map(_.cores).sum} Total, - {state.workers.map(_.coresUsed).sum} Used
  • -
  • Memory: - {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total, - {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
  • +
  • Alive Workers: {aliveWorkers.size}
  • +
  • Cores in use: {aliveWorkers.map(_.cores).sum} Total, + {aliveWorkers.map(_.coresUsed).sum} Used
  • +
  • Memory in use: + {Utils.megabytesToString(aliveWorkers.map(_.memory).sum)} Total, + {Utils.megabytesToString(aliveWorkers.map(_.memoryUsed).sum)} Used
  • Applications: {state.activeApps.size} Running, {state.completedApps.size} Completed
  • diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index eb26e9f99c70b..6174fc11f83d8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -19,10 +19,10 @@ package org.apache.spark.deploy.master.ui import org.apache.spark.Logging import org.apache.spark.deploy.master.Master -import org.apache.spark.status.api.v1.{ApplicationsListResource, ApplicationInfo, JsonRootResource, UIRoot} +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource, ApplicationInfo, + UIRoot} import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.RpcUtils /** * Web UI server for the standalone master. @@ -32,8 +32,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging with UIRoot { - val masterActorRef = master.self - val timeout = RpcUtils.askTimeout(master.conf) + val masterEndpointRef = master.self val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) val masterPage = new MasterPage(this) @@ -47,7 +46,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) attachPage(new HistoryNotFoundPage(this)) attachPage(masterPage) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(JsonRootResource.getJsonServlet(this)) + attachHandler(ApiRootResource.getServletHandler(this)) attachHandler(createRedirectHandler( "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 894cb78d8591a..5accaf78d0a51 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -54,7 +54,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case ("--master" | "-m") :: value :: tail => if (!value.startsWith("mesos://")) { + // scalastyle:off println System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + // scalastyle:on println System.exit(1) } masterUrl = value.stripPrefix("mesos://") @@ -73,7 +75,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case Nil => { if (masterUrl == null) { + // scalastyle:off println System.err.println("--master is required") + // scalastyle:on println printUsageAndExit(1) } } @@ -83,6 +87,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: } private def printUsageAndExit(exitCode: Int): Unit = { + // scalastyle:off println System.err.println( "Usage: MesosClusterDispatcher [options]\n" + "\n" + @@ -96,6 +101,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: " Zookeeper for persistence\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index be8560d10fc62..e8ef60bd5428a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -68,7 +68,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) val content =

    Driver state information for driver id {driverId}

    - Back to Drivers + Back to Drivers

    Driver state: {driverState.state}

    diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 6078f50518ba4..1fe956320a1b8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -57,7 +57,11 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { private val supportedMasterPrefixes = Seq("spark://", "mesos://") - private val masters: Array[String] = Utils.parseStandaloneMasterUrls(master) + private val masters: Array[String] = if (master.startsWith("spark://")) { + Utils.parseStandaloneMasterUrls(master) + } else { + Array(master) + } // Set of masters that lost contact with us, used to keep track of // whether there are masters still alive for us to communicate with diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 502b9bb701ccf..d5b9bcab1423f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -20,10 +20,10 @@ package org.apache.spark.deploy.rest import java.io.File import javax.servlet.http.HttpServletResponse -import akka.actor.ActorRef import org.apache.spark.deploy.ClientArguments._ import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.util.Utils import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} /** @@ -45,35 +45,34 @@ import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} * @param host the address this server should bind to * @param requestedPort the port this server will attempt to bind to * @param masterConf the conf used by the Master - * @param masterActor reference to the Master actor to which requests can be sent + * @param masterEndpoint reference to the Master endpoint to which requests can be sent * @param masterUrl the URL of the Master new drivers will attempt to connect to */ private[deploy] class StandaloneRestServer( host: String, requestedPort: Int, masterConf: SparkConf, - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String) extends RestSubmissionServer(host, requestedPort, masterConf) { protected override val submitRequestServlet = - new StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) + new StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) protected override val killRequestServlet = - new StandaloneKillRequestServlet(masterActor, masterConf) + new StandaloneKillRequestServlet(masterEndpoint, masterConf) protected override val statusRequestServlet = - new StandaloneStatusRequestServlet(masterActor, masterConf) + new StandaloneStatusRequestServlet(masterEndpoint, masterConf) } /** * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) extends KillRequestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse = { - val askTimeout = RpcUtils.askTimeout(conf) - val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( - DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.KillDriverResponse]( + DeployMessages.RequestKillDriver(submissionId)) val k = new KillSubmissionResponse k.serverSparkVersion = sparkVersion k.message = response.message @@ -86,13 +85,12 @@ private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: Sp /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) extends StatusRequestServlet { protected def handleStatus(submissionId: String): SubmissionStatusResponse = { - val askTimeout = RpcUtils.askTimeout(conf) - val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( - DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.DriverStatusResponse]( + DeployMessages.RequestDriverStatus(submissionId)) val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } val d = new SubmissionStatusResponse d.serverSparkVersion = sparkVersion @@ -110,7 +108,7 @@ private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. */ private[rest] class StandaloneSubmitRequestServlet( - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String, conf: SparkConf) extends SubmitRequestServlet { @@ -175,10 +173,9 @@ private[rest] class StandaloneSubmitRequestServlet( responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { requestMessage match { case submitRequest: CreateSubmissionRequest => - val askTimeout = RpcUtils.askTimeout(conf) val driverDescription = buildDriverDescription(submitRequest) - val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( - DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription)) val submitResponse = new CreateSubmissionResponse submitResponse.serverSparkVersion = sparkVersion submitResponse.message = response.message diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index e6615a3174ce1..ef5a7e35ad562 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -128,7 +128,7 @@ private[spark] object SubmitRestProtocolMessage { */ def fromJson(json: String): SubmitRestProtocolMessage = { val className = parseAction(json) - val clazz = Class.forName(packagePrefix + "." + className) + val clazz = Utils.classForName(packagePrefix + "." + className) .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage]) fromJson(json, clazz) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 8198296eeb341..868cc35d06ef3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -59,7 +59,7 @@ private[mesos] class MesosSubmitRequestServlet( extends SubmitRequestServlet { private val DEFAULT_SUPERVISE = false - private val DEFAULT_MEMORY = 512 // mb + private val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // mb private val DEFAULT_CORES = 1.0 private val nextDriverNumber = new AtomicLong(0) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 0a1d60f58bc58..45a3f43045437 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConversions._ import scala.collection.Map import org.apache.spark.Logging +import org.apache.spark.SecurityManager import org.apache.spark.deploy.Command import org.apache.spark.launcher.WorkerCommandBuilder import org.apache.spark.util.Utils @@ -40,12 +41,14 @@ object CommandUtils extends Logging { */ def buildProcessBuilder( command: Command, + securityMgr: SecurityManager, memory: Int, sparkHome: String, substituteArguments: String => String, classPaths: Seq[String] = Seq[String](), env: Map[String, String] = sys.env): ProcessBuilder = { - val localCommand = buildLocalCommand(command, substituteArguments, classPaths, env) + val localCommand = buildLocalCommand( + command, securityMgr, substituteArguments, classPaths, env) val commandSeq = buildCommandSeq(localCommand, memory, sparkHome) val builder = new ProcessBuilder(commandSeq: _*) val environment = builder.environment() @@ -69,6 +72,7 @@ object CommandUtils extends Logging { */ private def buildLocalCommand( command: Command, + securityMgr: SecurityManager, substituteArguments: String => String, classPath: Seq[String] = Seq[String](), env: Map[String, String]): Command = { @@ -76,20 +80,26 @@ object CommandUtils extends Logging { val libraryPathEntries = command.libraryPathEntries val cmdLibraryPath = command.environment.get(libraryPathName) - val newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) { + var newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) { val libraryPaths = libraryPathEntries ++ cmdLibraryPath ++ env.get(libraryPathName) command.environment + ((libraryPathName, libraryPaths.mkString(File.pathSeparator))) } else { command.environment } + // set auth secret to env variable if needed + if (securityMgr.isAuthenticationEnabled) { + newEnvironment += (SecurityManager.ENV_AUTH_SECRET -> securityMgr.getSecretKey) + } + Command( command.mainClass, command.arguments.map(substituteArguments), newEnvironment, command.classPathEntries ++ classPath, Seq[String](), // library path already captured in environment variable - command.javaOpts) + // filter out auth secret from java options + command.javaOpts.filterNot(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF))) } /** Spawn a thread that will redirect a given stream to a file */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index ef7a703bffe67..ec51c3d935d8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -21,7 +21,6 @@ import java.io._ import scala.collection.JavaConversions._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.fs.Path @@ -31,6 +30,7 @@ import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{Utils, Clock, SystemClock} /** @@ -43,7 +43,7 @@ private[deploy] class DriverRunner( val workDir: File, val sparkHome: File, val driverDesc: DriverDescription, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerUrl: String, val securityManager: SecurityManager) extends Logging { @@ -85,8 +85,8 @@ private[deploy] class DriverRunner( } // TODO: If we add ability to submit multiple jars they should also be added here - val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem, - sparkHome.getAbsolutePath, substituteVariables) + val builder = CommandUtils.buildProcessBuilder(driverDesc.command, securityManager, + driverDesc.mem, sparkHome.getAbsolutePath, substituteVariables) launchDriver(builder, driverDir, driverDesc.supervise) } catch { @@ -107,7 +107,7 @@ private[deploy] class DriverRunner( finalState = Some(state) - worker ! DriverStateChanged(driverId, state, finalException) + worker.send(DriverStateChanged(driverId, state, finalException)) } }.start() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index d1a12b01e78f7..6799f78ec0c19 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -53,14 +53,16 @@ object DriverWrapper { Thread.currentThread.setContextClassLoader(loader) // Delegate to supplied main class - val clazz = Class.forName(mainClass, true, loader) + val clazz = Utils.classForName(mainClass) val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) rpcEnv.shutdown() case _ => + // scalastyle:off println System.err.println("Usage: DriverWrapper [options]") + // scalastyle:on println System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 7aa85b732fc87..29a5042285578 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -21,11 +21,11 @@ import java.io._ import scala.collection.JavaConversions._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged import org.apache.spark.util.Utils @@ -41,7 +41,7 @@ private[deploy] class ExecutorRunner( val appDesc: ApplicationDescription, val cores: Int, val memory: Int, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerId: String, val host: String, val webUiPort: Int, @@ -91,7 +91,7 @@ private[deploy] class ExecutorRunner( process.destroy() exitCode = Some(process.waitFor()) } - worker ! ExecutorStateChanged(appId, execId, state, message, exitCode) + worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) } /** Stop this executor runner, including killing the process it launched */ @@ -125,8 +125,8 @@ private[deploy] class ExecutorRunner( private def fetchAndRunExecutor() { try { // Launch the process - val builder = CommandUtils.buildProcessBuilder(appDesc.command, memory, - sparkHome.getAbsolutePath, substituteVariables) + val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), + memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) @@ -159,7 +159,7 @@ private[deploy] class ExecutorRunner( val exitCode = process.waitFor() state = ExecutorState.EXITED val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)) + worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))) } catch { case interrupted: InterruptedException => { logInfo("Runner thread for executor " + fullId + " interrupted") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 8f3cc54051048..82e9578bbcba5 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -21,15 +21,14 @@ import java.io.File import java.io.IOException import java.text.SimpleDateFormat import java.util.{UUID, Date} +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import scala.concurrent.duration._ -import scala.language.postfixOps +import scala.concurrent.ExecutionContext import scala.util.Random - -import akka.actor._ -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} @@ -38,32 +37,39 @@ import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} -/** - * @param masterAkkaUrls Each url should be a valid akka url. - */ private[worker] class Worker( - host: String, - port: Int, + override val rpcEnv: RpcEnv, webUiPort: Int, cores: Int, memory: Int, - masterAkkaUrls: Array[String], - actorSystemName: String, - actorName: String, + masterRpcAddresses: Array[RpcAddress], + systemName: String, + endpointName: String, workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging { - import context.dispatcher + extends ThreadSafeRpcEndpoint with Logging { + + private val host = rpcEnv.address.host + private val port = rpcEnv.address.port Utils.checkHost(host, "Expected hostname") assert (port > 0) + // A scheduled executor used to send messages at the specified time. + private val forwordMessageScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") + + // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future` + // methods. + private val cleanupThreadExecutor = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) + // For worker and executor IDs private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -79,32 +85,26 @@ private[worker] class Worker( val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits) randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND } - private val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 * - REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds - private val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60 - * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds + private val INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(10 * + REGISTRATION_RETRY_FUZZ_MULTIPLIER)) + private val PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(60 + * REGISTRATION_RETRY_FUZZ_MULTIPLIER)) private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders private val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - private val APP_DATA_RETENTION_SECS = + private val APP_DATA_RETENTION_SECONDS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) private val testing: Boolean = sys.props.contains("spark.testing") - private var master: ActorSelection = null - private var masterAddress: Address = null + private var master: Option[RpcEndpointRef] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" - private val akkaUrl = AkkaUtils.address( - AkkaUtils.protocol(context.system), - actorSystemName, - host, - port, - actorName) - @volatile private var registered = false - @volatile private var connected = false + private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName) + private var registered = false + private var connected = false private val workerId = generateWorkerId() private val sparkHome = if (testing) { @@ -136,7 +136,18 @@ private[worker] class Worker( private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) private val workerSource = new WorkerSource(this) - private var registrationRetryTimer: Option[Cancellable] = None + private var registerMasterFutures: Array[JFuture[_]] = null + private var registrationRetryTimer: Option[JScheduledFuture[_]] = None + + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("worker-register-master-threadpool")) var coresUsed = 0 var memoryUsed = 0 @@ -162,14 +173,13 @@ private[worker] class Worker( } } - override def preStart() { + override def onStart() { assert(!registered) logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( host, port, cores, Utils.megabytesToString(memory))) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") logInfo("Spark home: " + sparkHome) createWorkDir() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() @@ -181,24 +191,32 @@ private[worker] class Worker( metricsSystem.getServletHandlers.foreach(webUi.attachHandler) } - private def changeMaster(url: String, uiUrl: String) { + private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) { // activeMasterUrl it's a valid Spark url since we receive it from master. - activeMasterUrl = url + activeMasterUrl = masterRef.address.toSparkURL activeMasterWebUiUrl = uiUrl - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system)) + master = Some(masterRef) connected = true // Cancel any outstanding re-registration attempts because we found a new master - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } - private def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + masterRpcAddresses.map { masterAddress => + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + }) } } @@ -211,8 +229,7 @@ private[worker] class Worker( Utils.tryOrExit { connectionAttemptCount += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") /** @@ -235,21 +252,48 @@ private[worker] class Worker( * still not safe if the old master recovers within this interval, but this is a much * less likely scenario. */ - if (master != null) { - master ! RegisterWorker( - workerId, host, port, cores, memory, webUi.boundPort, publicAddress) - } else { - // We are retrying the initial registration - tryRegisterAllMasters() + master match { + case Some(masterRef) => + // registered == false && master != None means we lost the connection to master, so + // masterRef cannot be used and we need to recreate it again. Note: we must not set + // master to None due to the above comments. + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + val masterAddress = masterRef.address + registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + })) + case None => + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + // We are retrying the initial registration + registerMasterFutures = tryRegisterAllMasters() } // We have exceeded the initial registration retry threshold // All retries from now on should use a higher interval if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = Some { - context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, - PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = Some( + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) } } else { logError("All masters are unresponsive! Giving up.") @@ -258,41 +302,67 @@ private[worker] class Worker( } } + /** + * Cancel last registeration retry, or do nothing if no retry + */ + private def cancelLastRegistrationRetry(): Unit = { + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterFutures = null + } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = None + } + private def registerWithMaster() { - // DisassociatedEvent may be triggered multiple times, so don't attempt registration + // onDisconnected may be triggered multiple times, so don't attempt registration // if there are outstanding registration attempts scheduled. registrationRetryTimer match { case None => registered = false - tryRegisterAllMasters() + registerMasterFutures = tryRegisterAllMasters() connectionAttemptCount = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, - INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate( + new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) case Some(_) => logInfo("Not spawning another attempt to register with the master, since there is an" + " attempt scheduled already.") } } - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisteredWorker(masterUrl, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterUrl) + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(masterRef, masterWebUiUrl) => + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) registered = true - changeMaster(masterUrl, masterWebUiUrl) - context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat) + changeMaster(masterRef, masterWebUiUrl) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(SendHeartbeat) + } + }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) if (CLEANUP_ENABLED) { logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") - context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis, - CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(WorkDirCleanup) + } + }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) } case SendHeartbeat => - if (connected) { master ! Heartbeat(workerId) } + if (connected) { sendToMaster(Heartbeat(workerId, self)) } case WorkDirCleanup => // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor + // Copy ids so that it can be used in the cleanup thread. + val appIds = executors.values.map(_.appId).toSet val cleanupFuture = concurrent.future { val appDirs = workDir.listFiles() if (appDirs == null) { @@ -302,30 +372,27 @@ private[worker] class Worker( // the directory is used by an application - check that the application is not running // when cleaning up val appIdFromDir = dir.getName - val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir) + val isAppStillRunning = appIds.contains(appIdFromDir) dir.isDirectory && !isAppStillRunning && - !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS) + !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECONDS) }.foreach { dir => logInfo(s"Removing directory: ${dir.getPath}") Utils.deleteRecursively(dir) } - } + }(cleanupThreadExecutor) - cleanupFuture onFailure { + cleanupFuture.onFailure { case e: Throwable => logError("App dir cleanup failed: " + e.getMessage, e) - } + }(cleanupThreadExecutor) - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl, masterWebUiUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + changeMaster(masterRef, masterWebUiUrl) val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) - sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq) - - case Heartbeat => - logInfo(s"Received heartbeat from driver ${sender.path}") + masterRef.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)) case RegisterWorkerFailed(message) => if (!registered) { @@ -372,14 +439,14 @@ private[worker] class Worker( publicAddress, sparkHome, executorDir, - akkaUrl, + workerUri, conf, appLocalDirs, ExecutorState.LOADING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ memoryUsed += memory_ - master ! ExecutorStateChanged(appId, execId, manager.state, None, None) + sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None)) } catch { case e: Exception => { logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) @@ -387,14 +454,14 @@ private[worker] class Worker( executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, - Some(e.toString), None) + sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, + Some(e.toString), None)) } } } - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) + case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) => + sendToMaster(executorStateChanged) val fullId = appId + "/" + execId if (ExecutorState.isFinished(state)) { executors.get(fullId) match { @@ -437,7 +504,7 @@ private[worker] class Worker( sparkHome, driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)), self, - akkaUrl, + workerUri, securityMgr) drivers(driverId) = driver driver.start() @@ -456,7 +523,7 @@ private[worker] class Worker( } } - case DriverStateChanged(driverId, state, exception) => { + case driverStageChanged @ DriverStateChanged(driverId, state, exception) => { state match { case DriverState.ERROR => logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") @@ -469,23 +536,13 @@ private[worker] class Worker( case _ => logDebug(s"Driver $driverId changed state to $state") } - master ! DriverStateChanged(driverId, state, exception) + sendToMaster(driverStageChanged) val driver = drivers.remove(driverId).get finishedDrivers(driverId) = driver memoryUsed -= driver.driverDesc.mem coresUsed -= driver.driverDesc.cores } - case x: DisassociatedEvent if x.remoteAddress == masterAddress => - logInfo(s"$x Disassociated !") - masterDisconnected() - - case RequestWorkerState => - sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, - finishedExecutors.values.toList, drivers.values.toList, - finishedDrivers.values.toList, activeMasterUrl, cores, memory, - coresUsed, memoryUsed, activeMasterWebUiUrl) - case ReregisterWithMaster => reregisterWithMaster() @@ -494,6 +551,21 @@ private[worker] class Worker( maybeCleanupApplication(id) } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestWorkerState => + context.reply(WorkerStateResponse(host, port, workerId, executors.values.toList, + finishedExecutors.values.toList, drivers.values.toList, + finishedDrivers.values.toList, activeMasterUrl, cores, memory, + coresUsed, memoryUsed, activeMasterWebUiUrl)) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (master.exists(_.address == remoteAddress)) { + logInfo(s"$remoteAddress Disassociated !") + masterDisconnected() + } + } + private def masterDisconnected() { logError("Connection to master failed! Waiting for master to reconnect...") connected = false @@ -513,13 +585,29 @@ private[worker] class Worker( } } + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => + logWarning( + s"Dropping $message because the connection to master has not yet been established") + } + } + private def generateWorkerId(): String = { "worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port) } - override def postStop() { + override def onStop() { + cleanupThreadExecutor.shutdownNow() metricsSystem.report() - registrationRetryTimer.foreach(_.cancel()) + cancelLastRegistrationRetry() + forwordMessageScheduler.shutdownNow() + registerMasterThreadPool.shutdownNow() executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) shuffleService.stop() @@ -533,12 +621,12 @@ private[deploy] object Worker extends Logging { SignalLogger.register(log) val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) - val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, + val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, @@ -547,18 +635,17 @@ private[deploy] object Worker extends Logging { masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None, - conf: SparkConf = new SparkConf): (ActorSystem, Int) = { + conf: SparkConf = new SparkConf): RpcEnv = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf, securityManager = securityMgr) - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) - actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) - (actorSystem, boundPort) + val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr) + val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) + rpcEnv.setupEndpoint(actorName, new Worker(rpcEnv, webUiPort, cores, memory, masterAddresses, + systemName, actorName, workDir, conf, securityMgr)) + rpcEnv } def isUseLocalNodeSSLConfig(cmd: Command): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 9678631da9f6f..5181142c5f80e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -121,6 +121,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { * Print usage and exit JVM with the given exit code. */ def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Worker [options] \n" + "\n" + @@ -136,6 +137,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8081)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } @@ -147,6 +149,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { val ibmVendor = System.getProperty("java.vendor").contains("IBM") var totalMb = 0 try { + // scalastyle:off classforname val bean = ManagementFactory.getOperatingSystemMXBean() if (ibmVendor) { val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean") @@ -157,14 +160,17 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize") totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt } + // scalastyle:on classforname } catch { case e: Exception => { totalMb = 2*1024 + // scalastyle:off println System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") + // scalastyle:on println } } // Leave out 1 GB for the operating system, but don't return a negative memory size - math.max(totalMb - 1024, 512) + math.max(totalMb - 1024, Utils.DEFAULT_DRIVER_MEM_MB) } def checkWorkerMemory(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 83fb991891a41..fae5640b9a213 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.worker import org.apache.spark.Logging -import org.apache.spark.deploy.DeployMessages.SendHeartbeat import org.apache.spark.rpc._ /** diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 88170d4df3053..5a1d06eb87db9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy.worker.ui +import java.io.File +import java.net.URI import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -29,6 +31,7 @@ import org.apache.spark.util.logging.RollingFileAppender private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker private val workDir = parent.workDir + private val supportedLogTypes = Set("stderr", "stdout") def renderLog(request: HttpServletRequest): String = { val defaultBytes = 100 * 1024 @@ -129,6 +132,18 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with offsetOption: Option[Long], byteLength: Int ): (String, Long, Long, Long) = { + + if (!supportedLogTypes.contains(logType)) { + return ("Error: Log type must be one of " + supportedLogTypes.mkString(", "), 0, 0, 0) + } + + // Verify that the normalized path of the log directory is in the working directory + val normalizedUri = new URI(logDirectory).normalize() + val normalizedLogDir = new File(normalizedUri.getPath) + if (!Utils.isInDirectory(workDir, normalizedLogDir)) { + return ("Error: invalid log directory " + logDirectory, 0, 0, 0) + } + try { val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType) logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}") @@ -144,7 +159,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with offset } } - val endIndex = math.min(startIndex + totalLength, totalLength) + val endIndex = math.min(startIndex + byteLength, totalLength) logDebug(s"Getting log from $startIndex to $endIndex") val logText = Utils.offsetBytes(files, startIndex, endIndex) logDebug(s"Got log of length ${logText.length} bytes") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 9f9f27d71e1ae..fd905feb97e92 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -17,10 +17,8 @@ package org.apache.spark.deploy.worker.ui -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import javax.servlet.http.HttpServletRequest import org.json4s.JValue @@ -32,18 +30,15 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { - private val workerActor = parent.worker.self - private val timeout = parent.timeout + private val workerEndpoint = parent.worker.self override def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) JsonProtocol.writeWorkerState(workerState) } def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") val runningExecutors = workerState.executors diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index b3bb5f911dbd7..334a5b10142aa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -38,7 +38,7 @@ class WorkerWebUI( extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { - private[ui] val timeout = RpcUtils.askTimeout(worker.conf) + private[ui] val timeout = RpcUtils.askRpcTimeout(worker.conf) initialize() diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index ed159dec4f998..fcd76ec52742a 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -33,7 +33,7 @@ import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( override val rpcEnv: RpcEnv, @@ -55,18 +55,22 @@ private[spark] class CoarseGrainedExecutorBackend( private[this] val ser: SerializerInstance = env.closureSerializer.newInstance() override def onStart() { - import scala.concurrent.ExecutionContext.Implicits.global logInfo("Connecting to driver: " + driverUrl) rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => + // This is a very fast action so we can use "ThreadUtils.sameThread" driver = Some(ref) ref.ask[RegisteredExecutor.type]( RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) - } onComplete { + }(ThreadUtils.sameThread).onComplete { + // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => Utils.tryLogNonFatalError { Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor } - case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) - } + case Failure(e) => { + logError(s"Cannot register with driver: $driverUrl", e) + System.exit(1) + } + }(ThreadUtils.sameThread) } def extractLogUrls: Map[String, String] = { @@ -231,7 +235,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { argv = tail case Nil => case tail => + // scalastyle:off println System.err.println(s"Unrecognized options: ${tail.mkString(" ")}") + // scalastyle:on println printUsageAndExit() } } @@ -245,6 +251,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } private def printUsageAndExit() = { + // scalastyle:off println System.err.println( """ |"Usage: CoarseGrainedExecutorBackend [options] @@ -258,6 +265,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { | --worker-url | --user-class-path |""".stripMargin) + // scalastyle:on println System.exit(1) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 8f916e0502ecb..1a02051c87f19 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -356,7 +356,7 @@ private[spark] class Executor( logInfo("Using REPL class URI: " + classUri) try { val _userClassPathFirst: java.lang.Boolean = userClassPathFirst - val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") + val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], classOf[ClassLoader], classOf[Boolean]) @@ -443,7 +443,7 @@ private[spark] class Executor( try { val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message) if (response.reregisterBlockManager) { - logWarning("Told to re-register on heartbeat") + logInfo("Told to re-register on heartbeat") env.blockManager.reregister() } } catch { diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 06152f16ae618..e80feeeab4142 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,11 +17,15 @@ package org.apache.spark.executor +import java.io.{IOException, ObjectInputStream} +import java.util.concurrent.ConcurrentHashMap + import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.storage.{BlockId, BlockStatus} +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -43,22 +47,22 @@ class TaskMetrics extends Serializable { private var _hostname: String = _ def hostname: String = _hostname private[spark] def setHostname(value: String) = _hostname = value - + /** * Time taken on the executor to deserialize this task */ private var _executorDeserializeTime: Long = _ def executorDeserializeTime: Long = _executorDeserializeTime private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value - - + + /** * Time the executor spends actually running the task (including fetching shuffle data) */ private var _executorRunTime: Long = _ def executorRunTime: Long = _executorRunTime private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value - + /** * The number of bytes this task transmitted back to the driver as the TaskResult */ @@ -94,8 +98,8 @@ class TaskMetrics extends Serializable { */ private var _diskBytesSpilled: Long = _ def diskBytesSpilled: Long = _diskBytesSpilled - def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value - def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value + private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value + private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value /** * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read @@ -210,10 +214,26 @@ class TaskMetrics extends Serializable { private[spark] def updateInputMetrics(): Unit = synchronized { inputMetrics.foreach(_.updateBytesRead()) } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + in.defaultReadObject() + // Get the hostname from cached data, since hostname is the order of number of nodes in + // cluster, so using cached hostname will decrease the object number and alleviate the GC + // overhead. + _hostname = TaskMetrics.getCachedHostName(_hostname) + } } private[spark] object TaskMetrics { + private val hostNameCache = new ConcurrentHashMap[String, String]() + def empty: TaskMetrics = new TaskMetrics + + def getCachedHostName(host: String): String = { + val canonicalHost = hostNameCache.putIfAbsent(host, host) + if (canonicalHost != null) canonicalHost else host + } } /** @@ -261,7 +281,7 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { */ private var _recordsRead: Long = _ def recordsRead: Long = _recordsRead - def incRecordsRead(records: Long): Unit = _recordsRead += records + def incRecordsRead(records: Long): Unit = _recordsRead += records /** * Invoke the bytesReadCallback and mutate bytesRead. @@ -315,7 +335,7 @@ class ShuffleReadMetrics extends Serializable { def remoteBlocksFetched: Int = _remoteBlocksFetched private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value - + /** * Number of local blocks fetched in this shuffle by this task */ @@ -333,7 +353,7 @@ class ShuffleReadMetrics extends Serializable { def fetchWaitTime: Long = _fetchWaitTime private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value - + /** * Total number of remote bytes read from the shuffle by this task */ @@ -381,7 +401,7 @@ class ShuffleWriteMetrics extends Serializable { def shuffleBytesWritten: Long = _shuffleBytesWritten private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value - + /** * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ @@ -389,7 +409,7 @@ class ShuffleWriteMetrics extends Serializable { def shuffleWriteTime: Long = _shuffleWriteTime private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value - + /** * Total number of records written to the shuffle by this task */ diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala index c219d21fbefa9..532850dd57716 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -21,6 +21,8 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} + +import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil /** @@ -39,7 +41,8 @@ private[spark] object FixedLengthBinaryInputFormat { } private[spark] class FixedLengthBinaryInputFormat - extends FileInputFormat[LongWritable, BytesWritable] { + extends FileInputFormat[LongWritable, BytesWritable] + with Logging { private var recordLength = -1 @@ -51,7 +54,7 @@ private[spark] class FixedLengthBinaryInputFormat recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) } if (recordLength <= 0) { - println("record length is less than 0, file cannot be split") + logDebug("record length is less than 0, file cannot be split") false } else { true diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0756cdb2ed8e6..607d5a321efca 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -17,7 +17,7 @@ package org.apache.spark.io -import java.io.{InputStream, OutputStream} +import java.io.{IOException, InputStream, OutputStream} import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} @@ -63,8 +63,7 @@ private[spark] object CompressionCodec { def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) val codec = try { - val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader) - .getConstructor(classOf[SparkConf]) + val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) } catch { case e: ClassNotFoundException => None @@ -154,8 +153,53 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt - new SnappyOutputStream(s, blockSize) + new SnappyOutputStreamWrapper(new SnappyOutputStream(s, blockSize)) } override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s) } + +/** + * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close + * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version + * of snappy-java that contains the fix for https://github.com/xerial/snappy-java/issues/107. + */ +private final class SnappyOutputStreamWrapper(os: SnappyOutputStream) extends OutputStream { + + private[this] var closed: Boolean = false + + override def write(b: Int): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b) + } + + override def write(b: Array[Byte]): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b, off, len) + } + + override def flush(): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.flush() + } + + override def close(): Unit = { + if (!closed) { + closed = true + os.close() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 818f7a4c8d422..87df42748be44 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.util.{Utils => SparkUtils} private[spark] trait SparkHadoopMapRedUtil { @@ -64,10 +65,10 @@ trait SparkHadoopMapRedUtil { private def firstAvailableClass(first: String, second: String): Class[_] = { try { - Class.forName(first) + SparkUtils.classForName(first) } catch { case e: ClassNotFoundException => - Class.forName(second) + SparkUtils.classForName(second) } } } diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index cfd20392d12f1..943ebcb7bd0a1 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -21,6 +21,7 @@ import java.lang.{Boolean => JBoolean, Integer => JInteger} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} +import org.apache.spark.util.Utils private[spark] trait SparkHadoopMapReduceUtil { @@ -46,7 +47,7 @@ trait SparkHadoopMapReduceUtil { isMap: Boolean, taskId: Int, attemptId: Int): TaskAttemptID = { - val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID") + val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID") try { // First, attempt to use the old-style constructor that takes a boolean isMap // (not available in YARN) @@ -57,10 +58,10 @@ trait SparkHadoopMapReduceUtil { } catch { case exc: NoSuchMethodException => { // If that failed, look for the new constructor that takes a TaskType (not available in 1.x) - val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType") + val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType") .asInstanceOf[Class[Enum[_]]] val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( - taskTypeClass, if(isMap) "MAP" else "REDUCE") + taskTypeClass, if (isMap) "MAP" else "REDUCE") val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, classOf[Int], classOf[Int]) ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), @@ -71,10 +72,10 @@ trait SparkHadoopMapReduceUtil { private def firstAvailableClass(first: String, second: String): Class[_] = { try { - Class.forName(first) + Utils.classForName(first) } catch { case e: ClassNotFoundException => - Class.forName(second) + Utils.classForName(second) } } } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index 8edf493780687..d7495551ad233 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -23,10 +23,10 @@ import java.util.Properties import scala.collection.mutable import scala.util.matching.Regex -import org.apache.spark.Logging import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} -private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging { +private[spark] class MetricsConfig(conf: SparkConf) extends Logging { private val DEFAULT_PREFIX = "*" private val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r @@ -46,23 +46,14 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi // Add default properties in case there's no properties file setDefaultProperties(properties) - // If spark.metrics.conf is not set, try to get file in class path - val isOpt: Option[InputStream] = configFile.map(new FileInputStream(_)).orElse { - try { - Option(Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME)) - } catch { - case e: Exception => - logError("Error loading default configuration file", e) - None - } - } + loadPropertiesFromFile(conf.getOption("spark.metrics.conf")) - isOpt.foreach { is => - try { - properties.load(is) - } finally { - is.close() - } + // Also look for the properties in provided Spark configuration + val prefix = "spark.metrics.conf." + conf.getAll.foreach { + case (k, v) if k.startsWith(prefix) => + properties.setProperty(k.substring(prefix.length()), v) + case _ => } propertyCategories = subProperties(properties, INSTANCE_REGEX) @@ -97,5 +88,31 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi case None => propertyCategories.getOrElse(DEFAULT_PREFIX, new Properties) } } -} + /** + * Loads configuration from a config file. If no config file is provided, try to get file + * in class path. + */ + private[this] def loadPropertiesFromFile(path: Option[String]): Unit = { + var is: InputStream = null + try { + is = path match { + case Some(f) => new FileInputStream(f) + case None => Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME) + } + + if (is != null) { + properties.load(is) + } + } catch { + case e: Exception => + val file = path.getOrElse(DEFAULT_METRICS_CONF_FILENAME) + logError(s"Error loading configuration file $file", e) + } finally { + if (is != null) { + is.close() + } + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 9150ad35712a1..67f64d5e278de 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -20,6 +20,8 @@ package org.apache.spark.metrics import java.util.Properties import java.util.concurrent.TimeUnit +import org.apache.spark.util.Utils + import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} @@ -70,8 +72,7 @@ private[spark] class MetricsSystem private ( securityMgr: SecurityManager) extends Logging { - private[this] val confFile = conf.get("spark.metrics.conf", null) - private[this] val metricsConfig = new MetricsConfig(Option(confFile)) + private[this] val metricsConfig = new MetricsConfig(conf) private val sinks = new mutable.ArrayBuffer[Sink] private val sources = new mutable.ArrayBuffer[Source] @@ -167,7 +168,7 @@ private[spark] class MetricsSystem private ( sourceConfigs.foreach { kv => val classPath = kv._2.getProperty("class") try { - val source = Class.forName(classPath).newInstance() + val source = Utils.classForName(classPath).newInstance() registerSource(source.asInstanceOf[Source]) } catch { case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e) @@ -183,7 +184,7 @@ private[spark] class MetricsSystem private ( val classPath = kv._2.getProperty("class") if (null != classPath) { try { - val sink = Class.forName(classPath) + val sink = Utils.classForName(classPath) .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) .newInstance(kv._2, registry, securityMgr) if (kv._1 == "servlet") { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index e8b3074e8f1a6..11dfcfe2f04e1 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -26,9 +26,9 @@ import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem private[spark] class Slf4jSink( - val property: Properties, + val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) + securityMgr: SecurityManager) extends Sink { val SLF4J_DEFAULT_PERIOD = 10 val SLF4J_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala index 90e3aa70b99ef..670e683663324 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala @@ -20,4 +20,4 @@ package org.apache.spark.metrics /** * Sinks used in Spark's metrics system. */ -package object sink +package object sink diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index b573f1a8a5fcb..79cb0640c8672 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -57,16 +57,6 @@ private[nio] class BlockMessage() { } def set(buffer: ByteBuffer) { - /* - println() - println("BlockMessage: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ typ = buffer.getInt() val idLength = buffer.getInt() val idBuilder = new StringBuilder(idLength) @@ -110,7 +100,7 @@ private[nio] class BlockMessage() { def getType: Int = typ def getId: BlockId = id def getData: ByteBuffer = data - def getLevel: StorageLevel = level + def getLevel: StorageLevel = level def toBufferMessage: BufferMessage = { val buffers = new ArrayBuffer[ByteBuffer]() @@ -138,24 +128,12 @@ private[nio] class BlockMessage() { buffers += data } - /* - println() - println("BlockMessage: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ Message.createBufferMessage(buffers) } override def toString: String = { "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" + ", data = " + (if (data != null) data.remaining.toString else "null") + "]" } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index 1ba25aa74aa02..f1c9ea8b64ca3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -43,16 +43,6 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) val newBlockMessages = new ArrayBuffer[BlockMessage]() val buffer = bufferMessage.buffers(0) buffer.clear() - /* - println() - println("BlockMessageArray: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ while (buffer.remaining() > 0) { val size = buffer.getInt() logDebug("Creating block message of size " + size + " bytes") @@ -86,23 +76,11 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) logDebug("Buffer list:") buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - /* - println() - println("BlockMessageArray: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ Message.createBufferMessage(buffers) } } -private[nio] object BlockMessageArray { +private[nio] object BlockMessageArray extends Logging { def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { val newBlockMessageArray = new BlockMessageArray() @@ -114,8 +92,8 @@ private[nio] object BlockMessageArray { val blockMessages = (0 until 10).map { i => if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear + val buffer = ByteBuffer.allocate(100) + buffer.clear() BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, StorageLevel.MEMORY_ONLY_SER)) } else { @@ -123,10 +101,10 @@ private[nio] object BlockMessageArray { } } val blockMessageArray = new BlockMessageArray(blockMessages) - println("Block message array created") + logDebug("Block message array created") val bufferMessage = blockMessageArray.toBufferMessage - println("Converted to buffer message") + logDebug("Converted to buffer message") val totalSize = bufferMessage.size val newBuffer = ByteBuffer.allocate(totalSize) @@ -138,10 +116,11 @@ private[nio] object BlockMessageArray { }) newBuffer.flip val newBufferMessage = Message.createBufferMessage(newBuffer) - println("Copied to new buffer message, size = " + newBufferMessage.size) + logDebug("Copied to new buffer message, size = " + newBufferMessage.size) val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - println("Converted back to block message array") + logDebug("Converted back to block message array") + // scalastyle:off println newBlockMessageArray.foreach(blockMessage => { blockMessage.getType match { case BlockMessage.TYPE_PUT_BLOCK => { @@ -154,6 +133,7 @@ private[nio] object BlockMessageArray { } } }) + // scalastyle:on println } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 6b898bd4bfc1b..1499da07bb83b 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -326,15 +326,14 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // MUST be called within the selector loop def connect() { - try{ + try { channel.register(selector, SelectionKey.OP_CONNECT) channel.connect(address) logInfo("Initiating connection to [" + address + "]") } catch { - case e: Exception => { + case e: Exception => logError("Error connecting to " + address, e) callOnExceptionCallbacks(e) - } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 497871ed6d5e5..9143918790381 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -635,12 +635,11 @@ private[nio] class ConnectionManager( val message = securityMsgResp.toBufferMessage if (message == null) throw new IOException("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) - } catch { - case e: Exception => { + } catch { + case e: Exception => logError("Error handling sasl client authentication", e) waitingConn.close() throw new IOException("Error evaluating sasl response: ", e) - } } } } @@ -1017,7 +1016,9 @@ private[spark] object ConnectionManager { val conf = new SparkConf val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + // scalastyle:off println println("Received [" + msg + "] from [" + id + "]") + // scalastyle:on println None }) @@ -1034,6 +1035,7 @@ private[spark] object ConnectionManager { System.gc() } + // scalastyle:off println def testSequentialSending(manager: ConnectionManager) { println("--------------------------") println("Sequential Sending") @@ -1151,4 +1153,5 @@ private[spark] object ConnectionManager { println() } } + // scalastyle:on println } diff --git a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala index 747a2088a7258..232c552f9865d 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala @@ -75,7 +75,7 @@ private[nio] class SecurityMessage extends Logging { for (i <- 1 to idLength) { idBuilder += buffer.getChar() } - connectionId = idBuilder.toString() + connectionId = idBuilder.toString() val tokenLength = buffer.getInt() token = new Array[Byte](tokenLength) diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 2ab41ba488ff6..8ae76c5f72f2e 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.4.0-SNAPSHOT" + val SPARK_VERSION = "1.5.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 3ef3cc219dec6..91b07ce3af1b6 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -32,12 +32,12 @@ import org.apache.spark.util.collection.OpenHashMap * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. */ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[OpenHashMap[T,Long], Map[T, BoundedDouble]] { + extends ApproximateEvaluator[OpenHashMap[T, Long], Map[T, BoundedDouble]] { var outputsMerged = 0 - var sums = new OpenHashMap[T,Long]() // Sum of counts for each key + var sums = new OpenHashMap[T, Long]() // Sum of counts for each key - override def merge(outputId: Int, taskResult: OpenHashMap[T,Long]) { + override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]) { outputsMerged += 1 taskResult.foreach { case (key, value) => sums.changeValue(key, value, _ + value) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index ec185340c3a2d..ca1eb1f4e4a9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -19,8 +19,10 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong +import org.apache.spark.util.ThreadUtils + import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.ExecutionContext import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} @@ -66,6 +68,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val f = new ComplexFutureAction[Seq[T]] f.run { + // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which + // is a cached thread pool. val results = new ArrayBuffer[T](num) val totalParts = self.partitions.length var partsScanned = 0 @@ -81,9 +85,9 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi numPartsToTry = partsScanned * 4 } else { // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max(1, + numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } @@ -101,7 +105,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi partsScanned += numPartsToTry } results.toSeq - } + }(AsyncRDDActions.futureExecutionContext) f } @@ -123,3 +127,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi (index, data) => Unit, Unit) } } + +private object AsyncRDDActions { + val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("AsyncRDDActions-future", 128)) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 0d130dd4c7a60..e17bd47905d7a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -21,15 +21,14 @@ import java.io.IOException import scala.reflect.ClassTag -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} -private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} +private[spark] class CheckpointRDDPartition(val index: Int) extends Partition /** * This RDD represents a RDD checkpoint file (similar to HadoopRDD). @@ -38,9 +37,11 @@ private[spark] class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { - val broadcastedConf = sc.broadcast(new SerializableWritable(sc.hadoopConfiguration)) + private val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration)) - @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) + @transient private val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) + + override def getCheckpointFile: Option[String] = Some(checkpointPath) override def getPartitions: Array[Partition] = { val cpath = new Path(checkpointPath) @@ -49,7 +50,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) if (fs.exists(cpath)) { val dirContents = fs.listStatus(cpath).map(_.getPath) val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.length + val numPart = partitionFiles.length if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { throw new SparkException("Invalid checkpoint directory: " + checkpointPath) @@ -60,9 +61,6 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) } - checkpointData = Some(new RDDCheckpointData[T](this)) - checkpointData.get.cpFile = Some(checkpointPath) - override def getPreferredLocations(split: Partition): Seq[String] = { val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))) @@ -75,9 +73,9 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) CheckpointRDD.readFromFile(file, broadcastedConf, context) } - override def checkpoint() { - // Do nothing. CheckpointRDD should not be checkpointed. - } + // CheckpointRDD should not be checkpointed again + override def checkpoint(): Unit = { } + override def doCheckpoint(): Unit = { } } private[spark] object CheckpointRDD extends Logging { @@ -87,7 +85,7 @@ private[spark] object CheckpointRDD extends Logging { def writeToFile[T: ClassTag]( path: String, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], blockSize: Int = -1 )(ctx: TaskContext, iterator: Iterator[T]) { val env = SparkEnv.get @@ -135,7 +133,7 @@ private[spark] object CheckpointRDD extends Logging { def readFromFile[T]( path: Path, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], context: TaskContext ): Iterator[T] = { val env = SparkEnv.get @@ -164,7 +162,7 @@ private[spark] object CheckpointRDD extends Logging { val path = new Path(hdfsPath, "temp") val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf()) val fs = path.getFileSystem(conf) - val broadcastedConf = sc.broadcast(new SerializableWritable(conf)) + val broadcastedConf = sc.broadcast(new SerializableConfiguration(conf)) sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 0c1b02c07d09f..663eebb8e4191 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -310,11 +310,11 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: def throwBalls() { if (noLocality) { // no preferredLocations in parent RDD, no randomization needed if (maxPartitions > groupArr.size) { // just return prev.partitions - for ((p,i) <- prev.partitions.zipWithIndex) { + for ((p, i) <- prev.partitions.zipWithIndex) { groupArr(i).arr += p } } else { // no locality available, then simply split partitions based on positions in array - for(i <- 0 until maxPartitions) { + for (i <- 0 until maxPartitions) { val rangeStart = ((i.toLong * prev.partitions.length) / maxPartitions).toInt val rangeEnd = (((i.toLong + 1) * prev.partitions.length) / maxPartitions).toInt (rangeStart until rangeEnd).foreach{ j => groupArr(i).arr += prev.partitions(j) } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 2cefe63d44b20..f1c17369cb48c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -44,7 +44,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{NextIterator, Utils} +import org.apache.spark.util.{SerializableConfiguration, NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} import org.apache.spark.storage.StorageLevel @@ -100,7 +100,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp @DeveloperApi class HadoopRDD[K, V]( @transient sc: SparkContext, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], @@ -121,8 +121,8 @@ class HadoopRDD[K, V]( minPartitions: Int) = { this( sc, - sc.broadcast(new SerializableWritable(conf)) - .asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + sc.broadcast(new SerializableConfiguration(conf)) + .asInstanceOf[Broadcast[SerializableConfiguration]], None /* initLocalJobConfFuncOpt */, inputFormatClass, keyClass, @@ -383,11 +383,11 @@ private[spark] object HadoopRDD extends Logging { private[spark] class SplitInfoReflections { val inputSplitWithLocationInfo = - Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") + Utils.classForName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo") - val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit") + val newInputSplit = Utils.classForName("org.apache.hadoop.mapreduce.InputSplit") val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo") - val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo") + val splitLocationInfo = Utils.classForName("org.apache.hadoop.mapred.SplitLocationInfo") val isInMemory = splitLocationInfo.getMethod("isInMemory") val getLocation = splitLocationInfo.getMethod("getLocation") } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 2ab967f4bb313..f827270ee6a44 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel @@ -74,7 +74,7 @@ class NewHadoopRDD[K, V]( with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + private val confBroadcast = sc.broadcast(new SerializableConfiguration(conf)) // private val serializableConf = new SerializableWritable(conf) private val jobTrackerId: String = { @@ -196,7 +196,7 @@ class NewHadoopRDD[K, V]( override def getPreferredLocations(hsplit: Partition): Seq[String] = { val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => + case Some(c) => try { val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] Some(HadoopRDD.convertSplitLocationInfo(infos)) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index a6d5d2c94e17f..91a6a2d039852 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -44,7 +44,7 @@ import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -296,6 +296,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = self.withScope { + val cleanedF = self.sparkContext.clean(func) if (keyClass.isArray) { throw new SparkException("reduceByKeyLocally() does not support array keys") @@ -305,7 +306,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val map = new JHashMap[K, V] iter.foreach { pair => val old = map.get(pair._1) - map.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) + map.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } Iterator(map) } : Iterator[JHashMap[K, V]] @@ -313,7 +314,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val mergeMaps = (m1: JHashMap[K, V], m2: JHashMap[K, V]) => { m2.foreach { pair => val old = m1.get(pair._1) - m1.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) + m1.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } m1 } : JHashMap[K, V] @@ -327,7 +328,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) reduceByKeyLocally(func) } - /** + /** * Count the number of elements for each key, collecting the results to a local Map. * * Note that this method should only be used if the resulting map is expected to be small, as @@ -466,7 +467,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 val bufs = combineByKey[CompactBuffer[V]]( - createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine=false) + createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) bufs.asInstanceOf[RDD[(K, Iterable[V])]] } @@ -1001,7 +1002,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val wrappedConf = new SerializableWritable(job.getConfiguration) + val wrappedConf = new SerializableConfiguration(job.getConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance @@ -1010,7 +1011,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) jobFormat.checkOutputSpecs(job) } - val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { + val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => { val config = wrappedConf.value /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, @@ -1026,7 +1027,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) - val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]] require(writer != null, "Unable to obtain RecordWriter") var recordsWritten = 0L Utils.tryWithSafeFinally { @@ -1064,7 +1065,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val wrappedConf = new SerializableWritable(hadoopConf) + val wrappedConf = new SerializableConfiguration(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 7598ff617b399..9e3880714a79f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -86,7 +86,7 @@ class PartitionerAwareUnionRDD[T: ClassTag]( } val location = if (locations.isEmpty) { None - } else { + } else { // Find the location that maximum number of parent partitions prefer Some(locations.groupBy(x => x).maxBy(_._2.length)._1) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index dc60d48927624..defdabf95ac4b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -123,7 +123,9 @@ private[spark] class PipedRDD[T: ClassTag]( new Thread("stderr reader for " + command) { override def run() { for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { + // scalastyle:off println System.err.println(line) + // scalastyle:on println } } }.start() @@ -133,6 +135,7 @@ private[spark] class PipedRDD[T: ClassTag]( override def run() { val out = new PrintWriter(proc.getOutputStream) + // scalastyle:off println // input the pipe context firstly if (printPipeContext != null) { printPipeContext(out.println(_)) @@ -144,6 +147,7 @@ private[spark] class PipedRDD[T: ClassTag]( out.println(elem) } } + // scalastyle:on println out.close() } }.start() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 02a94baf372d9..9f7ebae3e9af3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -194,7 +194,7 @@ abstract class RDD[T: ClassTag]( @transient private var partitions_ : Array[Partition] = null /** An Option holding our checkpoint RDD, if we are checkpointed */ - private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) + private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD) /** * Get the list of dependencies of this RDD, taking into account whether the @@ -434,11 +434,11 @@ abstract class RDD[T: ClassTag]( * @return A random sub-sample of the RDD without replacement. */ private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = { - this.mapPartitionsWithIndex { case (index, partition) => + this.mapPartitionsWithIndex( { (index, partition) => val sampler = new BernoulliCellSampler[T](lb, ub) sampler.setSeed(seed + index) sampler.sample(partition) - } + }, preservesPartitioning = true) } /** @@ -454,7 +454,7 @@ abstract class RDD[T: ClassTag]( withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = { - val numStDev = 10.0 + val numStDev = 10.0 if (num < 0) { throw new IllegalArgumentException("Negative number of elements requested") @@ -890,6 +890,10 @@ abstract class RDD[T: ClassTag]( * Return an iterator that contains all of the elements in this RDD. * * The iterator will consume as much memory as the largest partition in this RDD. + * + * Note: this results in multiple Spark jobs, and if the input RDD is the result + * of a wide transformation (e.g. join with different partitioners), to avoid + * recomputing the input RDD should be cached first. */ def toLocalIterator: Iterator[T] = withScope { def collectPartition(p: Int): Array[T] = { @@ -1015,9 +1019,16 @@ abstract class RDD[T: ClassTag]( /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to - * modify t1 and return it as its result value to avoid object allocation; however, it should not - * modify t2. + * given associative and commutative function and a neutral "zero value". The function + * op(t1, t2) is allowed to modify t1 and return it as its result value to avoid object + * allocation; however, it should not modify t2. + * + * This behaves somewhat differently from fold operations implemented for non-distributed + * collections in functional languages like Scala. This fold operation may be applied to + * partitions individually, and then fold those results into the final result, rather than + * apply the fold to each element sequentially in some defined ordering. For functions + * that are not commutative, the result may differ from that of a fold applied to a + * non-distributed collection. */ def fold(zeroValue: T)(op: (T, T) => T): T = withScope { // Clone the zero value since we will also be serializing it as part of tasks @@ -1131,8 +1142,8 @@ abstract class RDD[T: ClassTag]( if (elementClassTag.runtimeClass.isArray) { throw new SparkException("countByValueApprox() does not support arrays") } - val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T,Long] = { (ctx, iter) => - val map = new OpenHashMap[T,Long] + val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T, Long] = { (ctx, iter) => + val map = new OpenHashMap[T, Long] iter.foreach { t => map.changeValue(t, 1L, _ + 1L) } @@ -1440,12 +1451,16 @@ abstract class RDD[T: ClassTag]( * executed on this RDD. It is strongly recommended that this RDD is persisted in * memory, otherwise saving it on a file will require recomputation. */ - def checkpoint() { + def checkpoint(): Unit = { if (context.checkpointDir.isEmpty) { throw new SparkException("Checkpoint directory has not been set in the SparkContext") } else if (checkpointData.isEmpty) { - checkpointData = Some(new RDDCheckpointData(this)) - checkpointData.get.markForCheckpoint() + // NOTE: we use a global lock here due to complexities downstream with ensuring + // children RDD partitions point to the correct parent partitions. In the future + // we should revisit this consideration. + RDDCheckpointData.synchronized { + checkpointData = Some(new RDDCheckpointData(this)) + } } } @@ -1486,7 +1501,7 @@ abstract class RDD[T: ClassTag]( private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassTag] = { + protected[spark] def firstParent[U: ClassTag]: RDD[U] = { dependencies.head.rdd.asInstanceOf[RDD[U]] } @@ -1524,7 +1539,7 @@ abstract class RDD[T: ClassTag]( * doCheckpoint() is called recursively on the parent RDDs. */ private[spark] def doCheckpoint(): Unit = { - RDDOperationScope.withScope(sc, "checkpoint", false, true) { + RDDOperationScope.withScope(sc, "checkpoint", allowNesting = false, ignoreParent = true) { if (!doCheckpointCalled) { doCheckpointCalled = true if (checkpointData.isDefined) { @@ -1578,15 +1593,15 @@ abstract class RDD[T: ClassTag]( case 0 => Seq.empty case 1 => val d = rdd.dependencies.head - debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]], true) + debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_, _, _]], true) case _ => val frontDeps = rdd.dependencies.take(len - 1) val frontDepStrings = frontDeps.flatMap( - d => debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]])) + d => debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_, _, _]])) val lastDep = rdd.dependencies.last val lastDepStrings = - debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_,_,_]], true) + debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true) (frontDepStrings ++ lastDepStrings) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 1722c27e55003..4f954363bed8e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -22,15 +22,15 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.Path import org.apache.spark._ -import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} +import org.apache.spark.util.SerializableConfiguration /** * Enumeration to manage state transitions of an RDD through checkpointing - * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] + * [ Initialized --> checkpointing in progress --> checkpointed ]. */ private[spark] object CheckpointState extends Enumeration { type CheckpointState = Value - val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value + val Initialized, CheckpointingInProgress, Checkpointed = Value } /** @@ -45,37 +45,37 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) import CheckpointState._ // The checkpoint state of the associated RDD. - var cpState = Initialized + private var cpState = Initialized // The file to which the associated RDD has been checkpointed to - @transient var cpFile: Option[String] = None + private var cpFile: Option[String] = None // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. - var cpRDD: Option[RDD[T]] = None + // This is defined if and only if `cpState` is `Checkpointed`. + private var cpRDD: Option[CheckpointRDD[T]] = None - // Mark the RDD for checkpointing - def markForCheckpoint() { - RDDCheckpointData.synchronized { - if (cpState == Initialized) cpState = MarkedForCheckpoint - } - } + // TODO: are we sure we need to use a global lock in the following methods? // Is the RDD already checkpointed - def isCheckpointed: Boolean = { - RDDCheckpointData.synchronized { cpState == Checkpointed } + def isCheckpointed: Boolean = RDDCheckpointData.synchronized { + cpState == Checkpointed } // Get the file to which this RDD was checkpointed to as an Option - def getCheckpointFile: Option[String] = { - RDDCheckpointData.synchronized { cpFile } + def getCheckpointFile: Option[String] = RDDCheckpointData.synchronized { + cpFile } - // Do the checkpointing of the RDD. Called after the first job using that RDD is over. - def doCheckpoint() { - // If it is marked for checkpointing AND checkpointing is not already in progress, - // then set it to be in progress, else return + /** + * Materialize this RDD and write its content to a reliable DFS. + * This is called immediately after the first action invoked on this RDD has completed. + */ + def doCheckpoint(): Unit = { + + // Guard against multiple threads checkpointing the same RDD by + // atomically flipping the state of this RDDCheckpointData RDDCheckpointData.synchronized { - if (cpState == MarkedForCheckpoint) { + if (cpState == Initialized) { cpState = CheckpointingInProgress } else { return @@ -86,18 +86,20 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get val fs = path.getFileSystem(rdd.context.hadoopConfiguration) if (!fs.mkdirs(path)) { - throw new SparkException("Failed to create checkpoint path " + path) + throw new SparkException(s"Failed to create checkpoint path $path") } // Save to file, and reload it as an RDD val broadcastedConf = rdd.context.broadcast( - new SerializableWritable(rdd.context.hadoopConfiguration)) + new SerializableConfiguration(rdd.context.hadoopConfiguration)) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { rdd.context.cleaner.foreach { cleaner => cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) } } + + // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( @@ -112,34 +114,26 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed } - logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) - } - - // Get preferred location of a split after checkpointing - def getPreferredLocations(split: Partition): Seq[String] = { - RDDCheckpointData.synchronized { - cpRDD.get.preferredLocations(split) - } + logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}") } - def getPartitions: Array[Partition] = { - RDDCheckpointData.synchronized { - cpRDD.get.partitions - } + def getPartitions: Array[Partition] = RDDCheckpointData.synchronized { + cpRDD.get.partitions } - def checkpointRDD: Option[RDD[T]] = { - RDDCheckpointData.synchronized { - cpRDD - } + def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized { + cpRDD } } private[spark] object RDDCheckpointData { + + /** Return the path of the directory to which this RDD's checkpoint data is written. */ def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = { - sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) } + sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") } } + /** Clean up the files associated with the checkpoint data for this RDD. */ def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = { rddCheckpointDataPath(sc, rddId).foreach { path => val fs = path.getFileSystem(sc.hadoopConfiguration) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala index 93ec606f2de7d..44667281c1063 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -24,7 +24,7 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} /** * A general, named code block representing an operation that instantiates RDDs. @@ -43,9 +43,8 @@ import org.apache.spark.SparkContext @JsonPropertyOrder(Array("id", "name", "parent")) private[spark] class RDDOperationScope( val name: String, - val parent: Option[RDDOperationScope] = None) { - - val id: Int = RDDOperationScope.nextScopeId() + val parent: Option[RDDOperationScope] = None, + val id: String = RDDOperationScope.nextScopeId().toString) { def toJson: String = { RDDOperationScope.jsonMapper.writeValueAsString(this) @@ -75,7 +74,7 @@ private[spark] class RDDOperationScope( * A collection of utility methods to construct a hierarchical representation of RDD scopes. * An RDD scope tracks the series of operations that created a given RDD. */ -private[spark] object RDDOperationScope { +private[spark] object RDDOperationScope extends Logging { private val jsonMapper = new ObjectMapper().registerModule(DefaultScalaModule) private val scopeCounter = new AtomicInteger(0) @@ -88,15 +87,25 @@ private[spark] object RDDOperationScope { /** * Execute the given body such that all RDDs created in this body will have the same scope. - * The name of the scope will be the name of the method that immediately encloses this one. + * The name of the scope will be the first method name in the stack trace that is not the + * same as this method's. * * Note: Return statements are NOT allowed in body. */ private[spark] def withScope[T]( sc: SparkContext, allowNesting: Boolean = false)(body: => T): T = { - val callerMethodName = Thread.currentThread.getStackTrace()(3).getMethodName - withScope[T](sc, callerMethodName, allowNesting)(body) + val ourMethodName = "withScope" + val callerMethodName = Thread.currentThread.getStackTrace() + .dropWhile(_.getMethodName != ourMethodName) + .find(_.getMethodName != ourMethodName) + .map(_.getMethodName) + .getOrElse { + // Log a warning just in case, but this should almost certainly never happen + logWarning("No valid method name for this RDD operation scope!") + "N/A" + } + withScope[T](sc, callerMethodName, allowNesting, ignoreParent = false)(body) } /** @@ -116,7 +125,7 @@ private[spark] object RDDOperationScope { sc: SparkContext, name: String, allowNesting: Boolean, - ignoreParent: Boolean = false)(body: => T): T = { + ignoreParent: Boolean)(body: => T): T = { // Save the old scope to restore it later val scopeKey = SparkContext.RDD_SCOPE_KEY val noOverrideKey = SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 3dfcf67f0eb66..4b5f15dd06b85 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -104,13 +104,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag if (!convertKey && !convertValue) { self.saveAsHadoopFile(path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (!convertKey && convertValue) { - self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( + self.map(x => (x._1, anyToWritable(x._2))).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && !convertValue) { - self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( + self.map(x => (anyToWritable(x._1), x._2)).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && convertValue) { - self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( + self.map(x => (anyToWritable(x._1), anyToWritable(x._2))).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 633aeba3bbae6..f7cb1791d4ac6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -125,7 +125,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys integrate(1, t => map.remove(t._1)) - map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten + map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index a96b6c3d23454..81f40ad33aa5d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -123,7 +123,7 @@ private[spark] class ZippedPartitionsRDD3 } private[spark] class ZippedPartitionsRDD4 - [A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag]( + [A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]( sc: SparkContext, var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], var rdd1: RDD[A], diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 69181edb9ad44..6ae47894598be 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -17,8 +17,7 @@ package org.apache.spark.rpc -import scala.concurrent.{Await, Future} -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.Future import scala.reflect.ClassTag import org.apache.spark.util.RpcUtils @@ -32,7 +31,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) private[this] val maxRetries = RpcUtils.numRetries(conf) private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) - private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) + private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) /** * return the address for the [[RpcEndpointRef]] @@ -52,7 +51,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * * This method only sends the message once and never retries. */ - def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to @@ -91,7 +90,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = { // TODO: Consider removing multiple attempts var attempts = 0 var lastException: Exception = null @@ -99,7 +98,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) attempts += 1 try { val future = ask[T](message, timeout) - val result = Await.result(future, timeout) + val result = timeout.awaitResult(future) if (result == null) { throw new SparkException("Actor returned null") } @@ -110,10 +109,14 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) lastException = e logWarning(s"Error sending message [message = $message] in $attempts attempts", e) } - Thread.sleep(retryWaitMs) + + if (attempts < maxRetries) { + Thread.sleep(retryWaitMs) + } } throw new SparkException( s"Error sending message [message = $message]", lastException) } + } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 12b6b28d4d7ec..c9fcc7a36cc04 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -18,8 +18,10 @@ package org.apache.spark.rpc import java.net.URI +import java.util.concurrent.TimeoutException -import scala.concurrent.{Await, Future} +import scala.concurrent.{Awaitable, Await, Future} +import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.{SecurityManager, SparkConf} @@ -37,8 +39,7 @@ private[spark] object RpcEnv { val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") val rpcEnvName = conf.get("spark.rpc", "akka") val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) - Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). - newInstance().asInstanceOf[RpcEnvFactory] + Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] } def create( @@ -66,7 +67,7 @@ private[spark] object RpcEnv { */ private[spark] abstract class RpcEnv(conf: SparkConf) { - private[spark] val defaultLookupTimeout = RpcUtils.lookupTimeout(conf) + private[spark] val defaultLookupTimeout = RpcUtils.lookupRpcTimeout(conf) /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement @@ -94,7 +95,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action. */ def setupEndpointRefByURI(uri: String): RpcEndpointRef = { - Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout) + defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri)) } /** @@ -158,6 +159,8 @@ private[spark] case class RpcAddress(host: String, port: Int) { val hostPort: String = host + ":" + port override val toString: String = hostPort + + def toSparkURL: String = "spark://" + hostPort } @@ -182,3 +185,107 @@ private[spark] object RpcAddress { RpcAddress(host, port) } } + + +/** + * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + */ +private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) + extends TimeoutException(message) { initCause(cause) } + + +/** + * Associates a timeout with a description so that a when a TimeoutException occurs, additional + * context about the timeout can be amended to the exception message. + * @param duration timeout duration in seconds + * @param timeoutProp the configuration property that controls this timeout + */ +private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) + extends Serializable { + + /** Amends the standard message of TimeoutException to include the description */ + private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { + new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te) + } + + /** + * PartialFunction to match a TimeoutException and add the timeout description to the message + * + * @note This can be used in the recover callback of a Future to add to a TimeoutException + * Example: + * val timeout = new RpcTimeout(5 millis, "short timeout") + * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) + */ + def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { + // The exception has already been converted to a RpcTimeoutException so just raise it + case rte: RpcTimeoutException => throw rte + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) + } + + /** + * Wait for the completed result and return it. If the result is not available within this + * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. + * @param awaitable the `Awaitable` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * is still not ready + */ + def awaitResult[T](awaitable: Awaitable[T]): T = { + try { + Await.result(awaitable, duration) + } catch addMessageIfTimeout + } +} + + +private[spark] object RpcTimeout { + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @throws NoSuchElementException if property is not set + */ + def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @param defaultValue default timeout value in seconds if property not found + */ + def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup prioritized list of timeout properties in the configuration + * and create a RpcTimeout with the first set property key in the + * description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutPropList prioritized list of property keys for the timeout in seconds + * @param defaultValue default timeout value in seconds if no properties found + */ + def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { + require(timeoutPropList.nonEmpty) + + // Find the first set property or use the default value with the first property + val itr = timeoutPropList.iterator + var foundProp: Option[(String, String)] = None + while (itr.hasNext && foundProp.isEmpty){ + val propKey = itr.next() + conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } + } + val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) + val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds } + new RpcTimeout(timeout, finalProp._1) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index ba0d468f111ef..f2d87f68341af 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -20,7 +20,6 @@ package org.apache.spark.rpc.akka import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future -import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -29,9 +28,11 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add import akka.event.Logging.Error import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} +import com.google.common.util.concurrent.MoreExecutors + import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.rpc._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} /** * A RpcEnv implementation based on Akka. @@ -178,10 +179,10 @@ private[spark] class AkkaRpcEnv private[akka] ( }) } catch { case NonFatal(e) => - if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - _sender ! AkkaFailure(e) - } else { + _sender ! AkkaFailure(e) + if (!needReply) { + // If the sender does not require a reply, it may not handle the exception. So we rethrow + // "e" to make sure it will be processed. throw e } } @@ -212,8 +213,11 @@ private[spark] class AkkaRpcEnv private[akka] ( override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { import actorSystem.dispatcher - actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout). - map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)). + // this is just in case there is a timeout from creating the future in resolveOne, we want the + // exception to indicate the conf that determines the timeout + recover(defaultLookupTimeout.addMessageIfTimeout) } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { @@ -293,9 +297,9 @@ private[akka] class AkkaRpcEndpointRef( actorRef ! AkkaMessage(message, false) } - override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { - import scala.concurrent.ExecutionContext.Implicits.global - actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { + actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { + // The function will run in the calling thread, so it should be short and never block. case msg @ AkkaMessage(message, reply) => if (reply) { logError(s"Receive $msg but the sender cannot reply") @@ -305,7 +309,8 @@ private[akka] class AkkaRpcEndpointRef( } case AkkaFailure(e) => Future.failed(e) - }.mapTo[T] + }(ThreadUtils.sameThread).mapTo[T]. + recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } override def toString: String = s"${getClass.getSimpleName}($actorRef)" diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5d812918a13d1..f3d87ee5c4fd1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -35,6 +35,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ @@ -81,6 +82,8 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) + private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() private val nextStageId = new AtomicInteger(0) @@ -137,6 +140,22 @@ class DAGScheduler( private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) + // Flag to control if reduce tasks are assigned preferred locations + private val shuffleLocalityEnabled = + sc.getConf.getBoolean("spark.shuffle.reduceLocality.enabled", true) + // Number of map, reduce tasks above which we do not assign preferred locations + // based on map output sizes. We limit the size of jobs for which assign preferred locations + // as computing the top locations by size becomes expensive. + private[this] val SHUFFLE_PREF_MAP_THRESHOLD = 1000 + // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that + private[this] val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000 + + // Fraction of total map output that must be at a location for it to considered as a preferred + // location for a reduce task. + // Making this larger will focus on fewer locations where most data can be read locally, but + // may lead to more delay in scheduling if those locations are busy. + private[scheduler] val REDUCER_PREF_LOCS_FRACTION = 0.2 + // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) @@ -170,7 +189,7 @@ class DAGScheduler( blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( - BlockManagerHeartbeat(blockManagerId), 600 seconds) + BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } // Called by TaskScheduler when an executor fails. @@ -193,9 +212,15 @@ class DAGScheduler( def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized { // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times if (!cacheLocs.contains(rdd.id)) { - val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] - val locs: Seq[Seq[TaskLocation]] = blockManagerMaster.getLocations(blockIds).map { bms => - bms.map(bm => TaskLocation(bm.host, bm.executorId)) + // Note: if the storage level is NONE, we don't need to get locations from block manager. + val locs: Seq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { + Seq.fill(rdd.partitions.size)(Nil) + } else { + val blockIds = + rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] + blockManagerMaster.getLocations(blockIds).map { bms => + bms.map(bm => TaskLocation(bm.host, bm.executorId)) + } } cacheLocs(rdd.id) = locs } @@ -208,19 +233,17 @@ class DAGScheduler( /** * Get or create a shuffle map stage for the given shuffle dependency's map side. - * The jobId value passed in will be used if the stage doesn't already exist with - * a lower jobId (jobId always increases across jobs.) */ private def getShuffleMapStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int): ShuffleMapStage = { + firstJobId: Int): ShuffleMapStage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies - registerShuffleDependencies(shuffleDep, jobId) + registerShuffleDependencies(shuffleDep, firstJobId) // Then register current shuffleDep - val stage = newOrUsedShuffleStage(shuffleDep, jobId) + val stage = newOrUsedShuffleStage(shuffleDep, firstJobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage @@ -230,15 +253,15 @@ class DAGScheduler( /** * Helper function to eliminate some code re-use when creating new stages. */ - private def getParentStagesAndId(rdd: RDD[_], jobId: Int): (List[Stage], Int) = { - val parentStages = getParentStages(rdd, jobId) + private def getParentStagesAndId(rdd: RDD[_], firstJobId: Int): (List[Stage], Int) = { + val parentStages = getParentStages(rdd, firstJobId) val id = nextStageId.getAndIncrement() (parentStages, id) } /** * Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in - * newOrUsedShuffleStage. The stage will be associated with the provided jobId. + * newOrUsedShuffleStage. The stage will be associated with the provided firstJobId. * Production of shuffle map stages should always use newOrUsedShuffleStage, not * newShuffleMapStage directly. */ @@ -246,21 +269,19 @@ class DAGScheduler( rdd: RDD[_], numTasks: Int, shuffleDep: ShuffleDependency[_, _, _], - jobId: Int, + firstJobId: Int, callSite: CallSite): ShuffleMapStage = { - val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, firstJobId) val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages, - jobId, callSite, shuffleDep) + firstJobId, callSite, shuffleDep) stageIdToStage(id) = stage - updateJobIdStageIdMaps(jobId, stage) + updateJobIdStageIdMaps(firstJobId, stage) stage } /** - * Create a ResultStage -- either directly for use as a result stage, or as part of the - * (re)-creation of a shuffle map stage in newOrUsedShuffleStage. The stage will be associated - * with the provided jobId. + * Create a ResultStage associated with the provided jobId. */ private def newResultStage( rdd: RDD[_], @@ -277,16 +298,16 @@ class DAGScheduler( /** * Create a shuffle map Stage for the given RDD. The stage will also be associated with the - * provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is + * provided firstJobId. If a stage for the shuffleId existed previously so that the shuffleId is * present in the MapOutputTracker, then the number and location of available outputs are * recovered from the MapOutputTracker */ private def newOrUsedShuffleStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int): ShuffleMapStage = { + firstJobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd val numTasks = rdd.partitions.size - val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, jobId, rdd.creationSite) + val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) @@ -304,10 +325,10 @@ class DAGScheduler( } /** - * Get or create the list of parent stages for a given RDD. The stages will be assigned the - * provided jobId if they haven't already been created with a lower jobId. + * Get or create the list of parent stages for a given RDD. The new Stages will be created with + * the provided firstJobId. */ - private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { + private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError @@ -321,7 +342,7 @@ class DAGScheduler( for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - parents += getShuffleMapStage(shufDep, jobId) + parents += getShuffleMapStage(shufDep, firstJobId) case _ => waitingForVisit.push(dep.rdd) } @@ -336,11 +357,11 @@ class DAGScheduler( } /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) { + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int) { val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) while (parentsWithNoMapStage.nonEmpty) { val currentShufDep = parentsWithNoMapStage.pop() - val stage = newOrUsedShuffleStage(currentShufDep, jobId) + val stage = newOrUsedShuffleStage(currentShufDep, firstJobId) shuffleToMapStage(currentShufDep.shuffleId) = stage } } @@ -386,11 +407,12 @@ class DAGScheduler( def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd - if (getCacheLocs(rdd).contains(Nil)) { + val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil) + if (rddHasUncachedPartitions) { for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) + val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { missing += mapStage } @@ -577,7 +599,7 @@ class DAGScheduler( private[scheduler] def doCancelAllJobs() { // Cancel all running jobs. - runningStages.map(_.jobId).foreach(handleJobCancellation(_, + runningStages.map(_.firstJobId).foreach(handleJobCancellation(_, reason = "as part of cancellation of all jobs")) activeJobs.clear() // These should already be empty by this point, jobIdToActiveJob.clear() // but just in case we lost track of some jobs... @@ -603,7 +625,7 @@ class DAGScheduler( clearCacheLocs() val failedStagesCopy = failedStages.toArray failedStages.clear() - for (stage <- failedStagesCopy.sortBy(_.jobId)) { + for (stage <- failedStagesCopy.sortBy(_.firstJobId)) { submitStage(stage) } } @@ -623,7 +645,7 @@ class DAGScheduler( logTrace("failed: " + failedStages) val waitingStagesCopy = waitingStages.toArray waitingStages.clear() - for (stage <- waitingStagesCopy.sortBy(_.jobId)) { + for (stage <- waitingStagesCopy.sortBy(_.firstJobId)) { submitStage(stage) } } @@ -843,14 +865,14 @@ class DAGScheduler( } } - val properties = jobIdToActiveJob.get(stage.jobId).map(_.properties).orNull + val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are // serializable. If tasks are not serializable, a SparkListenerStageCompleted event // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. - stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size)) + stage.makeNewStageAttempt(partitionsToCompute.size) outputCommitCoordinator.stageStart(stage.id) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) @@ -886,30 +908,37 @@ class DAGScheduler( return } - val tasks: Seq[Task[_]] = stage match { - case stage: ShuffleMapStage => - partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) - val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) - } + val tasks: Seq[Task[_]] = try { + stage match { + case stage: ShuffleMapStage => + partitionsToCompute.map { id => + val locs = getPreferredLocs(stage.rdd, id) + val part = stage.rdd.partitions(id) + new ShuffleMapTask(stage.id, taskBinary, part, locs) + } - case stage: ResultStage => - val job = stage.resultOfJob.get - partitionsToCompute.map { id => - val p: Int = job.partitions(id) - val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) - } + case stage: ResultStage => + val job = stage.resultOfJob.get + partitionsToCompute.map { id => + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + new ResultTask(stage.id, taskBinary, part, locs, id) + } + } + } catch { + case NonFatal(e) => + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + runningStages -= stage + return } if (tasks.size > 0) { logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") stage.pendingTasks ++= tasks logDebug("New pending tasks: " + stage.pendingTasks) - taskScheduler.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) + taskScheduler.submitTasks(new TaskSet( + tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark @@ -1323,7 +1352,7 @@ class DAGScheduler( for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) + val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { waitingForVisit.push(mapStage.rdd) } // Otherwise there's no need to follow the dependency back @@ -1364,10 +1393,10 @@ class DAGScheduler( private def getPreferredLocsInternal( rdd: RDD[_], partition: Int, - visited: HashSet[(RDD[_],Int)]): Seq[TaskLocation] = { + visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = { // If the partition has already been visited, no need to re-visit. // This avoids exponential path exploration. SPARK-695 - if (!visited.add((rdd,partition))) { + if (!visited.add((rdd, partition))) { // Nil has already been returned for previously visited partitions. return Nil } @@ -1381,17 +1410,32 @@ class DAGScheduler( if (rddPrefs.nonEmpty) { return rddPrefs.map(TaskLocation(_)) } - // If the RDD has narrow dependencies, pick the first partition of the first narrow dep - // that has any placement preferences. Ideally we would choose based on transfer sizes, - // but this will do for now. + rdd.dependencies.foreach { case n: NarrowDependency[_] => + // If the RDD has narrow dependencies, pick the first partition of the first narrow dep + // that has any placement preferences. Ideally we would choose based on transfer sizes, + // but this will do for now. for (inPart <- n.getParents(partition)) { val locs = getPreferredLocsInternal(n.rdd, inPart, visited) if (locs != Nil) { return locs } } + case s: ShuffleDependency[_, _, _] => + // For shuffle dependencies, pick locations which have at least REDUCER_PREF_LOCS_FRACTION + // of data as preferred locations + if (shuffleLocalityEnabled && + rdd.partitions.size < SHUFFLE_PREF_REDUCE_THRESHOLD && + s.rdd.partitions.size < SHUFFLE_PREF_MAP_THRESHOLD) { + // Get the preferred map output locations for this reducer + val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId, + partition, rdd.partitions.size, REDUCER_PREF_LOCS_FRACTION) + if (topLocsForReducer.nonEmpty) { + return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId)) + } + } + case _ => } Nil @@ -1404,17 +1448,29 @@ class DAGScheduler( taskScheduler.stop() } - // Start the event thread at the end of the constructor + // Start the event thread and register the metrics source at the end of the constructor + env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() } private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler) extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging { + private[this] val timer = dagScheduler.metricsSource.messageProcessingTimer + /** * The main event loop of the DAG scheduler. */ - override def onReceive(event: DAGSchedulerEvent): Unit = event match { + override def onReceive(event: DAGSchedulerEvent): Unit = { + val timerContext = timer.time() + try { + doOnReceive(event) + } finally { + timerContext.stop() + } + } + + private def doOnReceive(event: DAGSchedulerEvent): Unit = event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 12668b6c0988e..6b667d5d7645b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -17,12 +17,11 @@ package org.apache.spark.scheduler -import com.codahale.metrics.{Gauge,MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry, Timer} -import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source -private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) +private[scheduler] class DAGSchedulerSource(val dagScheduler: DAGScheduler) extends Source { override val metricRegistry = new MetricRegistry() override val sourceName = "DAGScheduler" @@ -46,4 +45,8 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] { override def getValue: Int = dagScheduler.activeJobs.size }) + + /** Timer that tracks the time to process messages in the DAGScheduler's event loop */ + val messageProcessingTimer: Timer = + metricRegistry.timer(MetricRegistry.name("messageProcessingTime")) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 529a5b2bf1a0d..5a06ef02f5c57 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -140,7 +140,9 @@ private[spark] class EventLoggingListener( /** Log the event as JSON. */ private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) { val eventJson = JsonProtocol.sparkEventToJson(event) + // scalastyle:off println writer.foreach(_.println(compact(render(eventJson)))) + // scalastyle:on println if (flushLogger) { writer.foreach(_.flush()) hadoopDataStream.foreach(hadoopFlushMethod.invoke(_)) @@ -197,6 +199,9 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } + // No-op because logging every update would be overkill + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {} + // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index e55b76c36cc5f..f96eb8ca0ae00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -125,7 +125,9 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener val date = new Date(System.currentTimeMillis()) writeInfo = dateFormat.get.format(date) + ": " + info } + // scalastyle:off println jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo)) + // scalastyle:on println } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 0b1d47cff3746..8321037cdc026 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -38,7 +38,7 @@ private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttem * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests) * for an extensive design discussion. */ -private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { +private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) extends Logging { // Initialized by SparkEnv var coordinatorRef: Option[RpcEndpointRef] = None @@ -129,9 +129,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { } def stop(): Unit = synchronized { - coordinatorRef.foreach(_ send StopCoordinator) - coordinatorRef = None - authorizedCommittersByStage.clear() + if (isDriver) { + coordinatorRef.foreach(_ send StopCoordinator) + coordinatorRef = None + authorizedCommittersByStage.clear() + } } // Marked private[scheduler] instead of private so this can be mocked in tests diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 86f357abb8723..c6d957b65f3fb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -41,7 +41,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { * * @param logData Stream containing event log data. * @param sourceName Filename (or other source identifier) from whence @logData is being read - * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations + * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations * encountered, log file might not finished writing) or not */ def replay( @@ -62,7 +62,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { if (!maybeTruncated || lines.hasNext) { throw jpe } else { - logWarning(s"Got JsonParseException from log file $sourceName" + + logWarning(s"Got JsonParseException from log file $sourceName" + s" at line $lineNumber, the file might not have finished writing cleanly.") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index c0f3d5a13d623..bf81b9aca4810 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -28,9 +28,9 @@ private[spark] class ResultStage( rdd: RDD[_], numTasks: Int, parents: List[Stage], - jobId: Int, + firstJobId: Int, callSite: CallSite) - extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { // The active job for this result stage. Will be empty if the job has already finished // (e.g., because the job was cancelled). diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 646820520ea1b..8801a761afae3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -49,4 +49,11 @@ private[spark] trait SchedulerBackend { */ def applicationAttemptId(): Option[String] = None + /** + * Get the URLs for the driver logs. These URLs are used to display the links in the UI + * Executors tab for the driver. + * @return Map containing the log names and their respective URLs + */ + def getDriverLogUrls: Option[Map[String, String]] = None + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala index 5e62c8468f007..864941d468af9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala @@ -56,7 +56,7 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble - var compare:Int = 0 + var compare: Int = 0 if (s1Needy && !s2Needy) { return true diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index d02210743484c..66c75f325fcde 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -30,10 +30,10 @@ private[spark] class ShuffleMapStage( rdd: RDD[_], numTasks: Int, parents: List[Stage], - jobId: Int, + firstJobId: Int, callSite: CallSite, val shuffleDep: ShuffleDependency[_, _, _]) - extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { override def toString: String = "ShuffleMapStage " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 169d4fd3a94f0..896f1743332f1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -26,7 +26,7 @@ import org.apache.spark.{Logging, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} @DeveloperApi @@ -98,6 +98,9 @@ case class SparkListenerExecutorAdded(time: Long, executorId: String, executorIn case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends SparkListenerEvent + /** * Periodic updates from executors. * @param execId executor id @@ -110,8 +113,13 @@ case class SparkListenerExecutorMetricsUpdate( extends SparkListenerEvent @DeveloperApi -case class SparkListenerApplicationStart(appName: String, appId: Option[String], - time: Long, sparkUser: String, appAttemptId: Option[String]) extends SparkListenerEvent +case class SparkListenerApplicationStart( + appName: String, + appId: Option[String], + time: Long, + sparkUser: String, + appAttemptId: Option[String], + driverLogs: Option[Map[String, String]] = None) extends SparkListenerEvent @DeveloperApi case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent @@ -210,6 +218,11 @@ trait SparkListener { * Called when the driver removes an executor. */ def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { } + + /** + * Called when the driver receives a block update info. + */ + def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } } /** @@ -265,7 +278,7 @@ class StatsReportListener extends SparkListener with Logging { private[spark] object StatsReportListener extends Logging { // For profiling, the extremes are more interesting - val percentiles = Array[Int](0,5,10,25,50,75,90,95,100) + val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100) val probabilities = percentiles.map(_ / 100.0) val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" @@ -299,7 +312,7 @@ private[spark] object StatsReportListener extends Logging { dOpt.foreach { d => showDistribution(heading, d, formatNumber)} } - def showDistribution(heading: String, dOpt: Option[Distribution], format:String) { + def showDistribution(heading: String, dOpt: Option[Distribution], format: String) { def f(d: Double): String = format.format(d) showDistribution(heading, dOpt, f _) } @@ -313,7 +326,7 @@ private[spark] object StatsReportListener extends Logging { } def showBytesDistribution( - heading:String, + heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long], taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 61e69ecc08387..04afde33f5aad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -58,6 +58,8 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => listener.onExecutorRemoved(executorRemoved) + case blockUpdated: SparkListenerBlockUpdated => + listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 5d0ddb8377c33..b86724de2cb73 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.CallSite * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes * that each output partition is on. * - * Each Stage also has a jobId, identifying the job that first submitted the stage. When FIFO + * Each Stage also has a firstJobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. * @@ -51,7 +51,7 @@ private[spark] abstract class Stage( val rdd: RDD[_], val numTasks: Int, val parents: List[Stage], - val jobId: Int, + val firstJobId: Int, val callSite: CallSite) extends Logging { @@ -62,22 +62,28 @@ private[spark] abstract class Stage( var pendingTasks = new HashSet[Task[_]] + /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 val name = callSite.shortForm val details = callSite.longForm - /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */ - var latestInfo: StageInfo = StageInfo.fromStage(this) + /** + * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized + * here, before any attempts have actually been created, because the DAGScheduler uses this + * StageInfo to tell SparkListeners when a job starts (which happens before any stage attempts + * have been created). + */ + private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) - /** Return a new attempt id, starting with 0. */ - def newAttemptId(): Int = { - val id = nextAttemptId + /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ + def makeNewStageAttempt(numPartitionsToCompute: Int): Unit = { + _latestInfo = StageInfo.fromStage(this, nextAttemptId, Some(numPartitionsToCompute)) nextAttemptId += 1 - id } - def attemptId: Int = nextAttemptId + /** Returns the StageInfo for the most recent attempt for this stage. */ + def latestInfo: StageInfo = _latestInfo override final def hashCode(): Int = id override final def equals(other: Any): Boolean = other match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index e439d2a7e1229..5d2abbc67e9d9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -70,12 +70,12 @@ private[spark] object StageInfo { * shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a * sequence of narrow dependencies should also be associated with this Stage. */ - def fromStage(stage: Stage, numTasks: Option[Int] = None): StageInfo = { + def fromStage(stage: Stage, attemptId: Int, numTasks: Option[Int] = None): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos new StageInfo( stage.id, - stage.attemptId, + attemptId, stage.name, numTasks.getOrElse(stage.numTasks), rddInfos, diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 586d1e06204c1..15101c64f0503 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -125,7 +125,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - } + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 1f114a0207f7b..8b2a742b96988 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -40,6 +40,9 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long var metrics: TaskMetrics) extends TaskResult[T] with Externalizable { + private var valueObjectDeserialized = false + private var valueObject: T = _ + def this() = this(null.asInstanceOf[ByteBuffer], null, null) override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { @@ -72,10 +75,26 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long } } metrics = in.readObject().asInstanceOf[TaskMetrics] + valueObjectDeserialized = false } + /** + * When `value()` is called at the first time, it needs to deserialize `valueObject` from + * `valueBytes`. It may cost dozens of seconds for a large instance. So when calling `value` at + * the first time, the caller should avoid to block other threads. + * + * After the first time, `value()` is trivial and just returns the deserialized `valueObject`. + */ def value(): T = { - val resultSer = SparkEnv.get.serializer.newInstance() - resultSer.deserialize(valueBytes) + if (valueObjectDeserialized) { + valueObject + } else { + // This should not run when holding a lock because it may cost dozens of seconds for a large + // value. + val resultSer = SparkEnv.get.serializer.newInstance() + valueObject = resultSer.deserialize(valueBytes) + valueObjectDeserialized = true + valueObject + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 391827c1d2156..46a6f6537e2ee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -54,6 +54,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul if (!taskSetManager.canFetchMoreResults(serializedData.limit())) { return } + // deserialize "value" without holding any lock so that it won't block other threads. + // We should call it here, so that when it's called again in + // "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value. + directResult.value() (directResult, serializedData.limit()) case IndirectTaskResult(blockId, size) => if (!taskSetManager.canFetchMoreResults(size)) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index b4b8a630694bb..ed3dde0fc3055 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer import java.util.{TimerTask, Timer} +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong -import scala.concurrent.duration._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet @@ -32,7 +32,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -64,6 +64,9 @@ private[spark] class TaskSchedulerImpl( // How often to check for speculative tasks val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", "100ms") + private val speculationScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("task-scheduler-speculation") + // Threshold above which we warn user initial TaskSet may be starved val STARVATION_TIMEOUT_MS = conf.getTimeAsMs("spark.starvation.timeout", "15s") @@ -142,10 +145,11 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL_MS milliseconds, - SPECULATION_INTERVAL_MS milliseconds) { - Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } - }(sc.env.actorSystem.dispatcher) + speculationScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryOrStopSparkContext(sc) { + checkSpeculatableTasks() + } + }, SPECULATION_INTERVAL_MS, SPECULATION_INTERVAL_MS, TimeUnit.MILLISECONDS) } } @@ -412,6 +416,7 @@ private[spark] class TaskSchedulerImpl( } override def stop() { + speculationScheduler.shutdown() if (backend != null) { backend.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 7dc325283d961..82455b0426a5d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -46,7 +46,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} * * @param sched the TaskSchedulerImpl associated with the TaskSetManager * @param taskSet the TaskSet to manage scheduling for - * @param maxTaskFailures if any particular task fails more than this number of times, the entire + * @param maxTaskFailures if any particular task fails this number of times, the entire * task set will be aborted */ private[spark] class TaskSetManager( @@ -620,6 +620,12 @@ private[spark] class TaskSetManager( val index = info.index info.markSuccessful() removeRunningTask(tid) + // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the + // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not + // "deserialize" the value when holding a lock to avoid blocking other threads. So we call + // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. + // Note: "result.value()" only deserializes the value when it's called at the first time, so + // here "result.value()" just returns the value and won't block other threads. sched.dagScheduler.taskEnded( tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics) if (!successful(index)) { @@ -775,10 +781,10 @@ private[spark] class TaskSetManager( // that it's okay if we add a task to the same queue twice (if it had multiple preferred // locations), because dequeueTaskFromList will skip already-running tasks. for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding=true) + addPendingTask(index, readding = true) } for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding=true) + addPendingTask(index, readding = true) } // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage, @@ -855,9 +861,9 @@ private[spark] class TaskSetManager( case TaskLocality.RACK_LOCAL => "spark.locality.wait.rack" case _ => null } - + if (localityWaitKey != null) { - conf.getTimeAsMs(localityWaitKey, defaultWait) + conf.getTimeAsMs(localityWaitKey, defaultWait) } else { 0L } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 70364cea62a80..4be1eda2e9291 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -75,7 +75,8 @@ private[spark] object CoarseGrainedClusterMessages { case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage // Exchanged between the driver and the AM in Yarn client mode - case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) + case class AddWebUIFilter( + filterName: String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage // Messages exchanged between the driver and the cluster manager for executor allocation diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f107148f3b8c6..7c7f70d8a193b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -69,6 +69,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { + // If this DriverEndpoint is changed to support multiple threads, + // then this may need to be changed so that we don't share the serializer + // instance across threads + private val ser = SparkEnv.get.closureSerializer.newInstance() + override protected def log = CoarseGrainedSchedulerBackend.this.log private val addressToExecutorId = new HashMap[RpcAddress, String] @@ -79,7 +84,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveIntervalMs = conf.getTimeAsMs("spark.scheduler.revive.interval", "1s") - + reviveThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { Option(self).foreach(_.send(ReviveOffers)) @@ -98,7 +103,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case None => // Ignoring the update since we don't know about the executor. logWarning(s"Ignored task status update ($taskId state $state) " + - "from unknown executor $sender with ID $executorId") + s"from unknown executor with ID $executorId") } } @@ -163,7 +168,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Make fake resource offers on all executors - def makeOffers() { + private def makeOffers() { launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores) }.toSeq)) @@ -175,16 +180,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Make fake resource offers on just one executor - def makeOffers(executorId: String) { + private def makeOffers(executorId: String) { val executorData = executorDataMap(executorId) launchTasks(scheduler.resourceOffers( Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)))) } // Launch tasks returned by a set of resource offers - def launchTasks(tasks: Seq[Seq[TaskDescription]]) { + private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { - val ser = SparkEnv.get.closureSerializer.newInstance() val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ccf1dc5af6120..687ae9620460f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -85,7 +85,7 @@ private[spark] class SparkDeploySchedulerBackend( val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) - client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) + client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() waitForRegistration() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 2a3a5d925d06f..bc67abb5df446 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -46,7 +46,7 @@ private[spark] abstract class YarnSchedulerBackend( private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) - private implicit val askTimeout = RpcUtils.askTimeout(sc.conf) + private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) /** * Request executors from the ApplicationMaster by specifying the total number desired. @@ -149,7 +149,7 @@ private[spark] abstract class YarnSchedulerBackend( } } - override def onStop(): Unit ={ + override def onStop(): Unit = { askAmThreadPool.shutdownNow() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index dc59545b43314..cbade131494bc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,17 +18,21 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{Collections, List => JList} +import java.util.{List => JList, Collections} +import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import com.google.common.collect.HashBiMap import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.{AkkaUtils, Utils} -import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.spark.util.Utils /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -51,7 +55,7 @@ private[spark] class CoarseMesosSchedulerBackend( val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) - val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt // Cores we have acquired with each Mesos task ID val coresByTaskId = new HashMap[Int, Int] @@ -59,12 +63,34 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveIdsWithExecutors = new HashSet[String] - val taskIdToSlaveId = new HashMap[Int, String] - val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed + val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] + // How many times tasks on each slave failed + val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] + + /** + * The total number of executors we aim to have. Undefined when not using dynamic allocation + * and before the ExecutorAllocatorManager calls [[doRequesTotalExecutors]]. + */ + private var executorLimitOption: Option[Int] = None + + /** + * Return the current executor limit, which may be [[Int.MaxValue]] + * before properly initialized. + */ + private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue) + + private val pendingRemovedSlaveIds = new HashSet[String] + // private lock object protecting mutable state above. Using the intrinsic lock + // may lead to deadlocks since the superclass might also try to lock + private val stateLock = new ReentrantLock val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) + // Offer constraints + private val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + var nextMesosTaskId = 0 @volatile var appId: String = _ @@ -81,7 +107,7 @@ private[spark] class CoarseMesosSchedulerBackend( startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo) } - def createCommand(offer: Offer, numCores: Int): CommandInfo = { + def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = { val executorSparkHome = conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) .getOrElse { @@ -115,12 +141,6 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(sc.env.actorSystem), - SparkEnv.driverActorSystemName, - conf.get("spark.driver.host"), - conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.getOption("spark.executor.uri") .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) @@ -130,7 +150,7 @@ private[spark] class CoarseMesosSchedulerBackend( command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + - s" --driver-url $driverUrl" + + s" --driver-url $driverURL" + s" --executor-id ${offer.getSlaveId.getValue}" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + @@ -139,11 +159,12 @@ private[spark] class CoarseMesosSchedulerBackend( // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.get.split('/').last.split('.').head + val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString) command.setValue( s"cd $basename*; $prefixEnv " + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + - s" --driver-url $driverUrl" + - s" --executor-id ${offer.getSlaveId.getValue}" + + s" --driver-url $driverURL" + + s" --executor-id $executorId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -152,6 +173,17 @@ private[spark] class CoarseMesosSchedulerBackend( command.build() } + protected def driverURL: String = { + if (conf.contains("spark.testing")) { + "driverURL" + } else { + sc.env.rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + } + } + override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { @@ -169,15 +201,19 @@ private[spark] class CoarseMesosSchedulerBackend( * unless we've already launched more than we wanted to. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - synchronized { + stateLock.synchronized { val filters = Filters.newBuilder().setRefuseSeconds(5).build() - for (offer <- offers) { - val slaveId = offer.getSlaveId.toString + val offerAttributes = toAttributeMap(offer.getAttributesList) + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val slaveId = offer.getSlaveId.getValue val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && - mem >= MemoryUtils.calculateTotalMemory(sc) && + val id = offer.getId.getValue + if (taskIdToSlaveId.size < executorLimit && + totalCoresAcquired < maxCores && + meetsConstraints && + mem >= calculateTotalMemory(sc) && cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && !slaveIdsWithExecutors.contains(slaveId)) { @@ -191,42 +227,36 @@ private[spark] class CoarseMesosSchedulerBackend( val task = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", - MemoryUtils.calculateTotalMemory(sc))) + .addResources(createResource("mem", calculateTotalMemory(sc))) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder()) + .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder) } + // accept the offer and launch the task + logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") d.launchTasks( - Collections.singleton(offer.getId), Collections.singletonList(task.build()), filters) + Collections.singleton(offer.getId), + Collections.singleton(task.build()), filters) } else { - // Filter it out - d.launchTasks( - Collections.singleton(offer.getId), Collections.emptyList[MesosTaskInfo](), filters) + // Decline the offer + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.declineOffer(offer.getId) } } } } - /** Build a Mesos resource protobuf object */ - private def createResource(resourceName: String, quantity: Double): Protos.Resource = { - Resource.newBuilder() - .setName(resourceName) - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) - .build() - } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue.toInt val state = status.getState logInfo("Mesos task " + taskId + " is now " + state) - synchronized { + stateLock.synchronized { if (TaskState.isFinished(TaskState.fromMesos(state))) { val slaveId = taskIdToSlaveId(taskId) slaveIdsWithExecutors -= slaveId @@ -244,8 +274,9 @@ private[spark] class CoarseMesosSchedulerBackend( "is Spark installed on it?") } } + executorTerminated(d, slaveId, s"Executor finished with state $state") // In case we'd rejected everything before but have now lost a node - mesosDriver.reviveOffers() + d.reviveOffers() } } } @@ -264,18 +295,39 @@ private[spark] class CoarseMesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { - logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - if (slaveIdsWithExecutors.contains(slaveId.getValue)) { - // Note that the slave ID corresponds to the executor ID on that slave - slaveIdsWithExecutors -= slaveId.getValue - removeExecutor(slaveId.getValue, "Mesos slave lost") + /** + * Called when a slave is lost or a Mesos task finished. Update local view on + * what tasks are running and remove the terminated slave from the list of pending + * slave IDs that we might have asked to be killed. It also notifies the driver + * that an executor was removed. + */ + private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = { + stateLock.synchronized { + if (slaveIdsWithExecutors.contains(slaveId)) { + val slaveIdToTaskId = taskIdToSlaveId.inverse() + if (slaveIdToTaskId.contains(slaveId)) { + val taskId: Int = slaveIdToTaskId.get(slaveId) + taskIdToSlaveId.remove(taskId) + removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason) + } + // TODO: This assumes one Spark executor per Mesos slave, + // which may no longer be true after SPARK-5095 + pendingRemovedSlaveIds -= slaveId + slaveIdsWithExecutors -= slaveId } } } - override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { + private def sparkExecutorId(slaveId: String, taskId: String): String = { + s"$slaveId/$taskId" + } + + override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { + logInfo("Mesos slave lost: " + slaveId.getValue) + executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue) + } + + override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) slaveLost(d, s) } @@ -286,4 +338,34 @@ private[spark] class CoarseMesosSchedulerBackend( super.applicationId } + override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + // We don't truly know if we can fulfill the full amount of executors + // since at coarse grain it depends on the amount of slaves available. + logInfo("Capping the total amount of executors to " + requestedTotal) + executorLimitOption = Some(requestedTotal) + true + } + + override def doKillExecutors(executorIds: Seq[String]): Boolean = { + if (mesosDriver == null) { + logWarning("Asked to kill executors before the Mesos driver was started.") + return false + } + + val slaveIdToTaskId = taskIdToSlaveId.inverse() + for (executorId <- executorIds) { + val slaveId = executorId.split("/")(0) + if (slaveIdToTaskId.contains(slaveId)) { + mesosDriver.killTask( + TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) + pendingRemovedSlaveIds += slaveId + } else { + logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler") + } + } + // no need to adjust `executorLimitOption` since the AllocationManager already communicated + // the desired limit through a call to `doRequestTotalExecutors`. + // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] + true + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 1067a7f1caf4c..d3a20f822176e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -29,6 +29,7 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.{Scheduler, SchedulerDriver} + import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index db0a080b3b0c0..d72e2af456e15 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -23,14 +23,14 @@ import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkException, TaskState} /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -59,6 +59,10 @@ private[spark] class MesosSchedulerBackend( private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1) + // Offer constraints + private[this] val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + @volatile var appId: String = _ override def start() { @@ -71,8 +75,8 @@ private[spark] class MesosSchedulerBackend( val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility .getOrElse { - throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") - } + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } val environment = Environment.newBuilder() sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => environment.addVariables( @@ -115,14 +119,14 @@ private[spark] class MesosSchedulerBackend( .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Value.Scalar.newBuilder() - .setValue(mesosExecutorCores).build()) + .setValue(mesosExecutorCores).build()) .build() val memory = Resource.newBuilder() .setName("mem") .setType(Value.Type.SCALAR) .setScalar( Value.Scalar.newBuilder() - .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) + .setValue(calculateTotalMemory(sc)).build()) .build() val executorInfo = MesosExecutorInfo.newBuilder() .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) @@ -146,7 +150,7 @@ private[spark] class MesosSchedulerBackend( private def createExecArg(): Array[Byte] = { if (execArgs == null) { val props = new HashMap[String, String] - for ((key,value) <- sc.conf.getAll) { + for ((key, value) <- sc.conf.getAll) { props(key) = value } // Serialize the map as an array of (String, String) pairs @@ -191,13 +195,31 @@ private[spark] class MesosSchedulerBackend( val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue - (mem >= MemoryUtils.calculateTotalMemory(sc) && - // need at least 1 for executor, 1 for task - cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)) || - (slaveIdsWithExecutors.contains(slaveId) && - cpus >= scheduler.CPUS_PER_TASK) + val offerAttributes = toAttributeMap(o.getAttributesList) + + // check if all constraints are satisfield + // 1. Attribute constraints + // 2. Memory requirements + // 3. CPU requirements - need at least 1 for executor, 1 for task + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) + val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) + + val meetsRequirements = + (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || + (slaveIdsWithExecutors.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) + + // add some debug messaging + val debugstr = if (meetsRequirements) "Accepting" else "Declining" + val id = o.getId.getValue + logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + + meetsRequirements } + // Decline offers we ruled out immediately + unUsableOffers.foreach(o => d.declineOffer(o.getId)) + val workerOffers = usableOffers.map { o => val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt @@ -223,15 +245,15 @@ private[spark] class MesosSchedulerBackend( val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty) acceptedOffers .foreach { offer => - offer.foreach { taskDesc => - val slaveId = taskDesc.executorId - slaveIdsWithExecutors += slaveId - slavesIdsOfAcceptedOffers += slaveId - taskIdToSlaveId(taskDesc.taskId) = slaveId - mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) - .add(createMesosTask(taskDesc, slaveId)) - } + offer.foreach { taskDesc => + val slaveId = taskDesc.executorId + slaveIdsWithExecutors += slaveId + slavesIdsOfAcceptedOffers += slaveId + taskIdToSlaveId(taskDesc.taskId) = slaveId + mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) + .add(createMesosTask(taskDesc, slaveId)) } + } // Reply to the offers val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? @@ -251,8 +273,6 @@ private[spark] class MesosSchedulerBackend( d.declineOffer(o.getId) } - // Decline offers we ruled out immediately - unUsableOffers.foreach(o => d.declineOffer(o.getId)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index 928c5cfed417a..e79c543a9de27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -37,14 +37,14 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .newBuilder() .setMode(Volume.Mode.RW) spec match { - case Array(container_path) => + case Array(container_path) => Some(vol.setContainerPath(container_path)) case Array(container_path, "rw") => Some(vol.setContainerPath(container_path)) case Array(container_path, "ro") => Some(vol.setContainerPath(container_path) .setMode(Volume.Mode.RO)) - case Array(host_path, container_path) => + case Array(host_path, container_path) => Some(vol.setContainerPath(container_path) .setHostPath(host_path)) case Array(host_path, container_path, "rw") => @@ -108,7 +108,7 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { image: String, volumes: Option[List[Volume]] = None, network: Option[ContainerInfo.DockerInfo.Network] = None, - portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None):Unit = { + portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None): Unit = { val docker = ContainerInfo.DockerInfo.newBuilder().setImage(image) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index d11228f3d016a..925702e63afd3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -17,14 +17,17 @@ package org.apache.spark.scheduler.cluster.mesos -import java.util.List +import java.util.{List => JList} import java.util.concurrent.CountDownLatch import scala.collection.JavaConversions._ +import scala.util.control.NonFatal -import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status} -import org.apache.mesos.{MesosSchedulerDriver, Scheduler} -import org.apache.spark.Logging +import com.google.common.base.Splitter +import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos} +import org.apache.mesos.Protos._ +import org.apache.mesos.protobuf.GeneratedMessage +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.util.Utils /** @@ -36,7 +39,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { private final val registerLatch = new CountDownLatch(1) // Driver for talking to Mesos - protected var mesosDriver: MesosSchedulerDriver = null + protected var mesosDriver: SchedulerDriver = null /** * Starts the MesosSchedulerDriver with the provided information. This method returns @@ -86,10 +89,150 @@ private[mesos] trait MesosSchedulerUtils extends Logging { /** * Get the amount of resources for the specified type from the resource list */ - protected def getResource(res: List[Resource], name: String): Double = { + protected def getResource(res: JList[Resource], name: String): Double = { for (r <- res if r.getName == name) { return r.getScalar.getValue } 0.0 } + + /** Helper method to get the key,value-set pair for a Mesos Attribute protobuf */ + protected def getAttribute(attr: Attribute): (String, Set[String]) = { + (attr.getName, attr.getText.getValue.split(',').toSet) + } + + + /** Build a Mesos resource protobuf object */ + protected def createResource(resourceName: String, quantity: Double): Protos.Resource = { + Resource.newBuilder() + .setName(resourceName) + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) + .build() + } + + /** + * Converts the attributes from the resource offer into a Map of name -> Attribute Value + * The attribute values are the mesos attribute types and they are + * @param offerAttributes + * @return + */ + protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { + offerAttributes.map(attr => { + val attrValue = attr.getType match { + case Value.Type.SCALAR => attr.getScalar + case Value.Type.RANGES => attr.getRanges + case Value.Type.SET => attr.getSet + case Value.Type.TEXT => attr.getText + } + (attr.getName, attrValue) + }).toMap + } + + + /** + * Match the requirements (if any) to the offer attributes. + * if attribute requirements are not specified - return true + * else if attribute is defined and no values are given, simple attribute presence is performed + * else if attribute name and value is specified, subset match is performed on slave attributes + */ + def matchesAttributeRequirements( + slaveOfferConstraints: Map[String, Set[String]], + offerAttributes: Map[String, GeneratedMessage]): Boolean = { + slaveOfferConstraints.forall { + // offer has the required attribute and subsumes the required values for that attribute + case (name, requiredValues) => + offerAttributes.get(name) match { + case None => false + case Some(_) if requiredValues.isEmpty => true // empty value matches presence + case Some(scalarValue: Value.Scalar) => + // check if provided values is less than equal to the offered values + requiredValues.map(_.toDouble).exists(_ <= scalarValue.getValue) + case Some(rangeValue: Value.Range) => + val offerRange = rangeValue.getBegin to rangeValue.getEnd + // Check if there is some required value that is between the ranges specified + // Note: We only support the ability to specify discrete values, in the future + // we may expand it to subsume ranges specified with a XX..YY value or something + // similar to that. + requiredValues.map(_.toLong).exists(offerRange.contains(_)) + case Some(offeredValue: Value.Set) => + // check if the specified required values is a subset of offered set + requiredValues.subsetOf(offeredValue.getItemList.toSet) + case Some(textValue: Value.Text) => + // check if the specified value is equal, if multiple values are specified + // we succeed if any of them match. + requiredValues.contains(textValue.getValue) + } + } + } + + /** + * Parses the attributes constraints provided to spark and build a matching data struct: + * Map[, Set[values-to-match]] + * The constraints are specified as ';' separated key-value pairs where keys and values + * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for + * multiple values (comma separated). For example: + * {{{ + * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") + * // would result in + * + * Map( + * "tachyon" -> Set("true"), + * "zone": -> Set("us-east-1a", "us-east-1b") + * ) + * }}} + * + * Mesos documentation: http://mesos.apache.org/documentation/attributes-resources/ + * https://github.com/apache/mesos/blob/master/src/common/values.cpp + * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp + * + * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated + * by ':') + * @return Map of constraints to match resources offers. + */ + def parseConstraintString(constraintsVal: String): Map[String, Set[String]] = { + /* + Based on mesos docs: + attributes : attribute ( ";" attribute )* + attribute : labelString ":" ( labelString | "," )+ + labelString : [a-zA-Z0-9_/.-] + */ + val splitter = Splitter.on(';').trimResults().withKeyValueSeparator(':') + // kv splitter + if (constraintsVal.isEmpty) { + Map() + } else { + try { + Map() ++ mapAsScalaMap(splitter.split(constraintsVal)).map { + case (k, v) => + if (v == null || v.isEmpty) { + (k, Set[String]()) + } else { + (k, v.split(',').toSet) + } + } + } catch { + case NonFatal(e) => + throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e) + } + } + } + + // These defaults copied from YARN + private val MEMORY_OVERHEAD_FRACTION = 0.10 + private val MEMORY_OVERHEAD_MINIMUM = 384 + + /** + * Return the amount of memory to allocate to each executor, taking into account + * container overheads. + * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value + * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM + * (whichever is larger) + */ + def calculateTotalMemory(sc: SparkContext): Int = { + sc.conf.getInt("spark.mesos.executor.memoryOverhead", + math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + + sc.executorMemory + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index e64d06c4d3cfc..4d48fcfea44e7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -17,15 +17,16 @@ package org.apache.spark.scheduler.local +import java.io.File +import java.net.URL import java.nio.ByteBuffer -import java.util.concurrent.TimeUnit import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} -import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo private case class ReviveOffers() @@ -42,21 +43,19 @@ private case class StopExecutor() */ private[spark] class LocalEndpoint( override val rpcEnv: RpcEnv, + userClassPath: Seq[URL], scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) extends ThreadSafeRpcEndpoint with Logging { - private val reviveThread = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("local-revive-thread") - private var freeCores = totalCores - private val localExecutorId = SparkContext.DRIVER_IDENTIFIER - private val localExecutorHostname = "localhost" + val localExecutorId = SparkContext.DRIVER_IDENTIFIER + val localExecutorHostname = "localhost" private val executor = new Executor( - localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) + localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true) override def receive: PartialFunction[Any, Unit] = { case ReviveOffers => @@ -79,27 +78,13 @@ private[spark] class LocalEndpoint( context.reply(true) } - def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) - val tasks = scheduler.resourceOffers(offers).flatten - for (task <- tasks) { + for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber, task.name, task.serializedTask) } - if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) { - // Try to reviveOffer after 1 second, because scheduler may wait for locality timeout - reviveThread.schedule(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - Option(self).foreach(_.send(ReviveOffers)) - } - }, 1000, TimeUnit.MILLISECONDS) - } - } - - override def onStop(): Unit = { - reviveThread.shutdownNow() } } @@ -115,11 +100,28 @@ private[spark] class LocalBackend( extends SchedulerBackend with ExecutorBackend with Logging { private val appId = "local-" + System.currentTimeMillis - var localEndpoint: RpcEndpointRef = null + private var localEndpoint: RpcEndpointRef = null + private val userClassPath = getUserClasspath(conf) + private val listenerBus = scheduler.sc.listenerBus + + /** + * Returns a list of URLs representing the user classpath. + * + * @param conf Spark configuration. + */ + def getUserClasspath(conf: SparkConf): Seq[URL] = { + val userClassPathStr = conf.getOption("spark.executor.extraClassPath") + userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL) + } override def start() { - localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint( - "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores)) + val rpcEnv = SparkEnv.get.rpcEnv + val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores) + localEndpoint = rpcEnv.setupEndpoint("LocalBackendEndpoint", executorEndpoint) + listenerBus.post(SparkListenerExecutorAdded( + System.currentTimeMillis, + executorEndpoint.localExecutorId, + new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty))) } override def stop() { diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index dfbde7c8a1b0d..4a5274b46b7a0 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -62,8 +62,11 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa extends DeserializationStream { private val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass): Class[_] = + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + // scalastyle:off classforname Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } } def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T] @@ -121,6 +124,8 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true) + protected def this() = this(new SparkConf()) // For deserialization only + override def newInstance(): SerializerInstance = { val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 64ba27f34d2f1..7cb6e080533ad 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -17,8 +17,9 @@ package org.apache.spark.serializer -import java.io.{EOFException, InputStream, OutputStream} +import java.io.{EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer +import javax.annotation.Nullable import scala.reflect.ClassTag @@ -35,7 +36,7 @@ import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} import org.apache.spark.util.collection.CompactBuffer /** @@ -51,7 +52,7 @@ class KryoSerializer(conf: SparkConf) with Serializable { private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") - + if (bufferSizeKb >= ByteUnit.GiB.toKiB(2)) { throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + s"2048 mb, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} mb.") @@ -93,12 +94,15 @@ class KryoSerializer(conf: SparkConf) // For results returned by asJavaIterable. See JavaIterableWrapperSerializer. kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer) - // Allow sending SerializableWritable + // Allow sending classes with custom Java serializers kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) + kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) + kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) try { + // scalastyle:off classforname // Use the default classloader when calling the user registrator. Thread.currentThread.setContextClassLoader(classLoader) // Register classes given through spark.kryo.classesToRegister. @@ -108,6 +112,7 @@ class KryoSerializer(conf: SparkConf) userRegistrator .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } + // scalastyle:on classforname } catch { case e: Exception => throw new SparkException(s"Failed to register classes with Kryo", e) @@ -136,21 +141,45 @@ class KryoSerializer(conf: SparkConf) } private[spark] -class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { - val output = new KryoOutput(outStream) +class KryoSerializationStream( + serInstance: KryoSerializerInstance, + outStream: OutputStream) extends SerializationStream { + + private[this] var output: KryoOutput = new KryoOutput(outStream) + private[this] var kryo: Kryo = serInstance.borrowKryo() override def writeObject[T: ClassTag](t: T): SerializationStream = { kryo.writeClassAndObject(output, t) this } - override def flush() { output.flush() } - override def close() { output.close() } + override def flush() { + if (output == null) { + throw new IOException("Stream is closed") + } + output.flush() + } + + override def close() { + if (output != null) { + try { + output.close() + } finally { + serInstance.releaseKryo(kryo) + kryo = null + output = null + } + } + } } private[spark] -class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { - private val input = new KryoInput(inStream) +class KryoDeserializationStream( + serInstance: KryoSerializerInstance, + inStream: InputStream) extends DeserializationStream { + + private[this] var input: KryoInput = new KryoInput(inStream) + private[this] var kryo: Kryo = serInstance.borrowKryo() override def readObject[T: ClassTag](): T = { try { @@ -163,50 +192,105 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser } override def close() { - // Kryo's Input automatically closes the input stream it is using. - input.close() + if (input != null) { + try { + // Kryo's Input automatically closes the input stream it is using. + input.close() + } finally { + serInstance.releaseKryo(kryo) + kryo = null + input = null + } + } } } private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - private val kryo = ks.newKryo() - // Make these lazy vals to avoid creating a buffer unless we use them + /** + * A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do + * their work, then release the instance by calling `releaseKryo()`. Logically, this is a caching + * pool of size one. SerializerInstances are not thread-safe, hence accesses to this field are + * not synchronized. + */ + @Nullable private[this] var cachedKryo: Kryo = borrowKryo() + + /** + * Borrows a [[Kryo]] instance. If possible, this tries to re-use a cached Kryo instance; + * otherwise, it allocates a new instance. + */ + private[serializer] def borrowKryo(): Kryo = { + if (cachedKryo != null) { + val kryo = cachedKryo + // As a defensive measure, call reset() to clear any Kryo state that might have been modified + // by the last operation to borrow this instance (see SPARK-7766 for discussion of this issue) + kryo.reset() + cachedKryo = null + kryo + } else { + ks.newKryo() + } + } + + /** + * Release a borrowed [[Kryo]] instance. If this serializer instance already has a cached Kryo + * instance, then the given Kryo instance is discarded; otherwise, the Kryo is stored for later + * re-use. + */ + private[serializer] def releaseKryo(kryo: Kryo): Unit = { + if (cachedKryo == null) { + cachedKryo = kryo + } + } + + // Make these lazy vals to avoid creating a buffer unless we use them. private lazy val output = ks.newKryoOutput() private lazy val input = new KryoInput() override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() + val kryo = borrowKryo() try { kryo.writeClassAndObject(output, t) } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + "increase spark.kryoserializer.buffer.max value.") + } finally { + releaseKryo(kryo) } ByteBuffer.wrap(output.toBytes) } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - input.setBuffer(bytes.array) - kryo.readClassAndObject(input).asInstanceOf[T] + val kryo = borrowKryo() + try { + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } finally { + releaseKryo(kryo) + } } override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + val kryo = borrowKryo() val oldClassLoader = kryo.getClassLoader - kryo.setClassLoader(loader) - input.setBuffer(bytes.array) - val obj = kryo.readClassAndObject(input).asInstanceOf[T] - kryo.setClassLoader(oldClassLoader) - obj + try { + kryo.setClassLoader(loader) + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } finally { + kryo.setClassLoader(oldClassLoader) + releaseKryo(kryo) + } } override def serializeStream(s: OutputStream): SerializationStream = { - new KryoSerializationStream(kryo, s) + new KryoSerializationStream(this, s) } override def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(kryo, s) + new KryoDeserializationStream(this, s) } /** @@ -216,7 +300,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ def getAutoReset(): Boolean = { val field = classOf[Kryo].getDeclaredField("autoReset") field.setAccessible(true) - field.get(kryo).asInstanceOf[Boolean] + val kryo = borrowKryo() + try { + field.get(kryo).asInstanceOf[Boolean] + } finally { + releaseKryo(kryo) + } } } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index 5abfa467c0ec8..a1b1e1631eafb 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField} +import java.io._ import java.lang.reflect.{Field, Method} import java.security.AccessController @@ -27,7 +27,7 @@ import scala.util.control.NonFatal import org.apache.spark.Logging -private[serializer] object SerializationDebugger extends Logging { +private[spark] object SerializationDebugger extends Logging { /** * Improve the given NotSerializableException with the serialization path leading from the given @@ -62,7 +62,7 @@ private[serializer] object SerializationDebugger extends Logging { * * It does not yet handle writeObject override, but that shouldn't be too hard to do either. */ - def find(obj: Any): List[String] = { + private[serializer] def find(obj: Any): List[String] = { new SerializationDebugger().visit(obj, List.empty) } @@ -125,6 +125,12 @@ private[serializer] object SerializationDebugger extends Logging { return List.empty } + /** + * Visit an externalizable object. + * Since writeExternal() can choose to add arbitrary objects at the time of serialization, + * the only way to capture all the objects it will serialize is by using a + * dummy ObjectOutput that collects all the relevant objects for further testing. + */ private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] = { val fieldList = new ListObjectOutput @@ -145,17 +151,50 @@ private[serializer] object SerializationDebugger extends Logging { // An object contains multiple slots in serialization. // Get the slots and visit fields in all of them. val (finalObj, desc) = findObjectAndDescriptor(o) + + // If the object has been replaced using writeReplace(), + // then call visit() on it again to test its type again. + if (!finalObj.eq(o)) { + return visit(finalObj, s"writeReplace data (class: ${finalObj.getClass.getName})" :: stack) + } + + // Every class is associated with one or more "slots", each slot refers to the parent + // classes of this class. These slots are used by the ObjectOutputStream + // serialization code to recursively serialize the fields of an object and + // its parent classes. For example, if there are the following classes. + // + // class ParentClass(parentField: Int) + // class ChildClass(childField: Int) extends ParentClass(1) + // + // Then serializing the an object Obj of type ChildClass requires first serializing the fields + // of ParentClass (that is, parentField), and then serializing the fields of ChildClass + // (that is, childField). Correspondingly, there will be two slots related to this object: + // + // 1. ParentClass slot, which will be used to serialize parentField of Obj + // 2. ChildClass slot, which will be used to serialize childField fields of Obj + // + // The following code uses the description of each slot to find the fields in the + // corresponding object to visit. + // val slotDescs = desc.getSlotDescs var i = 0 while (i < slotDescs.length) { val slotDesc = slotDescs(i) if (slotDesc.hasWriteObjectMethod) { - // TODO: Handle classes that specify writeObject method. + // If the class type corresponding to current slot has writeObject() defined, + // then its not obvious which fields of the class will be serialized as the writeObject() + // can choose arbitrary fields for serialization. This case is handled separately. + val elem = s"writeObject data (class: ${slotDesc.getName})" + val childStack = visitSerializableWithWriteObjectMethod(finalObj, elem :: stack) + if (childStack.nonEmpty) { + return childStack + } } else { + // Visit all the fields objects of the class corresponding to the current slot. val fields: Array[ObjectStreamField] = slotDesc.getFields val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields) val numPrims = fields.length - objFieldValues.length - desc.getObjFieldValues(finalObj, objFieldValues) + slotDesc.getObjFieldValues(finalObj, objFieldValues) var j = 0 while (j < objFieldValues.length) { @@ -169,18 +208,54 @@ private[serializer] object SerializationDebugger extends Logging { } j += 1 } - } i += 1 } return List.empty } + + /** + * Visit a serializable object which has the writeObject() defined. + * Since writeObject() can choose to add arbitrary objects at the time of serialization, + * the only way to capture all the objects it will serialize is by using a + * dummy ObjectOutputStream that collects all the relevant fields for further testing. + * This is similar to how externalizable objects are visited. + */ + private def visitSerializableWithWriteObjectMethod( + o: Object, stack: List[String]): List[String] = { + val innerObjectsCatcher = new ListObjectOutputStream + var notSerializableFound = false + try { + innerObjectsCatcher.writeObject(o) + } catch { + case io: IOException => + notSerializableFound = true + } + + // If something was not serializable, then visit the captured objects. + // Otherwise, all the captured objects are safely serializable, so no need to visit them. + // As an optimization, just added them to the visited list. + if (notSerializableFound) { + val innerObjects = innerObjectsCatcher.outputArray + var k = 0 + while (k < innerObjects.length) { + val childStack = visit(innerObjects(k), stack) + if (childStack.nonEmpty) { + return childStack + } + k += 1 + } + } else { + visited ++= innerObjectsCatcher.outputArray + } + return List.empty + } } /** * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles * writeReplace in Serializable. It starts with the object itself, and keeps calling the - * writeReplace method until there is no more + * writeReplace method until there is no more. */ @tailrec private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = { @@ -220,6 +295,31 @@ private[serializer] object SerializationDebugger extends Logging { override def writeByte(i: Int): Unit = {} } + /** An output stream that emulates /dev/null */ + private class NullOutputStream extends OutputStream { + override def write(b: Int) { } + } + + /** + * A dummy [[ObjectOutputStream]] that saves the list of objects written to it and returns + * them through `outputArray`. This works by using the [[ObjectOutputStream]]'s `replaceObject()` + * method which gets called on every object, only if replacing is enabled. So this subclass + * of [[ObjectOutputStream]] enabled replacing, and uses replaceObject to get the objects that + * are being serializabled. The serialized bytes are ignored by sending them to a + * [[NullOutputStream]], which acts like a /dev/null. + */ + private class ListObjectOutputStream extends ObjectOutputStream(new NullOutputStream) { + private val output = new mutable.ArrayBuffer[Any] + this.enableReplaceObject(true) + + def outputArray: Array[Any] = output.toArray + + override def replaceObject(obj: Object): Object = { + output += obj + obj + } + } + /** An implicit class that allows us to call private methods of ObjectStreamClass. */ implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal { def getSlotDescs: Array[ObjectStreamClass] = { @@ -307,7 +407,9 @@ private[serializer] object SerializationDebugger extends Logging { /** ObjectStreamClass$ClassDataSlot.desc field */ val DescField: Field = { + // scalastyle:off classforname val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc") + // scalastyle:on classforname f.setAccessible(true) f } diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 6078c9d433ebf..bd2704dc81871 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io._ import java.nio.ByteBuffer +import javax.annotation.concurrent.NotThreadSafe import scala.reflect.ClassTag @@ -114,8 +115,12 @@ object Serializer { /** * :: DeveloperApi :: * An instance of a serializer, for use by one thread at a time. + * + * It is legal to create multiple serialization / deserialization streams from the same + * SerializerInstance as long as those streams are all used within the same thread. */ @DeveloperApi +@NotThreadSafe abstract class SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer @@ -177,6 +182,7 @@ abstract class DeserializationStream { } catch { case eof: EOFException => finished = true + null } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 6ad427bcac7f9..f6a96d81e7aa9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVecto /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { - val writers: Array[BlockObjectWriter] + val writers: Array[DiskBlockObjectWriter] /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ def releaseWriters(success: Boolean) @@ -76,7 +76,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** @@ -113,15 +113,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() - val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { + val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { fileGroup = getUnusedFileGroup() - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, writeMetrics) } } else { - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) // Because of previous failures, the shuffle file may already exist on this machine. diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d9c63b6e7bbb9..fae69551e7330 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -114,7 +114,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB } private[spark] object IndexShuffleBlockResolver { - // No-op reduce ID used in interactions with disk store and BlockObjectWriter. + // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter. // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort // shuffle outputs for several reduces are glommed into a single file. // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index f6e6fe5defe09..4cc4ef5f1886e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -17,14 +17,17 @@ package org.apache.spark.shuffle +import java.io.IOException + import org.apache.spark.scheduler.MapStatus /** * Obtained inside a map task to write out records to the shuffle system. */ -private[spark] trait ShuffleWriter[K, V] { +private[spark] abstract class ShuffleWriter[K, V] { /** Write a sequence of records to this task's output */ - def write(records: Iterator[_ <: Product2[K, V]]): Unit + @throws[IOException] + def write(records: Iterator[Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ def stop(success: Boolean): Option[MapStatus] diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 80374adc44296..9d8e7e9f03aea 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -17,29 +17,29 @@ package org.apache.spark.shuffle.hash -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.util.{Failure, Success, Try} +import java.io.InputStream + +import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.{Failure, Success} import org.apache.spark._ -import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.CompletionIterator +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, + ShuffleBlockId} private[hash] object BlockStoreShuffleFetcher extends Logging { - def fetch[T]( + def fetchBlockStreams( shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer) - : Iterator[T] = + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker) + : Iterator[(BlockId, InputStream)] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val blockManager = SparkEnv.get.blockManager val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId) logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) @@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + blocksByAddress, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + + // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler + blockFetcherItr.map { blockPair => val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { - case Success(block) => { - block.asInstanceOf[Iterator[T]] + case Success(inputStream) => { + (blockId, inputStream) } case Failure(e) => { blockId match { @@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } } - - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - SparkEnv.get.blockManager.shuffleClient, - blockManager, - blocksByAddress, - serializer, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - val itr = blockFetcherItr.flatMap(unpackBlock) - - val completionIter = CompletionIterator[T, Iterator[T]](itr, { - context.taskMetrics.updateShuffleReadMetrics() - }) - - new InterruptibleIterator[T](context, completionIter) { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - override def next(): T = { - readMetrics.incRecordsRead(1) - delegate.next() - } - } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 41bafabde05b9..d5c9880659dd3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,16 +17,20 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.storage.BlockManager +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext) + context: TaskContext, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] { require(endPartition == startPartition + 1, @@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( + handle.shuffleId, startPartition, context, blockManager, mapOutputTracker) + + // Wrap the streams for compression based on configuration + val wrappedStreams = blockStreams.map { case (blockId, inputStream) => + blockManager.wrapForCompression(blockId, inputStream) + } + val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val serializerInstance = ser.newInstance() + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { wrappedStream => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map(record => { + readMetrics.incRecordsRead(1) + record + }), + context.taskMetrics().updateShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { - new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context)) + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else { - new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context)) + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") - - // Convert the Product2s to pairs since this is what downstream RDDs currently expect - iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // Sort the output if there is a sort ordering defined. diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 897f0a5dc5bcc..41df70c602c30 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -22,7 +22,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter private[spark] class HashShuffleWriter[K, V]( shuffleBlockResolver: FileShuffleBlockResolver, @@ -49,7 +49,7 @@ private[spark] class HashShuffleWriter[K, V]( writeMetrics) /** Write a bunch of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + override def write(records: Iterator[Product2[K, V]]): Unit = { val iter = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { dep.aggregator.get.combineValuesByKey(records, context) @@ -102,7 +102,7 @@ private[spark] class HashShuffleWriter[K, V]( private def commitWritesAndBuildStatus(): MapStatus = { // Commit the writes. Get the size of each bucket block (total block size). - val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter => + val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter => writer.commitAndClose() writer.fileSegment().length } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 15842941daaab..d7fab351ca3b8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -72,7 +72,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager true } - override def shuffleBlockResolver: IndexShuffleBlockResolver = { + override val shuffleBlockResolver: IndexShuffleBlockResolver = { indexShuffleBlockResolver } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index add2656294ca2..5865e7640c1cf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -17,9 +17,10 @@ package org.apache.spark.shuffle.sort -import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext} +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus +import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -35,7 +36,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val blockManager = SparkEnv.get.blockManager - private var sorter: ExternalSorter[K, V, _] = null + private var sorter: SortShuffleFileWriter[K, V] = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -48,19 +49,28 @@ private[spark] class SortShuffleWriter[K, V, C]( context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics) /** Write a bunch of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { - if (dep.mapSideCombine) { + override def write(records: Iterator[Product2[K, V]]): Unit = { + sorter = if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") - sorter = new ExternalSorter[K, V, C]( + new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - sorter.insertAll(records) + } else if (SortShuffleWriter.shouldBypassMergeSort( + SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need local aggregation and sorting, write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner, + writeMetrics, Serializer.getSerializer(dep.serializer)) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer) - sorter.insertAll(records) + new ExternalSorter[K, V, V]( + aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) } + sorter.insertAll(records) // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately @@ -100,3 +110,13 @@ private[spark] class SortShuffleWriter[K, V, C]( } } +private[spark] object SortShuffleWriter { + def shouldBypassMergeSort( + conf: SparkConf, + numPartitions: Int, + aggregator: Option[Aggregator[_, _, _]], + keyOrdering: Option[Ordering[_]]): Boolean = { + val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala new file mode 100644 index 0000000000000..df7bbd64247dd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe + +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark._ +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.sort.SortShuffleManager + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. + */ +private[spark] class UnsafeShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} + +private[spark] object UnsafeShuffleManager extends Logging { + + /** + * The maximum number of shuffle output partitions that UnsafeShuffleManager supports. + */ + val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + + /** + * Helper method for determining whether a shuffle should use the optimized unsafe shuffle + * path or whether it should fall back to the original sort-based shuffle. + */ + def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { + val shufId = dependency.shuffleId + val serializer = Serializer.getSerializer(dependency.serializer) + if (!serializer.supportsRelocationOfSerializedObjects) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " + + s"${serializer.getClass.getName}, does not support object relocation") + false + } else if (dependency.aggregator.isDefined) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") + false + } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") + false + } else { + log.debug(s"Can use UnsafeShuffle for shuffle $shufId") + true + } + } +} + +/** + * A shuffle implementation that uses directly-managed memory to implement several performance + * optimizations for certain types of shuffles. In cases where the new performance optimizations + * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those + * shuffles. + * + * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: + * + * - The shuffle dependency specifies no aggregation or output ordering. + * - The shuffle serializer supports relocation of serialized values (this is currently supported + * by KryoSerializer and Spark SQL's custom serializers). + * - The shuffle produces fewer than 16777216 output partitions. + * - No individual record is larger than 128 MB when serialized. + * + * In addition, extra spill-merging optimizations are automatically applied when the shuffle + * compression codec supports concatenation of serialized streams. This is currently supported by + * Spark's LZF serializer. + * + * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. + * In sort-based shuffle, incoming records are sorted according to their target partition ids, then + * written to a single map output file. Reducers fetch contiguous regions of this file in order to + * read their portion of the map output. In cases where the map output data is too large to fit in + * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * to produce the final output file. + * + * UnsafeShuffleManager optimizes this process in several ways: + * + * - Its sort operates on serialized binary data rather than Java objects, which reduces memory + * consumption and GC overheads. This optimization requires the record serializer to have certain + * properties to allow serialized records to be re-ordered without requiring deserialization. + * See SPARK-4550, where this optimization was first proposed and implemented, for more details. + * + * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts + * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per + * record in the sorting array, this fits more of the array into cache. + * + * - The spill merging procedure operates on blocks of serialized records that belong to the same + * partition and does not need to deserialize records during the merge. + * + * - When the spill compression codec supports concatenation of compressed data, the spill merge + * simply concatenates the serialized and compressed spill partitions to produce the final output + * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used + * and avoids the need to allocate decompression or copying buffers during the merge. + * + * For more details on UnsafeShuffleManager's design, see SPARK-7081. + */ +private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " + + "manager; its optimized shuffles will continue to spill to disk when necessary.") + } + + private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) + private[this] val shufflesThatFellBackToSortShuffle = + Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]()) + private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]() + + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) { + new UnsafeShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] = { + sortShuffleManager.getReader(handle, startPartition, endPartition, context) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext): ShuffleWriter[K, V] = { + handle match { + case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => + numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) + val env = SparkEnv.get + new UnsafeShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + context.taskMemoryManager(), + env.shuffleMemoryManager, + unsafeShuffleHandle, + mapId, + context, + env.conf) + case other => + shufflesThatFellBackToSortShuffle.add(handle.shuffleId) + sortShuffleManager.getWriter(handle, mapId, context) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) { + sortShuffleManager.unregisterShuffle(shuffleId) + } else { + Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps => + (0 until numMaps).foreach { mapId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } + } + + override val shuffleBlockResolver: IndexShuffleBlockResolver = { + sortShuffleManager.shuffleBlockResolver + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + sortShuffleManager.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 50608588f09ae..390c136df79b3 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -169,7 +169,7 @@ private[v1] object AllStagesResource { val outputMetrics: Option[OutputMetricDistributions] = new MetricHelper[InternalOutputMetrics, OutputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw:InternalTaskMetrics): Option[InternalOutputMetrics] = { + def getSubmetrics(raw: InternalTaskMetrics): Option[InternalOutputMetrics] = { raw.outputMetrics } def build: OutputMetricDistributions = new OutputMetricDistributions( @@ -284,7 +284,7 @@ private[v1] object AllStagesResource { * the options (returning None if the metrics are all empty), and extract the quantiles for each * metric. After creating an instance, call metricOption to get the result type. */ -private[v1] abstract class MetricHelper[I,O]( +private[v1] abstract class MetricHelper[I, O]( rawMetrics: Seq[InternalTaskMetrics], quantiles: Array[Double]) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JsonRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala similarity index 86% rename from core/src/main/scala/org/apache/spark/status/api/v1/JsonRootResource.scala rename to core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index c3ec45f54681b..50b6ba67e9931 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JsonRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.status.api.v1 +import java.util.zip.ZipOutputStream import javax.servlet.ServletContext import javax.ws.rs._ import javax.ws.rs.core.{Context, Response} @@ -39,7 +40,7 @@ import org.apache.spark.ui.SparkUI * HistoryServerSuite. */ @Path("/v1") -private[v1] class JsonRootResource extends UIRootFromServletContext { +private[v1] class ApiRootResource extends UIRootFromServletContext { @Path("applications") def getApplicationList(): ApplicationListResource = { @@ -101,7 +102,7 @@ private[v1] class JsonRootResource extends UIRootFromServletContext { @Path("applications/{appId}/stages") - def getStages(@PathParam("appId") appId: String): AllStagesResource= { + def getStages(@PathParam("appId") appId: String): AllStagesResource = { uiRoot.withSparkUI(appId, None) { ui => new AllStagesResource(ui) } @@ -110,14 +111,14 @@ private[v1] class JsonRootResource extends UIRootFromServletContext { @Path("applications/{appId}/{attemptId}/stages") def getStages( @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): AllStagesResource= { + @PathParam("attemptId") attemptId: String): AllStagesResource = { uiRoot.withSparkUI(appId, Some(attemptId)) { ui => new AllStagesResource(ui) } } @Path("applications/{appId}/stages/{stageId: \\d+}") - def getStage(@PathParam("appId") appId: String): OneStageResource= { + def getStage(@PathParam("appId") appId: String): OneStageResource = { uiRoot.withSparkUI(appId, None) { ui => new OneStageResource(ui) } @@ -164,14 +165,26 @@ private[v1] class JsonRootResource extends UIRootFromServletContext { } } + @Path("applications/{appId}/logs") + def getEventLogs( + @PathParam("appId") appId: String): EventLogDownloadResource = { + new EventLogDownloadResource(uiRoot, appId, None) + } + + @Path("applications/{appId}/{attemptId}/logs") + def getEventLogs( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { + new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + } } -private[spark] object JsonRootResource { +private[spark] object ApiRootResource { - def getJsonServlet(uiRoot: UIRoot): ServletContextHandler = { + def getServletHandler(uiRoot: UIRoot): ServletContextHandler = { val jerseyContext = new ServletContextHandler(ServletContextHandler.NO_SESSIONS) - jerseyContext.setContextPath("/json") - val holder:ServletHolder = new ServletHolder(classOf[ServletContainer]) + jerseyContext.setContextPath("/api") + val holder: ServletHolder = new ServletHolder(classOf[ServletContainer]) holder.setInitParameter("com.sun.jersey.config.property.resourceConfigClass", "com.sun.jersey.api.core.PackagesResourceConfig") holder.setInitParameter("com.sun.jersey.config.property.packages", @@ -193,6 +206,17 @@ private[spark] trait UIRoot { def getSparkUI(appKey: String): Option[SparkUI] def getApplicationInfoList: Iterator[ApplicationInfo] + /** + * Write the event logs for the given app to the [[ZipOutputStream]] instance. If attemptId is + * [[None]], event logs for all attempts of this application will be written out. + */ + def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit = { + Response.serverError() + .entity("Event logs are only available through the history server.") + .status(Response.Status.SERVICE_UNAVAILABLE) + .build() + } + /** * Get the spark UI with the given appID, and apply a function * to it. If there is no such app, throw an appropriate exception diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala new file mode 100644 index 0000000000000..22e21f0c62a29 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import java.io.OutputStream +import java.util.zip.ZipOutputStream +import javax.ws.rs.{GET, Produces} +import javax.ws.rs.core.{MediaType, Response, StreamingOutput} + +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil + +@Produces(Array(MediaType.APPLICATION_OCTET_STREAM)) +private[v1] class EventLogDownloadResource( + val uIRoot: UIRoot, + val appId: String, + val attemptId: Option[String]) extends Logging { + val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf) + + @GET + def getEventLogs(): Response = { + try { + val fileName = { + attemptId match { + case Some(id) => s"eventLogs-$appId-$id.zip" + case None => s"eventLogs-$appId.zip" + } + } + + val stream = new StreamingOutput { + override def write(output: OutputStream): Unit = { + val zipStream = new ZipOutputStream(output) + try { + uIRoot.writeEventLogs(appId, attemptId, zipStream) + } finally { + zipStream.close() + } + + } + } + + Response.ok(stream) + .header("Content-Disposition", s"attachment; filename=$fileName") + .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) + .build() + } catch { + case NonFatal(e) => + Response.serverError() + .entity(s"Event logs are not available for app: $appId.") + .status(Response.Status.SERVICE_UNAVAILABLE) + .build() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala index 07b224fac4786..dfdc09c6caf3b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala @@ -25,7 +25,7 @@ import org.apache.spark.ui.SparkUI private[v1] class OneRDDResource(ui: SparkUI) { @GET - def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { + def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { AllRDDResource.getRDDStorageInfo(rddId, ui.storageListener, true).getOrElse( throw new NotFoundException(s"no rdd found w/ id $rddId") ) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala index fd24aea63a8a1..f9812f06cf527 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala @@ -83,7 +83,7 @@ private[v1] class OneStageResource(ui: SparkUI) { withStageAttempt(stageId, stageAttemptId) { stage => val tasks = stage.ui.taskData.values.map{AllStagesResource.convertTaskData}.toIndexedSeq .sorted(OneStageResource.ordering(sortBy)) - tasks.slice(offset, offset + length) + tasks.slice(offset, offset + length) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index cee29786c3019..0c71cd2382225 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -16,40 +16,33 @@ */ package org.apache.spark.status.api.v1 -import java.text.SimpleDateFormat +import java.text.{ParseException, SimpleDateFormat} import java.util.TimeZone import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status -import scala.util.Try - private[v1] class SimpleDateParam(val originalValue: String) { - val timestamp: Long = { - SimpleDateParam.formats.collectFirst { - case fmt if Try(fmt.parse(originalValue)).isSuccess => - fmt.parse(originalValue).getTime() - }.getOrElse( - throw new WebApplicationException( - Response - .status(Status.BAD_REQUEST) - .entity("Couldn't parse date: " + originalValue) - .build() - ) - ) - } -} -private[v1] object SimpleDateParam { - - val formats: Seq[SimpleDateFormat] = { - - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") - gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) - - Seq( - new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz"), - gmtDay - ) + val timestamp: Long = { + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + try { + format.parse(originalValue).getTime() + } catch { + case _: ParseException => + val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) + try { + gmtDay.parse(originalValue).getTime() + } catch { + case _: ParseException => + throw new WebApplicationException( + Response + .status(Status.BAD_REQUEST) + .entity("Couldn't parse date: " + originalValue) + .build() + ) + } + } } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index ef3c8570d8186..2bec64f2ef02b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -134,7 +134,7 @@ class StageData private[spark]( val accumulatorUpdates: Seq[AccumulableInfo], val tasks: Option[Map[Long, TaskData]], - val executorSummary:Option[Map[String,ExecutorStageSummary]]) + val executorSummary: Option[Map[String, ExecutorStageSummary]]) class TaskData private[spark]( val taskId: Long, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index cc794e5c90ffa..86493673d958d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,12 +17,11 @@ package org.apache.spark.storage -import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream} +import java.io._ import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.concurrent.{Await, Future} -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.{ExecutionContext, Await, Future} import scala.concurrent.duration._ import scala.util.Random @@ -77,12 +76,17 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] + private val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) + // Actual storage of where blocks are kept private var externalBlockStoreInitialized = false private[spark] val memoryStore = new MemoryStore(this, maxMemory) private[spark] val diskStore = new DiskStore(this, diskBlockManager) - private[spark] lazy val externalBlockStore: ExternalBlockStore = + private[spark] lazy val externalBlockStore: ExternalBlockStore = { + externalBlockStoreInitialized = true new ExternalBlockStore(this, executorId) + } private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) @@ -266,11 +270,13 @@ private[spark] class BlockManager( asyncReregisterLock.synchronized { if (asyncReregisterTask == null) { asyncReregisterTask = Future[Unit] { + // This is a blocking action and should run in futureExecutionContext which is a cached + // thread pool reregister() asyncReregisterLock.synchronized { asyncReregisterTask = null } - } + }(futureExecutionContext) } } } @@ -485,16 +491,17 @@ private[spark] class BlockManager( if (level.useOffHeap) { logDebug(s"Getting block $blockId from ExternalBlockStore") if (externalBlockStore.contains(blockId)) { - externalBlockStore.getBytes(blockId) match { - case Some(bytes) => - if (!asBlockResult) { - return Some(bytes) - } else { - return Some(new BlockResult( - dataDeserialize(blockId, bytes), DataReadMethod.Memory, info.size)) - } + val result = if (asBlockResult) { + externalBlockStore.getValues(blockId) + .map(new BlockResult(_, DataReadMethod.Memory, info.size)) + } else { + externalBlockStore.getBytes(blockId) + } + result match { + case Some(values) => + return result case None => - logDebug(s"Block $blockId not found in externalBlockStore") + logDebug(s"Block $blockId not found in ExternalBlockStore") } } } @@ -641,7 +648,7 @@ private[spark] class BlockManager( file: File, serializerInstance: SerializerInstance, bufferSize: Int, - writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = { + writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream, @@ -744,7 +751,11 @@ private[spark] class BlockManager( case b: ByteBufferValues if putLevel.replication > 1 => // Duplicate doesn't copy the bytes, but just creates a wrapper val bufferView = b.buffer.duplicate() - Future { replicate(blockId, bufferView, putLevel) } + Future { + // This is a blocking action and should run in futureExecutionContext which is a cached + // thread pool + replicate(blockId, bufferView, putLevel) + }(futureExecutionContext) case _ => null } @@ -1198,8 +1209,19 @@ private[spark] class BlockManager( bytes: ByteBuffer, serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() - val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) - serializer.newInstance().deserializeStream(stream).asIterator + dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true), serializer) + } + + /** + * Deserializes a InputStream into an iterator of values and disposes of it when the end of + * the iterator is reached. + */ + def dataDeserializeStream( + blockId: BlockId, + inputStream: InputStream, + serializer: Serializer = defaultSerializer): Iterator[Any] = { + val stream = new BufferedInputStream(inputStream) + serializer.newInstance().deserializeStream(wrapForCompression(blockId, stream)).asIterator } def stop(): Unit = { @@ -1218,6 +1240,7 @@ private[spark] class BlockManager( } metadataCleaner.cancel() broadcastCleaner.cancel() + futureExecutionContext.shutdownNow() logInfo("BlockManager stopped") } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index a85e1c7632973..f70f701494dbf 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -17,13 +17,14 @@ package org.apache.spark.storage +import scala.collection.Iterable +import scala.collection.generic.CanBuildFrom import scala.concurrent.{Await, Future} -import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.RpcUtils +import org.apache.spark.util.{ThreadUtils, RpcUtils} private[spark] class BlockManagerMaster( @@ -32,7 +33,7 @@ class BlockManagerMaster( isDriver: Boolean) extends Logging { - val timeout = RpcUtils.askTimeout(conf) + val timeout = RpcUtils.askRpcTimeout(conf) /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { @@ -102,10 +103,10 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") - } + logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e) + }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -114,10 +115,10 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") - } + logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e) + }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -128,10 +129,10 @@ class BlockManagerMaster( future.onFailure { case e: Exception => logWarning(s"Failed to remove broadcast $broadcastId" + - s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}") - } + s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e) + }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -169,11 +170,17 @@ class BlockManagerMaster( val response = driverEndpoint. askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip - val result = Await.result(Future.sequence(futures), timeout) - if (result == null) { + implicit val sameThread = ThreadUtils.sameThread + val cbf = + implicitly[ + CanBuildFrom[Iterable[Future[Option[BlockStatus]]], + Option[BlockStatus], + Iterable[Option[BlockStatus]]]] + val blockStatus = timeout.awaitResult( + Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread)) + if (blockStatus == null) { throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) } - val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]] blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) => status.map { s => (blockManagerId, s) } }.toMap @@ -192,7 +199,15 @@ class BlockManagerMaster( askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg) - Await.result(future, timeout) + timeout.awaitResult(future) + } + + /** + * Find out if the executor has cached blocks. This method does not consider broadcast blocks, + * since they are not reported the master. + */ + def hasCachedBlocks(executorId: String): Boolean = { + driverEndpoint.askWithRetry[Boolean](HasCachedBlocks(executorId)) } /** Stop the driver endpoint, called only on the Spark driver node */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 3afb4c3c02e2d..5dc0c537cbb62 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.util.{HashMap => JHashMap} +import scala.collection.immutable.HashSet import scala.collection.mutable import scala.collection.JavaConversions._ import scala.concurrent.{ExecutionContext, Future} @@ -59,10 +60,11 @@ class BlockManagerMasterEndpoint( register(blockManagerId, maxMemSize, slaveEndpoint) context.reply(true) - case UpdateBlockInfo( + case _updateBlockInfo @ UpdateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize) => context.reply(updateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize)) + listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo))) case GetLocations(blockId) => context.reply(getLocations(blockId)) @@ -112,6 +114,17 @@ class BlockManagerMasterEndpoint( case BlockManagerHeartbeat(blockManagerId) => context.reply(heartbeatReceived(blockManagerId)) + case HasCachedBlocks(executorId) => + blockManagerIdByExecutor.get(executorId) match { + case Some(bm) => + if (blockManagerInfo.contains(bm)) { + val bmInfo = blockManagerInfo(bm) + context.reply(bmInfo.cachedBlocks.nonEmpty) + } else { + context.reply(false) + } + case None => context.reply(false) + } } private def removeRdd(rddId: Int): Future[Seq[Int]] = { @@ -292,16 +305,16 @@ class BlockManagerMasterEndpoint( blockManagerIdByExecutor.get(id.executorId) match { case Some(oldId) => // A block manager of the same executor already exists, so remove it (assumed dead) - logError("Got two different block manager registrations on same executor - " + logError("Got two different block manager registrations on same executor - " + s" will replace old one $oldId with new one $id") - removeExecutor(id.executorId) + removeExecutor(id.executorId) case None => } logInfo("Registering block manager %s with %s RAM, %s".format( id.hostPort, Utils.bytesToString(maxMemSize), id)) - + blockManagerIdByExecutor(id.executorId) = id - + blockManagerInfo(id) = new BlockManagerInfo( id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) } @@ -418,6 +431,9 @@ private[spark] class BlockManagerInfo( // Mapping from block id to its status. private val _blocks = new JHashMap[BlockId, BlockStatus] + // Cached blocks held by this BlockManager. This does not include broadcast blocks. + private val _cachedBlocks = new mutable.HashSet[BlockId] + def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) def updateLastSeenMs() { @@ -451,27 +467,35 @@ private[spark] class BlockManagerInfo( * and the diskSize here indicates the data size in or dropped to disk. * They can be both larger than 0, when a block is dropped from memory to disk. * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */ + var blockStatus: BlockStatus = null if (storageLevel.useMemory) { - _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0)) + blockStatus = BlockStatus(storageLevel, memSize, 0, 0) + _blocks.put(blockId, blockStatus) _remainingMem -= memSize logInfo("Added %s in memory on %s (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), Utils.bytesToString(_remainingMem))) } if (storageLevel.useDisk) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0)) + blockStatus = BlockStatus(storageLevel, 0, diskSize, 0) + _blocks.put(blockId, blockStatus) logInfo("Added %s on disk on %s (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) } if (storageLevel.useOffHeap) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, externalBlockStoreSize)) + blockStatus = BlockStatus(storageLevel, 0, 0, externalBlockStoreSize) + _blocks.put(blockId, blockStatus) logInfo("Added %s on ExternalBlockStore on %s (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(externalBlockStoreSize))) } + if (!blockId.isBroadcast && blockStatus.isCached) { + _cachedBlocks += blockId + } } else if (_blocks.containsKey(blockId)) { // If isValid is not true, drop the block. val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) + _cachedBlocks -= blockId if (blockStatus.storageLevel.useMemory) { logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), @@ -494,6 +518,7 @@ private[spark] class BlockManagerInfo( _remainingMem += _blocks.get(blockId).memSize _blocks.remove(blockId) } + _cachedBlocks -= blockId } def remainingMem: Long = _remainingMem @@ -502,6 +527,9 @@ private[spark] class BlockManagerInfo( def blocks: JHashMap[BlockId, BlockStatus] = _blocks + // This does not include broadcast blocks. + def cachedBlocks: collection.Set[BlockId] = _cachedBlocks + override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem def clear() { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 1683576067fe8..376e9eb48843d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -42,7 +42,6 @@ private[spark] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave - ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// @@ -108,4 +107,6 @@ private[spark] object BlockManagerMessages { extends ToBlockManagerMaster case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + + case class HasCachedBlocks(executorId: String) extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 543df4e1350dd..7478ab0fc2f7a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -40,7 +40,7 @@ class BlockManagerSlaveEndpoint( private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 8569c6f3cbbc3..c5ba9af3e2658 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -17,9 +17,8 @@ package org.apache.spark.storage -import com.codahale.metrics.{Gauge,MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry} -import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source private[spark] class BlockManagerSource(val blockManager: BlockManager) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala new file mode 100644 index 0000000000000..2789e25b8d3ab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.collection.mutable + +import org.apache.spark.scheduler._ + +private[spark] case class BlockUIData( + blockId: BlockId, + location: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) + +/** + * The aggregated status of stream blocks in an executor + */ +private[spark] case class ExecutorStreamBlockStatus( + executorId: String, + location: String, + blocks: Seq[BlockUIData]) { + + def totalMemSize: Long = blocks.map(_.memSize).sum + + def totalDiskSize: Long = blocks.map(_.diskSize).sum + + def totalExternalBlockStoreSize: Long = blocks.map(_.externalBlockStoreSize).sum + + def numStreamBlocks: Int = blocks.size + +} + +private[spark] class BlockStatusListener extends SparkListener { + + private val blockManagers = + new mutable.HashMap[BlockManagerId, mutable.HashMap[BlockId, BlockUIData]] + + override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { + val blockId = blockUpdated.blockUpdatedInfo.blockId + if (!blockId.isInstanceOf[StreamBlockId]) { + // Now we only monitor StreamBlocks + return + } + val blockManagerId = blockUpdated.blockUpdatedInfo.blockManagerId + val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel + val memSize = blockUpdated.blockUpdatedInfo.memSize + val diskSize = blockUpdated.blockUpdatedInfo.diskSize + val externalBlockStoreSize = blockUpdated.blockUpdatedInfo.externalBlockStoreSize + + synchronized { + // Drop the update info if the block manager is not registered + blockManagers.get(blockManagerId).foreach { blocksInBlockManager => + if (storageLevel.isValid) { + blocksInBlockManager.put(blockId, + BlockUIData( + blockId, + blockManagerId.hostPort, + storageLevel, + memSize, + diskSize, + externalBlockStoreSize) + ) + } else { + // If isValid is not true, it means we should drop the block. + blocksInBlockManager -= blockId + } + } + } + } + + override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { + synchronized { + blockManagers.put(blockManagerAdded.blockManagerId, mutable.HashMap()) + } + } + + override def onBlockManagerRemoved( + blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = synchronized { + blockManagers -= blockManagerRemoved.blockManagerId + } + + def allExecutorStreamBlockStatus: Seq[ExecutorStreamBlockStatus] = synchronized { + blockManagers.map { case (blockManagerId, blocks) => + ExecutorStreamBlockStatus( + blockManagerId.executorId, blockManagerId.hostPort, blocks.values.toSeq) + }.toSeq + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala similarity index 52% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala rename to core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala index a4a3a66b8b229..a5790e4454a89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala @@ -15,22 +15,33 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.optimizer +package org.apache.spark.storage -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.BlockManagerMessages.UpdateBlockInfo /** - * Overrides our expression evaluation tests and reruns them after optimization has occured. This - * is to ensure that constant folding and other optimizations do not break anything. + * :: DeveloperApi :: + * Stores information about a block status in a block manager. */ -class ExpressionOptimizationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) - super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow) +@DeveloperApi +case class BlockUpdatedInfo( + blockManagerId: BlockManagerId, + blockId: BlockId, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) + +private[spark] object BlockUpdatedInfo { + + private[spark] def apply(updateBlockInfo: UpdateBlockInfo): BlockUpdatedInfo = { + BlockUpdatedInfo( + updateBlockInfo.blockManagerId, + updateBlockInfo.blockId, + updateBlockInfo.storageLevel, + updateBlockInfo.memSize, + updateBlockInfo.diskSize, + updateBlockInfo.externalBlockStoreSize) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 2a4447705fa65..5f537692a16c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -124,10 +124,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon (blockId, getFile(blockId)) } + /** + * Create local directories for storing block data. These directories are + * located inside configured local directories and won't + * be deleted on JVM exit when using the external shuffle service. + */ private def createLocalDirs(conf: SparkConf): Array[File] = { - Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir => + Utils.getConfiguredLocalDirs(conf).flatMap { rootDir => try { val localDir = Utils.createDirectory(rootDir, "blockmgr") + Utils.chmod700(localDir) logInfo(s"Created local directory at $localDir") Some(localDir) } catch { @@ -139,8 +145,8 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def addShutdownHook(): AnyRef = { - Utils.addShutdownHook { () => - logDebug("Shutdown hook called") + Utils.addShutdownHook(Utils.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => + logInfo("Shutdown hook called") DiskBlockManager.this.doStop() } } @@ -151,7 +157,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon try { Utils.removeShutdownHook(shutdownHook) } catch { - case e: Exception => + case e: Exception => logError(s"Exception while removing shutdown hook.", e) } doStop() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala similarity index 70% rename from core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala rename to core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 8bc4e205bc3c6..49d9154f95a5b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -26,76 +26,25 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.util.Utils /** - * An interface for writing JVM objects to some underlying storage. This interface allows - * appending data to an existing block, and can guarantee atomicity in the case of faults - * as it allows the caller to revert partial writes. + * A class for writing JVM objects directly to a file on disk. This class allows data to be appended + * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to + * revert partial writes. * - * This interface does not support concurrent writes. Also, once the writer has - * been opened, it cannot be reopened again. - */ -private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream { - - def open(): BlockObjectWriter - - def close() - - def isOpen: Boolean - - /** - * Flush the partial writes and commit them as a single atomic block. - */ - def commitAndClose(): Unit - - /** - * Reverts writes that haven't been flushed yet. Callers should invoke this function - * when there are runtime exceptions. This method will not throw, though it may be - * unsuccessful in truncating written data. - */ - def revertPartialWritesAndClose() - - /** - * Writes a key-value pair. - */ - def write(key: Any, value: Any) - - /** - * Notify the writer that a record worth of bytes has been written with OutputStream#write. - */ - def recordWritten() - - /** - * Returns the file segment of committed data that this Writer has written. - * This is only valid after commitAndClose() has been called. - */ - def fileSegment(): FileSegment -} - -/** - * BlockObjectWriter which writes directly to a file on disk. Appends to the given file. + * This class does not support concurrent writes. Also, once the writer has been opened it cannot be + * reopened again. */ private[spark] class DiskBlockObjectWriter( - blockId: BlockId, + val blockId: BlockId, file: File, serializerInstance: SerializerInstance, bufferSize: Int, compressStream: OutputStream => OutputStream, syncWrites: Boolean, - // These write metrics concurrently shared with other active BlockObjectWriter's who + // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. writeMetrics: ShuffleWriteMetrics) - extends BlockObjectWriter(blockId) - with Logging -{ - /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ - private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { - override def write(i: Int): Unit = callWithTiming(out.write(i)) - override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b)) - override def write(b: Array[Byte], off: Int, len: Int): Unit = { - callWithTiming(out.write(b, off, len)) - } - override def close(): Unit = out.close() - override def flush(): Unit = out.flush() - } + extends OutputStream + with Logging { /** The file channel, used for repositioning / truncating the file. */ private var channel: FileChannel = null @@ -105,6 +54,7 @@ private[spark] class DiskBlockObjectWriter( private var objOut: SerializationStream = null private var initialized = false private var hasBeenClosed = false + private var commitAndCloseHasBeenCalled = false /** * Cursors used to represent positions in the file. @@ -131,12 +81,12 @@ private[spark] class DiskBlockObjectWriter( */ private var numRecordsWritten = 0 - override def open(): BlockObjectWriter = { + def open(): DiskBlockObjectWriter = { if (hasBeenClosed) { throw new IllegalStateException("Writer already closed. Cannot be reopened.") } fos = new FileOutputStream(file, true) - ts = new TimeTrackingOutputStream(fos) + ts = new TimeTrackingOutputStream(writeMetrics, fos) channel = fos.getChannel() bs = compressStream(new BufferedOutputStream(ts, bufferSize)) objOut = serializerInstance.serializeStream(bs) @@ -150,9 +100,9 @@ private[spark] class DiskBlockObjectWriter( if (syncWrites) { // Force outstanding writes to disk and track how long it takes objOut.flush() - callWithTiming { - fos.getFD.sync() - } + val start = System.nanoTime() + fos.getFD.sync() + writeMetrics.incShuffleWriteTime(System.nanoTime() - start) } } { objOut.close() @@ -168,29 +118,40 @@ private[spark] class DiskBlockObjectWriter( } } - override def isOpen: Boolean = objOut != null + def isOpen: Boolean = objOut != null - override def commitAndClose(): Unit = { + /** + * Flush the partial writes and commit them as a single atomic block. + */ + def commitAndClose(): Unit = { if (initialized) { // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the // serializer stream and the lower level stream. objOut.flush() bs.flush() close() + finalPosition = file.length() + // In certain compression codecs, more bytes are written after close() is called + writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + } else { + finalPosition = file.length() } - finalPosition = file.length() - // In certain compression codecs, more bytes are written after close() is called - writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + commitAndCloseHasBeenCalled = true } - // Discard current writes. We do this by flushing the outstanding writes and then - // truncating the file to its initial position. - override def revertPartialWritesAndClose() { - try { - writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) - writeMetrics.decShuffleRecordsWritten(numRecordsWritten) + /** + * Reverts writes that haven't been flushed yet. Callers should invoke this function + * when there are runtime exceptions. This method will not throw, though it may be + * unsuccessful in truncating written data. + */ + def revertPartialWritesAndClose() { + // Discard current writes. We do this by flushing the outstanding writes and then + // truncating the file to its initial position. + try { if (initialized) { + writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) + writeMetrics.decShuffleRecordsWritten(numRecordsWritten) objOut.flush() bs.flush() close() @@ -208,7 +169,10 @@ private[spark] class DiskBlockObjectWriter( } } - override def write(key: Any, value: Any) { + /** + * Writes a key-value pair. + */ + def write(key: Any, value: Any) { if (!initialized) { open() } @@ -228,7 +192,10 @@ private[spark] class DiskBlockObjectWriter( bs.write(kvBytes, offs, len) } - override def recordWritten(): Unit = { + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + def recordWritten(): Unit = { numRecordsWritten += 1 writeMetrics.incShuffleRecordsWritten(1) @@ -237,7 +204,15 @@ private[spark] class DiskBlockObjectWriter( } } - override def fileSegment(): FileSegment = { + /** + * Returns the file segment of committed data that this Writer has written. + * This is only valid after commitAndClose() has been called. + */ + def fileSegment(): FileSegment = { + if (!commitAndCloseHasBeenCalled) { + throw new IllegalStateException( + "fileSegment() is only valid after commitAndClose() has been called") + } new FileSegment(file, initialPosition, finalPosition - initialPosition) } @@ -251,12 +226,6 @@ private[spark] class DiskBlockObjectWriter( reportedPosition = pos } - private def callWithTiming(f: => Unit) = { - val start = System.nanoTime() - f - writeMetrics.incShuffleWriteTime(System.nanoTime() - start) - } - // For testing private[spark] override def flush() { objOut.flush() diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala index 8964762df6af3..f39325a12d244 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala @@ -32,6 +32,8 @@ import java.nio.ByteBuffer */ private[spark] abstract class ExternalBlockManager { + protected var blockManager: BlockManager = _ + override def toString: String = {"External Block Store"} /** @@ -41,7 +43,9 @@ private[spark] abstract class ExternalBlockManager { * * @throws java.io.IOException if there is any file system failure during the initialization. */ - def init(blockManager: BlockManager, executorId: String): Unit + def init(blockManager: BlockManager, executorId: String): Unit = { + this.blockManager = blockManager + } /** * Drop the block from underlying external block store, if it exists.. @@ -73,6 +77,11 @@ private[spark] abstract class ExternalBlockManager { */ def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit + def putValues(blockId: BlockId, values: Iterator[_]): Unit = { + val bytes = blockManager.dataSerialize(blockId, values) + putBytes(blockId, bytes) + } + /** * Retrieve the block bytes. * @return Some(ByteBuffer) if the block bytes is successfully retrieved @@ -82,6 +91,17 @@ private[spark] abstract class ExternalBlockManager { */ def getBytes(blockId: BlockId): Option[ByteBuffer] + /** + * Retrieve the block data. + * @return Some(Iterator[Any]) if the block data is successfully retrieved + * None if the block does not exist in the external block store. + * + * @throws java.io.IOException if there is any file system failure in getting the block. + */ + def getValues(blockId: BlockId): Option[Iterator[_]] = { + getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) + } + /** * Get the size of the block saved in the underlying external block store, * which is saved before by putBytes. diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala index 0bf770306ae9b..db965d54bafd6 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala @@ -18,9 +18,11 @@ package org.apache.spark.storage import java.nio.ByteBuffer + +import scala.util.control.NonFatal + import org.apache.spark.Logging import org.apache.spark.util.Utils -import scala.util.control.NonFatal /** @@ -40,7 +42,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: externalBlockManager.map(_.getSize(blockId)).getOrElse(0) } catch { case NonFatal(t) => - logError(s"error in getSize from $blockId", t) + logError(s"Error in getSize($blockId)", t) 0L } } @@ -54,7 +56,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: values: Array[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - putIterator(blockId, values.toIterator, level, returnValues) + putIntoExternalBlockStore(blockId, values.toIterator, returnValues) } override def putIterator( @@ -62,42 +64,70 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: values: Iterator[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - logDebug(s"Attempting to write values for block $blockId") - val bytes = blockManager.dataSerialize(blockId, values) - putIntoExternalBlockStore(blockId, bytes, returnValues) + putIntoExternalBlockStore(blockId, values, returnValues) } private def putIntoExternalBlockStore( blockId: BlockId, - bytes: ByteBuffer, + values: Iterator[_], returnValues: Boolean): PutResult = { - // So that we do not modify the input offsets ! - // duplicate does not copy buffer, so inexpensive - val byteBuffer = bytes.duplicate() - byteBuffer.rewind() - logDebug(s"Attempting to put block $blockId into ExtBlk store") + logTrace(s"Attempting to put block $blockId into ExternalBlockStore") // we should never hit here if externalBlockManager is None. Handle it anyway for safety. try { val startTime = System.currentTimeMillis if (externalBlockManager.isDefined) { - externalBlockManager.get.putBytes(blockId, bytes) + externalBlockManager.get.putValues(blockId, values) + val size = getSize(blockId) + val data = if (returnValues) { + Left(getValues(blockId).get) + } else { + null + } val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format( - blockId, Utils.bytesToString(byteBuffer.limit), finishTime - startTime)) + blockId, Utils.bytesToString(size), finishTime - startTime)) + PutResult(size, data) + } else { + logError(s"Error in putValues($blockId): no ExternalBlockManager has been configured") + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) + } + } catch { + case NonFatal(t) => + logError(s"Error in putValues($blockId)", t) + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) + } + } - if (returnValues) { - PutResult(bytes.limit(), Right(bytes.duplicate())) + private def putIntoExternalBlockStore( + blockId: BlockId, + bytes: ByteBuffer, + returnValues: Boolean): PutResult = { + logTrace(s"Attempting to put block $blockId into ExternalBlockStore") + // we should never hit here if externalBlockManager is None. Handle it anyway for safety. + try { + val startTime = System.currentTimeMillis + if (externalBlockManager.isDefined) { + val byteBuffer = bytes.duplicate() + byteBuffer.rewind() + externalBlockManager.get.putBytes(blockId, byteBuffer) + val size = bytes.limit() + val data = if (returnValues) { + Right(bytes) } else { - PutResult(bytes.limit(), null) + null } + val finishTime = System.currentTimeMillis + logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format( + blockId, Utils.bytesToString(size), finishTime - startTime)) + PutResult(size, data) } else { - logError(s"error in putBytes $blockId") - PutResult(bytes.limit(), null, Seq((blockId, BlockStatus.empty))) + logError(s"Error in putBytes($blockId): no ExternalBlockManager has been configured") + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) } } catch { case NonFatal(t) => - logError(s"error in putBytes $blockId", t) - PutResult(bytes.limit(), null, Seq((blockId, BlockStatus.empty))) + logError(s"Error in putBytes($blockId)", t) + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) } } @@ -107,13 +137,19 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: externalBlockManager.map(_.removeBlock(blockId)).getOrElse(true) } catch { case NonFatal(t) => - logError(s"error in removing $blockId", t) + logError(s"Error in removeBlock($blockId)", t) true } } override def getValues(blockId: BlockId): Option[Iterator[Any]] = { - getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) + try { + externalBlockManager.flatMap(_.getValues(blockId)) + } catch { + case NonFatal(t) => + logError(s"Error in getValues($blockId)", t) + None + } } override def getBytes(blockId: BlockId): Option[ByteBuffer] = { @@ -121,7 +157,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: externalBlockManager.flatMap(_.getBytes(blockId)) } catch { case NonFatal(t) => - logError(s"error in getBytes from $blockId", t) + logError(s"Error in getBytes($blockId)", t) None } } @@ -130,13 +166,13 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: try { val ret = externalBlockManager.map(_.blockExists(blockId)).getOrElse(false) if (!ret) { - logInfo(s"remove block $blockId") + logInfo(s"Remove block $blockId") blockManager.removeBlock(blockId, true) } ret } catch { case NonFatal(t) => - logError(s"error in getBytes from $blockId", t) + logError(s"Error in getBytes($blockId)", t) false } } @@ -156,7 +192,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: .getOrElse(ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME) try { - val instance = Class.forName(clsName) + val instance = Utils.classForName(clsName) .newInstance() .asInstanceOf[ExternalBlockManager] instance.init(blockManager, executorId) diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 95e2d688d9b17..021a9facfb0b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -24,6 +24,8 @@ import java.io.File * based off an offset and a length. */ private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { + require(offset >= 0, s"File segment offset cannot be negative (got $offset)") + require(length >= 0, s"File segment length cannot be negative (got $length)") override def toString: String = { "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d0faab62c9e9e..e49e39679e940 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,23 +17,23 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import scala.util.{Failure, Try} import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.serializer.{SerializerInstance, Serializer} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.util.Utils /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. * - * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a - * pipelined fashion as they are received. + * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks + * in a pipelined fashion as they are received. * * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid * using too much memory. @@ -44,7 +44,6 @@ import org.apache.spark.util.{CompletionIterator, Utils} * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. - * @param serializer serializer used to deserialize the data. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. */ private[spark] @@ -53,9 +52,8 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[InputStream])] with Logging { import ShuffleBlockFetcherIterator._ @@ -83,7 +81,7 @@ final class ShuffleBlockFetcherIterator( /** * A queue to hold our results. This turns the asynchronous model provided by - * [[BlockTransferService]] into a synchronous model (iterator). + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). */ private[this] val results = new LinkedBlockingQueue[FetchResult] @@ -102,9 +100,7 @@ final class ShuffleBlockFetcherIterator( /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L - private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - - private[this] val serializerInstance: SerializerInstance = serializer.newInstance() + private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency() /** * Whether the iterator is still active. If isZombie is true, the callback interface will no @@ -114,17 +110,23 @@ final class ShuffleBlockFetcherIterator( initialize() - /** - * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. - */ - private[this] def cleanup() { - isZombie = true + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { case SuccessFetchResult(_, _, buf) => buf.release() case _ => } + currentResult = null + } + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + releaseCurrentResultBuffer() // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { @@ -272,7 +274,13 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Try[Iterator[Any]]) = { + /** + * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers + * underlying each InputStream will be freed by the cleanup() method registered with the + * TaskCompletionListener. However, callers should close() these InputStreams + * as soon as they are no longer needed, in order to release memory as early as possible. + */ + override def next(): (BlockId, Try[InputStream]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -290,22 +298,15 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorTry: Try[Iterator[Any]] = result match { + val iteratorTry: Try[InputStream] = result match { case FailureFetchResult(_, e) => Failure(e) case SuccessFetchResult(blockId, _, buf) => // There is a chance that createInputStream can fail (e.g. fetching a local file that does // not exist, SPARK-4085). In that case, we should propagate the right exception so // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { is0 => - val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializerInstance.deserializeStream(is).asKeyValueIterator - CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - buf.release() - }) + Try(buf.createInputStream()).map { inputStream => + new BufferReleasingInputStream(inputStream, this) } } @@ -313,6 +314,39 @@ final class ShuffleBlockFetcherIterator( } } +/** + * Helper class that ensures a ManagedBuffer is release upon InputStream.close() + */ +private class BufferReleasingInputStream( + private val delegate: InputStream, + private val iterator: ShuffleBlockFetcherIterator) + extends InputStream { + private[this] var closed = false + + override def read(): Int = delegate.read() + + override def close(): Unit = { + if (!closed) { + delegate.close() + iterator.releaseCurrentResultBuffer() + closed = true + } + } + + override def available(): Int = delegate.available() + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = delegate.skip(n) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = delegate.read(b) + + override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len) + + override def reset(): Unit = delegate.reset() +} private[storage] object ShuffleBlockFetcherIterator { diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index bdc6276e41915..b53c86e89a273 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -22,7 +22,10 @@ import java.nio.ByteBuffer import java.text.SimpleDateFormat import java.util.{Date, Random} +import scala.util.control.NonFatal + import com.google.common.io.ByteStreams + import tachyon.client.{ReadType, WriteType, TachyonFS, TachyonFile} import tachyon.TachyonURI @@ -38,7 +41,6 @@ import org.apache.spark.util.Utils */ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Logging { - var blockManager: BlockManager =_ var rootDirs: String = _ var master: String = _ var client: tachyon.client.TachyonFS = _ @@ -52,7 +54,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def init(blockManager: BlockManager, executorId: String): Unit = { - this.blockManager = blockManager + super.init(blockManager, executorId) val storeDir = blockManager.conf.get(ExternalBlockStore.BASE_DIR, "/tmp_spark_tachyon") val appFolderName = blockManager.conf.get(ExternalBlockStore.FOLD_NAME) @@ -95,8 +97,29 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit = { val file = getFile(blockId) val os = file.getOutStream(WriteType.TRY_CACHE) - os.write(bytes.array()) - os.close() + try { + os.write(bytes.array()) + } catch { + case NonFatal(e) => + logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) + os.cancel() + } finally { + os.close() + } + } + + override def putValues(blockId: BlockId, values: Iterator[_]): Unit = { + val file = getFile(blockId) + val os = file.getOutStream(WriteType.TRY_CACHE) + try { + blockManager.dataSerializeStream(blockId, os, values) + } catch { + case NonFatal(e) => + logWarning(s"Failed to put values of block $blockId into Tachyon", e) + os.cancel() + } finally { + os.close() + } } override def getBytes(blockId: BlockId): Option[ByteBuffer] = { @@ -105,21 +128,31 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log return None } val is = file.getInStream(ReadType.CACHE) - assert (is != null) try { val size = file.length val bs = new Array[Byte](size.asInstanceOf[Int]) ByteStreams.readFully(is, bs) Some(ByteBuffer.wrap(bs)) } catch { - case ioe: IOException => - logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe) + case NonFatal(e) => + logWarning(s"Failed to get bytes of block $blockId from Tachyon", e) None } finally { is.close() } } + override def getValues(blockId: BlockId): Option[Iterator[_]] = { + val file = getFile(blockId) + if (file == null || file.getLocationHosts().size() == 0) { + return None + } + val is = file.getInStream(ReadType.CACHE) + Option(is).map { is => + blockManager.dataDeserializeStream(blockId, is) + } + } + override def getSize(blockId: BlockId): Long = { getFile(blockId.name).length } @@ -184,7 +217,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log tachyonDir = client.getFile(path) } } catch { - case e: Exception => + case NonFatal(e) => logWarning("Attempt " + tries + " to create tachyon dir " + tachyonDir + " failed", e) } } @@ -206,7 +239,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log Utils.deleteRecursively(tachyonDir, client) } } catch { - case e: Exception => + case NonFatal(e) => logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 06e616220c706..c8356467fab87 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -68,7 +68,9 @@ private[spark] object JettyUtils extends Logging { response.setStatus(HttpServletResponse.SC_OK) val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + // scalastyle:off println response.getWriter.println(servletParams.extractFn(result)) + // scalastyle:on println } else { response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") @@ -210,10 +212,16 @@ private[spark] object JettyUtils extends Logging { conf: SparkConf, serverName: String = ""): ServerInfo = { - val collection = new ContextHandlerCollection - collection.setHandlers(handlers.toArray) addFilters(handlers, conf) + val collection = new ContextHandlerCollection + val gzipHandlers = handlers.map { h => + val gzipHandler = new GzipHandler + gzipHandler.setHandler(h) + gzipHandler + } + collection.setHandlers(gzipHandlers.toArray) + // Bind to the given port, or throw a java.net.BindException if the port is occupied def connect(currentPort: Int): (Server, Int) = { val server = new Server(new InetSocketAddress(hostName, currentPort)) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index bfe4a180e8a6f..3788916cf39bb 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -19,7 +19,8 @@ package org.apache.spark.ui import java.util.Date -import org.apache.spark.status.api.v1.{ApplicationAttemptInfo, ApplicationInfo, JsonRootResource, UIRoot} +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, + UIRoot} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -64,7 +65,7 @@ private[spark] class SparkUI private ( attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) - attachHandler(JsonRootResource.getJsonServlet(this)) + attachHandler(ApiRootResource.getServletHandler(this)) // This should be POST only, but, the YARN AM proxy won't proxy POSTs attachHandler(createRedirectHandler( "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, @@ -136,7 +137,7 @@ private[spark] object SparkUI { jobProgressListener: JobProgressListener, securityManager: SecurityManager, appName: String, - startTime: Long): SparkUI = { + startTime: Long): SparkUI = { create(Some(sc), conf, listenerBus, securityManager, appName, jobProgressListener = Some(jobProgressListener), startTime = startTime) } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 063e2a1f8b18e..e2d25e36365fa 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -35,6 +35,10 @@ private[spark] object ToolTips { val OUTPUT = "Bytes and records written to Hadoop." + val STORAGE_MEMORY = + "Memory used / total available memory for storage of data " + + "like RDD partitions cached in memory. " + val SHUFFLE_WRITE = "Bytes and records written to disk in order to be read by a shuffle in a future stage." diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index ad16becde85dd..718aea7e1dc22 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.ui.scope.RDDOperationGraph /** Utility functions for generating XML pages with spark content. */ private[spark] object UIUtils extends Logging { - val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable" + val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed" val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. @@ -267,9 +267,17 @@ private[spark] object UIUtils extends Logging { fixedWidth: Boolean = false, id: Option[String] = None, headerClasses: Seq[String] = Seq.empty, - stripeRowsWithCss: Boolean = true): Seq[Node] = { + stripeRowsWithCss: Boolean = true, + sortable: Boolean = true): Seq[Node] = { - val listingTableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED + val listingTableClass = { + val _tableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED + if (sortable) { + _tableClass + " sortable" + } else { + _tableClass + } + } val colWidth = 100.toDouble / headers.size val colWidthAttr = if (fixedWidth) colWidth + "%" else "" @@ -309,7 +317,7 @@ private[spark] object UIUtils extends Logging { started: Int, completed: Int, failed: Int, - skipped:Int, + skipped: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) val startWidth = "width: %s%%".format((started.toDouble/total)*100) @@ -352,15 +360,17 @@ private[spark] object UIUtils extends Logging {
    -
    + + +
    +First we import the neccessary classes. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.clustering import StreamingKMeans +{% endhighlight %} + +Then we make an input stream of vectors for training, as well as a stream of labeled data +points for testing. We assume a StreamingContext `ssc` has been created, see +[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. + +{% highlight python %} +def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + +trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse) +testData = ssc.textFileStream("/testing/data/dir").map(parse) +{% endhighlight %} + +We create a model with random clusters and specify the number of clusters to find + +{% highlight python %} +model = StreamingKMeans(k=2, decayFactor=1.0).setRandomCenters(3, 1.0, 0) +{% endhighlight %} + +Now register the streams for training and testing and start the job, printing +the predicted cluster assignments on new data points as they arrive. + +{% highlight python %} +model.trainOn(trainingData) +print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + +ssc.start() +ssc.awaitTermination() +{% endhighlight %} +
    + +
    As you add new text files with data the cluster centers will update. Each training point should be formatted as `[x1, x2, x3]`, and each test data point @@ -580,7 +690,3 @@ should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or id (e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. With new data, the cluster centers will change! - -
    - -
    diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 7b397e30b2d90..eedc23424ad54 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -77,7 +77,7 @@ val ratings = data.map(_.split(',') match { case Array(user, item, rate) => // Build the recommendation model using ALS val rank = 10 -val numIterations = 20 +val numIterations = 10 val model = ALS.train(ratings, rank, numIterations, 0.01) // Evaluate the model on rating data @@ -107,7 +107,8 @@ other signals), you can use the `trainImplicit` method to get better results. {% highlight scala %} val alpha = 0.01 -val model = ALS.trainImplicit(ratings, rank, numIterations, alpha) +val lambda = 0.01 +val model = ALS.trainImplicit(ratings, rank, numIterations, lambda, alpha) {% endhighlight %} @@ -148,7 +149,7 @@ public class CollaborativeFiltering { // Build the recommendation model using ALS int rank = 10; - int numIterations = 20; + int numIterations = 10; MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); // Evaluate the model on rating data @@ -209,7 +210,7 @@ ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l # Build the recommendation model using Alternating Least Squares rank = 10 -numIterations = 20 +numIterations = 10 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 4f2a2f71048f7..3aa040046fca5 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -31,7 +31,7 @@ The base class of local vectors is implementations: [`DenseVector`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseVector) and [`SparseVector`](api/scala/index.html#org.apache.spark.mllib.linalg.SparseVector). We recommend using the factory methods implemented in -[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) to create local vectors. +[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) to create local vectors. {% highlight scala %} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -57,7 +57,7 @@ The base class of local vectors is implementations: [`DenseVector`](api/java/org/apache/spark/mllib/linalg/DenseVector.html) and [`SparseVector`](api/java/org/apache/spark/mllib/linalg/SparseVector.html). We recommend using the factory methods implemented in -[`Vectors`](api/java/org/apache/spark/mllib/linalg/Vector.html) to create local vectors. +[`Vectors`](api/java/org/apache/spark/mllib/linalg/Vectors.html) to create local vectors. {% highlight java %} import org.apache.spark.mllib.linalg.Vector; @@ -84,7 +84,7 @@ and the following as sparse vectors: with a single column We recommend using NumPy arrays over lists for efficiency, and using the factory methods implemented -in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vector) to create sparse vectors. +in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vectors) to create sparse vectors. {% highlight python %} import numpy as np @@ -226,7 +226,8 @@ examples = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") A local matrix has integer-typed row and column indices and double-typed values, stored on a single machine. MLlib supports dense matrices, whose entry values are stored in a single double array in -column major. For example, the following matrix `\[ \begin{pmatrix} +column-major order, and sparse matrices, whose non-zero entry values are stored in the Compressed Sparse +Column (CSC) format in column-major order. For example, the following dense matrix `\[ \begin{pmatrix} 1.0 & 2.0 \\ 3.0 & 4.0 \\ 5.0 & 6.0 @@ -238,28 +239,33 @@ is stored in a one-dimensional array `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]` with the m
    The base class of local matrices is -[`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide one -implementation: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix). +[`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide two +implementations: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix), +and [`SparseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.SparseMatrix). We recommend using the factory methods implemented -in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices) to create local -matrices. +in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices$) to create local +matrices. Remember, local matrices in MLlib are stored in column-major order. {% highlight scala %} import org.apache.spark.mllib.linalg.{Matrix, Matrices} // Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) val dm: Matrix = Matrices.dense(3, 2, Array(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +val sm: Matrix = Matrices.sparse(3, 2, Array(0, 1, 3), Array(0, 2, 1), Array(9, 6, 8)) {% endhighlight %}
    The base class of local matrices is -[`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide one -implementation: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html). +[`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide two +implementations: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html), +and [`SparseMatrix`](api/java/org/apache/spark/mllib/linalg/SparseMatrix.html). We recommend using the factory methods implemented in [`Matrices`](api/java/org/apache/spark/mllib/linalg/Matrices.html) to create local -matrices. +matrices. Remember, local matrices in MLlib are stored in column-major order. {% highlight java %} import org.apache.spark.mllib.linalg.Matrix; @@ -267,6 +273,30 @@ import org.apache.spark.mllib.linalg.Matrices; // Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) Matrix dm = Matrices.dense(3, 2, new double[] {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +Matrix sm = Matrices.sparse(3, 2, new int[] {0, 1, 3}, new int[] {0, 2, 1}, new double[] {9, 6, 8}); +{% endhighlight %} +
    + +
    + +The base class of local matrices is +[`Matrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrix), and we provide two +implementations: [`DenseMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.DenseMatrix), +and [`SparseMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.SparseMatrix). +We recommend using the factory methods implemented +in [`Matrices`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrices) to create local +matrices. Remember, local matrices in MLlib are stored in column-major order. + +{% highlight python %} +import org.apache.spark.mllib.linalg.{Matrix, Matrices} + +// Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) +dm2 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) + +// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 2, 1], [9, 6, 8]) {% endhighlight %}
    @@ -296,70 +326,6 @@ backed by an RDD of its entries. The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. In general the use of non-deterministic RDDs can lead to errors. -### BlockMatrix - -A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock`s, where a `MatrixBlock` is -a tuple of `((Int, Int), Matrix)`, where the `(Int, Int)` is the index of the block, and `Matrix` is -the sub-matrix at the given index with size `rowsPerBlock` x `colsPerBlock`. -`BlockMatrix` supports methods such as `add` and `multiply` with another `BlockMatrix`. -`BlockMatrix` also has a helper function `validate` which can be used to check whether the -`BlockMatrix` is set up properly. - -
    -
    - -A [`BlockMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) can be -most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. -`toBlockMatrix` creates blocks of size 1024 x 1024 by default. -Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. - -{% highlight scala %} -import org.apache.spark.mllib.linalg.distributed.{BlockMatrix, CoordinateMatrix, MatrixEntry} - -val entries: RDD[MatrixEntry] = ... // an RDD of (i, j, v) matrix entries -// Create a CoordinateMatrix from an RDD[MatrixEntry]. -val coordMat: CoordinateMatrix = new CoordinateMatrix(entries) -// Transform the CoordinateMatrix to a BlockMatrix -val matA: BlockMatrix = coordMat.toBlockMatrix().cache() - -// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. -// Nothing happens if it is valid. -matA.validate() - -// Calculate A^T A. -val ata = matA.transpose.multiply(matA) -{% endhighlight %} -
    - -
    - -A [`BlockMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/BlockMatrix.html) can be -most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. -`toBlockMatrix` creates blocks of size 1024 x 1024 by default. -Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. - -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.distributed.BlockMatrix; -import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; -import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix; - -JavaRDD entries = ... // a JavaRDD of (i, j, v) Matrix Entries -// Create a CoordinateMatrix from a JavaRDD. -CoordinateMatrix coordMat = new CoordinateMatrix(entries.rdd()); -// Transform the CoordinateMatrix to a BlockMatrix -BlockMatrix matA = coordMat.toBlockMatrix().cache(); - -// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. -// Nothing happens if it is valid. -matA.validate(); - -// Calculate A^T A. -BlockMatrix ata = matA.transpose().multiply(matA); -{% endhighlight %} -
    -
    - ### RowMatrix A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, backed by an RDD @@ -530,3 +496,67 @@ IndexedRowMatrix indexedRowMatrix = mat.toIndexedRowMatrix(); {% endhighlight %} + +### BlockMatrix + +A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock`s, where a `MatrixBlock` is +a tuple of `((Int, Int), Matrix)`, where the `(Int, Int)` is the index of the block, and `Matrix` is +the sub-matrix at the given index with size `rowsPerBlock` x `colsPerBlock`. +`BlockMatrix` supports methods such as `add` and `multiply` with another `BlockMatrix`. +`BlockMatrix` also has a helper function `validate` which can be used to check whether the +`BlockMatrix` is set up properly. + +
    +
    + +A [`BlockMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.distributed.{BlockMatrix, CoordinateMatrix, MatrixEntry} + +val entries: RDD[MatrixEntry] = ... // an RDD of (i, j, v) matrix entries +// Create a CoordinateMatrix from an RDD[MatrixEntry]. +val coordMat: CoordinateMatrix = new CoordinateMatrix(entries) +// Transform the CoordinateMatrix to a BlockMatrix +val matA: BlockMatrix = coordMat.toBlockMatrix().cache() + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate() + +// Calculate A^T A. +val ata = matA.transpose.multiply(matA) +{% endhighlight %} +
    + +
    + +A [`BlockMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/BlockMatrix.html) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.distributed.BlockMatrix; +import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; +import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix; + +JavaRDD entries = ... // a JavaRDD of (i, j, v) Matrix Entries +// Create a CoordinateMatrix from a JavaRDD. +CoordinateMatrix coordMat = new CoordinateMatrix(entries.rdd()); +// Transform the CoordinateMatrix to a BlockMatrix +BlockMatrix matA = coordMat.toBlockMatrix().cache(); + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate(); + +// Calculate A^T A. +BlockMatrix ata = matA.transpose().multiply(matA); +{% endhighlight %} +
    +
    diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index f723cd6b9dfab..a69e41e2a1936 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -188,7 +188,7 @@ Here we assume the extracted file is `text8` and in same directory as you run th import org.apache.spark._ import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.feature.Word2Vec +import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} val input = sc.textFile("text8").map(line => line.split(" ").toSeq) @@ -201,6 +201,10 @@ val synonyms = model.findSynonyms("china", 40) for((synonym, cosineSimilarity) <- synonyms) { println(s"$synonym $cosineSimilarity") } + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = Word2VecModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -380,7 +384,7 @@ data2 = labels.zip(normalizer2.transform(features)) [Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. ### ChiSqSelector -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which are most closely related to the label. +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which the class label depends on the most. This is akin to yielding the features with the most predictive power. #### Model Fitting @@ -401,7 +405,7 @@ Note that the user can also construct a `ChiSqSelectorModel` by hand by providin #### Example -The following example shows the basic use of ChiSqSelector. +The following example shows the basic use of ChiSqSelector. The data set used has a feature matrix consisting of greyscale values that vary from 0 to 255 for each feature.
    @@ -410,14 +414,16 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.feature.ChiSqSelector // Load some data in libsvm format val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") // Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category val discretizedData = data.map { lp => - LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => x / 16 } ) ) + LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => (x / 16).floor } ) ) } -// Create ChiSqSelector that will select 50 features +// Create ChiSqSelector that will select top 50 of 692 features val selector = new ChiSqSelector(50) // Create ChiSqSelector model (selecting features) val transformer = selector.fit(discretizedData) @@ -446,19 +452,20 @@ JavaRDD points = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache(); // Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category JavaRDD discretizedData = points.map( new Function() { @Override public LabeledPoint call(LabeledPoint lp) { final double[] discretizedFeatures = new double[lp.features().size()]; for (int i = 0; i < lp.features().size(); ++i) { - discretizedFeatures[i] = lp.features().apply(i) / 16; + discretizedFeatures[i] = Math.floor(lp.features().apply(i) / 16); } return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); } }); -// Create ChiSqSelector that will select 50 features +// Create ChiSqSelector that will select top 50 of 692 features ChiSqSelector selector = new ChiSqSelector(50); // Create ChiSqSelector model (selecting features) final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); @@ -505,7 +512,7 @@ v_N ### Example -This example below demonstrates how to load a simple vectors file, extract a set of vectors, then transform those vectors using a transforming vector value. +This example below demonstrates how to transform vectors using a transforming vector value.
    @@ -514,16 +521,67 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.feature.ElementwiseProduct import org.apache.spark.mllib.linalg.Vectors -// Load and parse the data: -val data = sc.textFile("data/mllib/kmeans_data.txt") -val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))) +// Create some vector data; also works for sparse vectors +val data = sc.parallelize(Array(Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))) val transformingVector = Vectors.dense(0.0, 1.0, 2.0) val transformer = new ElementwiseProduct(transformingVector) // Batch transform and per-row transform give the same results: -val transformedData = transformer.transform(parsedData) -val transformedData2 = parsedData.map(x => transformer.transform(x)) +val transformedData = transformer.transform(data) +val transformedData2 = data.map(x => transformer.transform(x)) + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +// Create some vector data; also works for sparse vectors +JavaRDD data = sc.parallelize(Arrays.asList( + Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))); +Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); +ElementwiseProduct transformer = new ElementwiseProduct(transformingVector); + +// Batch transform and per-row transform give the same results: +JavaRDD transformedData = transformer.transform(data); +JavaRDD transformedData2 = data.map( + new Function() { + @Override + public Vector call(Vector v) { + return transformer.transform(v); + } + } +); + +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.feature import ElementwiseProduct + +# Load and parse the data +sc = SparkContext() +data = sc.textFile("data/mllib/kmeans_data.txt") +parsedData = data.map(lambda x: [float(t) for t in x.split(" ")]) + +# Create weight vector. +transformingVector = Vectors.dense([0.0, 1.0, 2.0]) +transformer = ElementwiseProduct(transformingVector) + +# Batch transform +transformedData = transformer.transform(parsedData) +# Single-row transform +transformedData2 = transformer.transform(parsedData.first()) {% endhighlight %}
    diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 9fd9be0dd01b1..bcc066a185526 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -39,11 +39,11 @@ MLlib's FP-growth implementation takes the following (hyper-)parameters:
    -[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the +[`FPGrowth`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) implements the FP-growth algorithm. It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. Calling `FPGrowth.run` with transactions returns an -[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) +[`FPGrowthModel`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel) that stores the frequent itemsets with their frequencies. {% highlight scala %} diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index f8e879496c135..d2d1cc93fe006 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -7,7 +7,19 @@ description: MLlib machine learning library overview for Spark SPARK_VERSION_SHO MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, including classification, regression, clustering, collaborative -filtering, dimensionality reduction, as well as underlying optimization primitives, as outlined below: +filtering, dimensionality reduction, as well as underlying optimization primitives. +Guides for individual algorithms are listed below. + +The API is divided into 2 parts: + +* [The original `spark.mllib` API](mllib-guide.html#mllib-types-algorithms-and-utilities) is the primary API. +* [The "Pipelines" `spark.ml` API](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) is a higher-level API for constructing ML workflows. + +We list major functionality from both below, with links to detailed guides. + +# MLlib types, algorithms and utilities + +This lists functionality included in `spark.mllib`, the main MLlib API. * [Data types](mllib-data-types.html) * [Basic statistics](mllib-statistics.html) @@ -39,6 +51,7 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) +* [PMML model export](mllib-pmml-model-export.html) MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, @@ -48,8 +61,8 @@ and the migration guide below will explain all changes between releases. Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. -It is currently an alpha component, and we would like to hear back from the community about -how it fits real-world use cases and how it could be improved. + +*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. Note that we will keep supporting and adding features to `spark.mllib` along with the development of `spark.ml`. @@ -57,7 +70,11 @@ Users should be comfortable using `spark.mllib` features and expect more feature Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`. -See the **[spark.ml programming guide](ml-guide.html)** for more information on this package. +More detailed guides for `spark.ml` include: + +* **[spark.ml programming guide](ml-guide.html)**: overview of the Pipelines API and major concepts +* [Feature transformers](ml-features.html): Details on transformers supported in the Pipelines API, including a few not in the lower-level `spark.mllib` API +* [Ensembles](ml-ensembles.html): Details on ensemble learning methods in the Pipelines API # Dependencies @@ -89,21 +106,14 @@ version 1.4 or newer. For the `spark.ml` package, please see the [spark.ml Migration Guide](ml-guide.html#migration-guide). -## From 1.2 to 1.3 - -In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. - -* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. -* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. -* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: - * The constructor taking arguments was removed in favor of a builder patten using the default constructor plus parameter setter methods. - * Variable `model` is no longer public. -* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: - * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) - * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. -* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. -* In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. - So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. +## From 1.3 to 1.4 + +In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: + +* Gradient-Boosted Trees + * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. + * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. +* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. ## Previous Spark Versions diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index b521c2f27cd6e..5732bc4c7e79e 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -60,7 +60,7 @@ Model is created using the training set and a mean squared error is calculated f labels and real labels in the test set. {% highlight scala %} -import org.apache.spark.mllib.regression.IsotonicRegression +import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") @@ -88,6 +88,10 @@ val predictionAndLabel = test.map { point => // Calculate mean squared error between predicted and real labels. val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean() println("Mean Squared Error = " + meanSquaredError) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = IsotonicRegressionModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -150,6 +154,10 @@ Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( ).rdd()).mean(); System.out.println("Mean Squared Error = " + meanSquaredError); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 2b2be4d9d0273..3927d65fbf8fb 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -163,11 +163,8 @@ object, and make predictions with the resulting model to compute the training error. {% highlight scala %} -import org.apache.spark.SparkContext import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLUtils // Load training data in LIBSVM format. @@ -231,15 +228,13 @@ calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given bellow: {% highlight java %} -import java.util.Random; - import scala.Tuple2; import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.*; import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.linalg.Vector; + import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.SparkConf; @@ -282,8 +277,8 @@ public class SVMClassifier { System.out.println("Area under ROC = " + auROC); // Save and load model - model.save(sc.sc(), "myModelPath"); - SVMModel sameModel = SVMModel.load(sc.sc(), "myModelPath"); + model.save(sc, "myModelPath"); + SVMModel sameModel = SVMModel.load(sc, "myModelPath"); } } {% endhighlight %} @@ -315,15 +310,12 @@ a dependency.
    -The following example shows how to load a sample dataset, build Logistic Regression model, +The following example shows how to load a sample dataset, build SVM model, and make predictions with the resulting model to compute the training error. -Note that the Python API does not yet support model save/load but will in the future. - {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithSGD +from pyspark.mllib.classification import SVMWithSGD, SVMModel from pyspark.mllib.regression import LabeledPoint -from numpy import array # Load and parse the data def parsePoint(line): @@ -334,12 +326,16 @@ data = sc.textFile("data/mllib/sample_svm_data.txt") parsedData = data.map(parsePoint) # Build the model -model = LogisticRegressionWithSGD.train(parsedData) +model = SVMWithSGD.train(parsedData, iterations=100) # Evaluating the model on training data labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = SVMModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -503,7 +499,7 @@ Note that the Python API does not yet support multiclass classification and mode will in the future. {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel from pyspark.mllib.regression import LabeledPoint from numpy import array @@ -522,6 +518,10 @@ model = LogisticRegressionWithLBFGS.train(parsedData) labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LogisticRegressionModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -672,7 +672,7 @@ values. We compute the mean squared error at the end to evaluate Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel from numpy import array # Load and parse the data @@ -690,6 +690,10 @@ model = LinearRegressionWithSGD.train(parsedData) valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y) / valuesAndPreds.count() print("Mean Squared Error = " + str(MSE)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LinearRegressionModel.load(sc, "myModelPath") {% endhighlight %} @@ -772,6 +776,58 @@ will get better! +
    + +First, we import the necessary classes for parsing our input data and creating the model. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.regression import StreamingLinearRegressionWithSGD +{% endhighlight %} + +Then we make input streams for training and testing data. We assume a StreamingContext `ssc` +has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) +for more info. For this example, we use labeled points in training and testing streams, +but in practice you will likely want to use unlabeled vectors for test data. + +{% highlight python %} +def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + +trainingData = ssc.textFileStream("/training/data/dir").map(parse).cache() +testData = ssc.textFileStream("/testing/data/dir").map(parse) +{% endhighlight %} + +We create our model by initializing the weights to 0 + +{% highlight python %} +numFeatures = 3 +model = StreamingLinearRegressionWithSGD() +model.setInitialWeights([0.0, 0.0, 0.0]) +{% endhighlight %} + +Now we register the streams for training and testing and start the job. + +{% highlight python %} +model.trainOn(trainingData) +print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + +ssc.start() +ssc.awaitTermination() +{% endhighlight %} + +We can now save text files with data to the training or testing folders. +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions +will get better! + +
    + @@ -785,8 +841,7 @@ gradient descent (`stepSize`, `numIterations`, `miniBatchFraction`). For each o all three possible regularizations (none, L1 or L2). For Logistic Regression, [L-BFGS](api/scala/index.html#org.apache.spark.mllib.optimization.LBFGS) -version is implemented under [LogisticRegressionWithLBFGS] -(api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS), and this +version is implemented under [LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS), and this version supports both binary and multinomial Logistic Regression while SGD version only supports binary Logistic Regression. However, L-BFGS version doesn't support L1 regularization but SGD one supports L1 regularization. When L1 regularization is not required, L-BFGS version is strongly diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 4de2d9491ac2b..8df68d81f3c78 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -7,6 +7,22 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). +## From 1.2 to 1.3 + +In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. + +* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. +* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. +* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: + * The constructor taking arguments was removed in favor of a builder pattern using the default constructor plus parameter setter methods. + * Variable `model` is no longer public. +* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: + * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) + * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. +* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. +* In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. + So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. + ## From 1.1 to 1.2 The only API changes in MLlib v1.2 are in diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 9780ea52c4994..e73bd30f3a90a 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -14,14 +14,13 @@ and use it for prediction. MLlib supports [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) -and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). -These models are typically used for [document classification] -(http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). +These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). Within that context, each observation is a document and each feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes). Feature values must be nonnegative. The model type is selected with an optional parameter -"Multinomial" or "Bernoulli" with "Multinomial" as the default. +"multinomial" or "bernoulli" with "multinomial" as the default. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of @@ -35,7 +34,7 @@ sparsity. Since the training data is only used once, it is not necessary to cach [NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements multinomial naive Bayes. It takes an RDD of [LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional -smoothing parameter `lambda` as input, an optional model type parameter (default is Multinomial), and outputs a +smoothing parameter `lambda` as input, an optional model type parameter (default is "multinomial"), and outputs a [NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction. @@ -54,7 +53,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) val training = splits(0) val test = splits(1) -val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial") +val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial") val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() @@ -75,6 +74,8 @@ optionally smoothing parameter `lambda` as input, and output a can be used for evaluation and prediction. {% highlight java %} +import scala.Tuple2; + import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; @@ -82,7 +83,6 @@ import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.classification.NaiveBayes; import org.apache.spark.mllib.classification.NaiveBayesModel; import org.apache.spark.mllib.regression.LabeledPoint; -import scala.Tuple2; JavaRDD training = ... // training set JavaRDD test = ... // test set @@ -119,7 +119,7 @@ used for evaluation and prediction. Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.classification import NaiveBayes +from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel from pyspark.mllib.linalg import Vectors from pyspark.mllib.regression import LabeledPoint @@ -140,6 +140,10 @@ model = NaiveBayes.train(training, 1.0) # Make prediction and test accuracy. predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label)) accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() + +# Save and load model +model.save(sc, "myModelPath") +sameModel = NaiveBayesModel.load(sc, "myModelPath") {% endhighlight %} diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md new file mode 100644 index 0000000000000..42ea2ca81f80d --- /dev/null +++ b/docs/mllib-pmml-model-export.md @@ -0,0 +1,86 @@ +--- +layout: global +title: PMML model export - MLlib +displayTitle: MLlib - PMML model export +--- + +* Table of contents +{:toc} + +## MLlib supported models + +MLlib supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). + +The table below outlines the MLlib models that can be exported to PMML and their equivalent PMML model. + + + + + + + + + + + + + + + + + + + + + + + + + +
    MLlib modelPMML model
    KMeansModelClusteringModel
    LinearRegressionModelRegressionModel (functionName="regression")
    RidgeRegressionModelRegressionModel (functionName="regression")
    LassoModelRegressionModel (functionName="regression")
    SVMModelRegressionModel (functionName="classification" normalizationMethod="none")
    Binary LogisticRegressionModelRegressionModel (functionName="classification" normalizationMethod="logit")
    + +## Examples +
    + +
    +To export a supported `model` (see table above) to PMML, simply call `model.toPMML`. + +Here a complete example of building a KMeansModel and print it out in PMML format: +{% highlight scala %} +import org.apache.spark.mllib.clustering.KMeans +import org.apache.spark.mllib.linalg.Vectors + +// Load and parse the data +val data = sc.textFile("data/mllib/kmeans_data.txt") +val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() + +// Cluster the data into two classes using KMeans +val numClusters = 2 +val numIterations = 20 +val clusters = KMeans.train(parsedData, numClusters, numIterations) + +// Export to PMML +println("PMML Model:\n" + clusters.toPMML) +{% endhighlight %} + +As well as exporting the PMML model to a String (`model.toPMML` as in the example above), you can export the PMML model to other formats: + +{% highlight scala %} +// Export the model to a String in PMML format +clusters.toPMML + +// Export the model to a local file in PMML format +clusters.toPMML("/tmp/kmeans.xml") + +// Export the model to a directory on a distributed file system in PMML format +clusters.toPMML(sc,"/tmp/kmeans") + +// Export the model to the OutputStream in PMML format +clusters.toPMML(System.out) +{% endhighlight %} + +For unsupported models, either you will not find a `.toPMML` method or an `IllegalArgumentException` will be thrown. + +
    + +
    diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 887eae7f4f07b..de5d6485f9b5f 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -283,7 +283,7 @@ approxSample = data.sampleByKey(False, fractions); Hypothesis testing is a powerful tool in statistics to determine whether a result is statistically significant, whether this result occurred by chance or not. MLlib currently supports Pearson's -chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine +chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine whether the goodness of fit or the independence test is conducted. The goodness of fit test requires an input type of `Vector`, whereas the independence test requires a `Matrix` as input. @@ -422,6 +422,41 @@ for i, result in enumerate(featureTestResults): +Additionally, MLlib provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test +for equality of probability distributions. By providing the name of a theoretical distribution +(currently solely supported for the normal distribution) and its parameters, or a function to +calculate the cumulative distribution according to a given theoretical distribution, the user can +test the null hypothesis that their sample is drawn from that distribution. In the case that the +user tests against the normal distribution (`distName="norm"`), but does not provide distribution +parameters, the test initializes to the standard normal distribution and logs an appropriate +message. + +
    +
    +[`Statistics`](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) provides methods to +run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run +and interpret the hypothesis tests. + +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.stat.Statistics._ + +val data: RDD[Double] = ... // an RDD of sample data + +// run a KS test for the sample versus a standard normal distribution +val testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) +println(testResult) // summary of the test including the p-value, test statistic, + // and null hypothesis + // if our p-value indicates significance, we can reject the null hypothesis + +// perform a KS test using a cumulative distribution function of our making +val myCDF: Double => Double = ... +val testResult2 = Statistics.kolmogorovSmirnovTest(data, myCDF) +{% endhighlight %} +
    +
    + + ## Random data generation Random data generation is useful for randomized algorithms, prototyping, and performance testing. diff --git a/docs/monitoring.md b/docs/monitoring.md index 1e0fc150862fb..bcf885fe4e681 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -178,9 +178,9 @@ Note that the history server only displays completed Spark jobs. One way to sign In addition to viewing the metrics in the UI, they are also available as JSON. This gives developers an easy way to create new visualizations and monitoring tools for Spark. The JSON is available for -both running applications, and in the history server. The endpoints are mounted at `/json/v1`. Eg., -for the history server, they would typically be accessible at `http://:18080/json/v1`, and -for a running application, at `http://localhost:4040/json/v1`. +both running applications, and in the history server. The endpoints are mounted at `/api/v1`. Eg., +for the history server, they would typically be accessible at `http://:18080/api/v1`, and +for a running application, at `http://localhost:4040/api/v1`. @@ -228,6 +228,14 @@ for a running application, at `http://localhost:4040/json/v1`. + + + + + + + +
    EndpointMeaning
    /applications/[app-id]/storage/rdd/[rdd-id] Details for the storage status of a given RDD
    /applications/[app-id]/logsDownload the event logs for all attempts of the given application as a zip file
    /applications/[app-id]/[attempt-id]/logsDownload the event logs for the specified attempt of the given application as a zip file
    When running on Yarn, each application has multiple attempts, so `[app-id]` is actually @@ -240,12 +248,12 @@ These endpoints have been strongly versioned to make it easier to develop applic * Individual fields will never be removed for any given endpoint * New endpoints may be added * New fields may be added to existing endpoints -* New versions of the api may be added in the future at a separate endpoint (eg., `json/v2`). New versions are *not* required to be backwards compatible. +* New versions of the api may be added in the future at a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. * Api versions may be dropped, but only after at least one minor release of co-existing with a new api version Note that even when examining the UI of a running applications, the `applications/[app-id]` portion is still required, though there is only one application available. Eg. to see the list of jobs for the -running app, you would go to `http://localhost:4040/json/v1/applications/[app-id]/jobs`. This is to +running app, you would go to `http://localhost:4040/api/v1/applications/[app-id]/jobs`. This is to keep the paths consistent in both modes. # Metrics diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 27816515c5de2..ae712d62746f6 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -41,19 +41,20 @@ In addition, if you wish to access an HDFS cluster, you need to add a dependency artifactId = hadoop-client version = -Finally, you need to import some Spark classes and implicit conversions into your program. Add the following lines: +Finally, you need to import some Spark classes into your program. Add the following lines: {% highlight scala %} import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.SparkConf {% endhighlight %} +(Before Spark 1.3.0, you need to explicitly `import org.apache.spark.SparkContext._` to enable essential implicit conversions.) +
    -Spark {{site.SPARK_VERSION}} works with Java 6 and higher. If you are using Java 8, Spark supports +Spark {{site.SPARK_VERSION}} works with Java 7 and higher. If you are using Java 8, Spark supports [lambda expressions](http://docs.oracle.com/javase/tutorial/java/javaOO/lambdaexpressions.html) for concisely writing functions, otherwise you can use the classes in the [org.apache.spark.api.java.function](api/java/index.html?org/apache/spark/api/java/function/package-summary.html) package. @@ -97,9 +98,9 @@ to your version of HDFS. Some common HDFS version tags are listed on the [Prebuilt packages](http://spark.apache.org/downloads.html) are also available on the Spark homepage for common HDFS versions. -Finally, you need to import some Spark classes into your program. Add the following lines: +Finally, you need to import some Spark classes into your program. Add the following line: -{% highlight scala %} +{% highlight python %} from pyspark import SparkContext, SparkConf {% endhighlight %} @@ -477,7 +478,6 @@ the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` with custom converters.
    - ## RDD Operations @@ -821,11 +821,9 @@ by a key. In Scala, these operations are automatically available on RDDs containing [Tuple2](http://www.scala-lang.org/api/{{site.SCALA_VERSION}}/index.html#scala.Tuple2) objects -(the built-in tuples in the language, created by simply writing `(a, b)`), as long as you -import `org.apache.spark.SparkContext._` in your program to enable Spark's implicit -conversions. The key-value pair operations are available in the +(the built-in tuples in the language, created by simply writing `(a, b)`). The key-value pair operations are available in the [PairRDDFunctions](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions) class, -which automatically wraps around an RDD of tuples if you import the conversions. +which automatically wraps around an RDD of tuples. For example, the following code uses the `reduceByKey` operation on key-value pairs to count how many times each line of text occurs in a file: @@ -916,7 +914,8 @@ The following table lists some of the common transformations supported by Spark. RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.html#pyspark.RDD)) + [Python](api/python/pyspark.html#pyspark.RDD), + [R](api/R/index.html)) and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1029,7 +1028,9 @@ The following table lists some of the common actions supported by Spark. Refer t RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.html#pyspark.RDD)) + [Python](api/python/pyspark.html#pyspark.RDD), + [R](api/R/index.html)) + and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1071,7 +1072,7 @@ for details. saveAsSequenceFile(path)
    (Java and Scala) - Write the elements of the dataset as a Hadoop SequenceFile in a given path in the local filesystem, HDFS or any other Hadoop-supported file system. This is available on RDDs of key-value pairs that either implement Hadoop's Writable interface. In Scala, it is also + Write the elements of the dataset as a Hadoop SequenceFile in a given path in the local filesystem, HDFS or any other Hadoop-supported file system. This is available on RDDs of key-value pairs that implement Hadoop's Writable interface. In Scala, it is also available on types that are implicitly convertible to Writable (Spark includes conversions for basic types like Int, Double, String, etc). @@ -1122,7 +1123,7 @@ ordered data following shuffle then it's possible to use: * `sortBy` to make a globally ordered RDD Operations which can cause a shuffle include **repartition** operations like -[`repartition`](#RepartitionLink), and [`coalesce`](#CoalesceLink), **'ByKey** operations +[`repartition`](#RepartitionLink) and [`coalesce`](#CoalesceLink), **'ByKey** operations (except for counting) like [`groupByKey`](#GroupByLink) and [`reduceByKey`](#ReduceByLink), and **join** operations like [`cogroup`](#CogroupLink) and [`join`](#JoinLink). @@ -1138,14 +1139,16 @@ read the relevant sorted blocks. Certain shuffle operations can consume significant amounts of heap memory since they employ in-memory data structures to organize records before or after transferring them. Specifically, -`reduceByKey` and `aggregateByKey` create these structures on the map side and `'ByKey` operations +`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations generate these on the reduce side. When data does not fit in memory Spark will spill these tables to disk, incurring the additional overhead of disk I/O and increased garbage collection. Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files -are not cleaned up from Spark's temporary storage until Spark is stopped, which means that -long-running Spark jobs may consume available disk space. This is done so the shuffle doesn't need -to be re-computed if the lineage is re-computed. The temporary storage directory is specified by the +are preserved until the corresponding RDDs are no longer used and are garbage collected. +This is done so the shuffle files don't need to be re-created if the lineage is re-computed. +Garbage collection may happen only after a long period time, if the application retains references +to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may +consume a large amount of disk space. The temporary storage directory is specified by the `spark.local.dir` configuration parameter when configuring the Spark context. Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the @@ -1213,9 +1216,11 @@ storage levels is: Compared to MEMORY_ONLY_SER, OFF_HEAP reduces garbage collection overhead and allows executors to be smaller and to share a pool of memory, making it attractive in environments with large heaps or multiple concurrent applications. Furthermore, as the RDDs reside in Tachyon, - the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory + the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory in Tachyon is discardable. Thus, Tachyon does not attempt to reconstruct a block that it evicts - from memory. + from memory. If you plan to use Tachyon as the off heap store, Spark is compatible with Tachyon + out-of-the-box. Please refer to this page + for the suggested version pairings. @@ -1566,7 +1571,8 @@ You can see some [example Spark programs](http://spark.apache.org/examples.html) In addition, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), - [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python)). + [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), + [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)). You can run Java and Scala examples by passing the class name to Spark's `bin/run-example` script; for instance: ./bin/run-example SparkPi @@ -1575,6 +1581,10 @@ For Python examples, use `spark-submit` instead: ./bin/spark-submit examples/src/main/python/pi.py +For R examples, use `spark-submit` instead: + + ./bin/spark-submit examples/src/main/r/dataframe.R + For help on optimizing your programs, the [configuration](configuration.html) and [tuning](tuning.html) guides provide information on best practices. They are especially important for making sure that your data is stored in memory in an efficient format. @@ -1582,4 +1592,4 @@ For help on deploying, the [cluster mode overview](cluster-overview.html) descri in distributed operation and supported cluster managers. Finally, full API documentation is available in -[Scala](api/scala/#org.apache.spark.package), [Java](api/java/) and [Python](api/python/). +[Scala](api/scala/#org.apache.spark.package), [Java](api/java/), [Python](api/python/) and [R](api/R/). diff --git a/docs/quick-start.md b/docs/quick-start.md index 81143da865cf0..bb39e4111f244 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -184,10 +184,10 @@ scala> linesWithSpark.cache() res7: spark.RDD[String] = spark.FilteredRDD@17e51082 scala> linesWithSpark.count() -res8: Long = 15 +res8: Long = 19 scala> linesWithSpark.count() -res9: Long = 15 +res9: Long = 19 {% endhighlight %} It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is @@ -202,10 +202,10 @@ a cluster, as described in the [programming guide](programming-guide.html#initia >>> linesWithSpark.cache() >>> linesWithSpark.count() -15 +19 >>> linesWithSpark.count() -15 +19 {% endhighlight %} It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is @@ -423,14 +423,14 @@ dependencies to `spark-submit` through its `--py-files` argument by packaging th We can run this application using the `bin/spark-submit` script: -{% highlight python %} +{% highlight bash %} # Use spark-submit to run your application $ YOUR_SPARK_HOME/bin/spark-submit \ --master local[4] \ SimpleApp.py ... Lines with a: 46, Lines with b: 23 -{% endhighlight python %} +{% endhighlight %} @@ -444,7 +444,8 @@ Congratulations on running your first Spark application! * Finally, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), - [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python)). + [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), + [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)). You can run them as follows: {% highlight bash %} @@ -453,4 +454,7 @@ You can run them as follows: # For Python examples, use spark-submit directly: ./bin/spark-submit examples/src/main/python/pi.py + +# For R examples, use spark-submit directly: +./bin/spark-submit examples/src/main/r/dataframe.R {% endhighlight %} diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 5f1d6daeb27f0..1f915d8ea1d73 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -184,6 +184,14 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. + +{% highlight scala %} +conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +{% endhighlight %} + +For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. + # Mesos Docker Support Spark can make use of a Mesos Docker containerizer by setting the property `spark.mesos.executor.docker.image` @@ -298,6 +306,20 @@ See the [configuration page](configuration.html) for information on Spark config the final overhead will be this value. + + spark.mesos.constraints + Attribute based constraints to be matched against when accepting resource offers. + + Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes. +
      +
    • Scalar constraints are matched with "less than equal" semantics i.e. value in the constraint must be less than or equal to the value in the resource offer.
    • +
    • Range constraints are matched with "contains" semantics i.e. value in the constraint must be within the resource offer's value.
    • +
    • Set constraints are matched with "subset of" semantics i.e. value in the constraint must be a subset of the resource offer's value.
    • +
    • Text constraints are metched with "equality" semantics i.e. value in the constraint must be exactly equal to the resource offer's value.
    • +
    • In case there is no value present as a part of the constraint any offer with the corresponding attribute will be accepted (without value check).
    • +
    + + # Troubleshooting and Debugging diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 51c1339165024..de22ab557cacf 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -7,6 +7,51 @@ Support for running on [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) was added to Spark in version 0.6.0, and improved in subsequent releases. +# Launching Spark on YARN + +Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. +These configs are used to write to HDFS and connect to the YARN ResourceManager. The +configuration contained in this directory will be distributed to the YARN cluster so that all +containers used by the application use the same configuration. If the configuration references +Java system properties or environment variables not managed by YARN, they should also be set in the +Spark application's configuration (driver, executors, and the AM when running in client mode). + +There are two deploy modes that can be used to launch Spark applications on YARN. In `yarn-cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `yarn-client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. + +Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. +To launch a Spark application in `yarn-cluster` mode: + + `$ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options]` + +For example: + + $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ + --master yarn-cluster \ + --num-executors 3 \ + --driver-memory 4g \ + --executor-memory 2g \ + --executor-cores 1 \ + --queue thequeue \ + lib/spark-examples*.jar \ + 10 + +The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. + +To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. To run spark-shell: + + $ ./bin/spark-shell --master yarn-client + +## Adding Other JARs + +In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. + + $ ./bin/spark-submit --class my.main.Class \ + --master yarn-cluster \ + --jars my-other-jar.jar,my-other-other-jar.jar + my-main-jar.jar + app_arg1 app_arg2 + + # Preparations Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. @@ -17,6 +62,38 @@ To build Spark yourself, refer to [Building Spark](building-spark.html). Most of the configs are the same for Spark on YARN as for other deployment modes. See the [configuration page](configuration.html) for more information on those. These are configs that are specific to Spark on YARN. +# Debugging your Application + +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. + + yarn logs -applicationId + +will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). + +When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. + +To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a +large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` +on the nodes on which containers are launched. This directory contains the launch script, JARs, and +all environment variables used for launching each container. This process is useful for debugging +classpath problems in particular. (Note that enabling this requires admin privileges on cluster +settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). + +To use a custom log4j configuration for the application master or executors, there are two options: + +- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files + to be uploaded with the application. +- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` + (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, + the `file:` protocol should be explicitly provided, and the file needs to exist locally on all + the nodes. + +Note that for the first option, both executors and the application master will share the same +log4j configuration, which may cause issues when they run on the same node (e.g. trying to write +to the same log file). + +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. + #### Spark Properties @@ -50,8 +127,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -71,9 +148,22 @@ Most of the configs are the same for Spark on YARN as for other deployment modes - + + + + + + @@ -176,8 +266,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -193,7 +283,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -229,85 +319,50 @@ Most of the configs are the same for Spark on YARN as for other deployment modes running against earlier versions, this property will be ignored. + + + + + + + + + + + + + + + + + + + +
    spark.yarn.am.waitTime 100s - In yarn-cluster mode, time for the application master to wait for the - SparkContext to be initialized. In yarn-client mode, time for the application master to wait + In `yarn-cluster` mode, time for the application master to wait for the + SparkContext to be initialized. In `yarn-client` mode, time for the application master to wait for the driver to connect to it.
    spark.yarn.scheduler.heartbeat.interval-ms50003000 The interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. + The value is capped at half the value of YARN's configuration for the expiry interval + (yarn.am.liveness-monitor.expiry-interval-ms). +
    spark.yarn.scheduler.initial-allocation.interval200ms + The initial interval in which the Spark application master eagerly heartbeats to the YARN ResourceManager + when there are pending container allocation requests. It should be no larger than + spark.yarn.scheduler.heartbeat.interval-ms. The allocation interval will doubled on + successive eager heartbeats if pending containers still exist, until + spark.yarn.scheduler.heartbeat.interval-ms is reached.
    Add the environment variable specified by EnvironmentVariableName to the Application Master process launched on YARN. The user can specify multiple of - these and to set multiple environment variables. In yarn-cluster mode this controls - the environment of the SPARK driver and in yarn-client mode it only controls + these and to set multiple environment variables. In `yarn-cluster` mode this controls + the environment of the SPARK driver and in `yarn-client` mode it only controls the environment of the executor launcher.
    (none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use spark.driver.extraJavaOptions instead. + In cluster mode, use `spark.driver.extraJavaOptions` instead.
    spark.yarn.keytab(none) + The full path to the file that contains the keytab for the principal specified above. + This keytab will be copied to the node running the Application Master via the Secure Distributed Cache, + for renewing the login tickets and the delegation tokens periodically. +
    spark.yarn.principal(none) + Principal to be used to login to KDC, while running on secure HDFS. +
    spark.yarn.config.gatewayPath(none) + A path that is valid on the gateway host (the host where a Spark application is started) but may + differ for paths for the same resource in other nodes in the cluster. Coupled with + spark.yarn.config.replacementPath, this is used to support clusters with + heterogeneous configurations, so that Spark can correctly launch remote processes. +

    + The replacement path normally will contain a reference to some environment variable exported by + YARN (and, thus, visible to Spark containers). +

    + For example, if the gateway node has Hadoop libraries installed on /disk1/hadoop, and + the location of the Hadoop install is exported by YARN as the HADOOP_HOME + environment variable, setting this value to /disk1/hadoop and the replacement path to + $HADOOP_HOME will make sure that paths used to launch remote processes properly + reference the local YARN configuration. +

    spark.yarn.config.replacementPath(none) + See spark.yarn.config.gatewayPath. +
    -# Launching Spark on YARN - -Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. -These configs are used to write to the dfs and connect to the YARN ResourceManager. The -configuration contained in this directory will be distributed to the YARN cluster so that all -containers used by the application use the same configuration. If the configuration references -Java system properties or environment variables not managed by YARN, they should also be set in the -Spark application's configuration (driver, executors, and the AM when running in client mode). - -There are two deploy modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. - -Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the "master" parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the master parameter is simply "yarn-client" or "yarn-cluster". - -To launch a Spark application in yarn-cluster mode: - - ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] - -For example: - - $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ - --num-executors 3 \ - --driver-memory 4g \ - --executor-memory 2g \ - --executor-cores 1 \ - --queue thequeue \ - lib/spark-examples*.jar \ - 10 - -The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. - -To launch a Spark application in yarn-client mode, do the same, but replace "yarn-cluster" with "yarn-client". To run spark-shell: - - $ ./bin/spark-shell --master yarn-client - -## Adding Other JARs - -In yarn-cluster mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. - - $ ./bin/spark-submit --class my.main.Class \ - --master yarn-cluster \ - --jars my-other-jar.jar,my-other-other-jar.jar - my-main-jar.jar - app_arg1 app_arg2 - -# Debugging your Application - -In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. - - yarn logs -applicationId - -will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). - -When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. - -To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a -large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` -on the nodes on which containers are launched. This directory contains the launch script, JARs, and -all environment variables used for launching each container. This process is useful for debugging -classpath problems in particular. (Note that enabling this requires admin privileges on cluster -settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). - -To use a custom log4j configuration for the application master or executors, there are two options: - -- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files - to be uploaded with the application. -- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` - (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, - the `file:` protocol should be explicitly provided, and the file needs to exist locally on all - the nodes. - -Note that for the first option, both executors and the application master will share the same -log4j configuration, which may cause issues when they run on the same node (e.g. trying to write -to the same log file). - -If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. - # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 0eed9adacf123..4f71fbc086cd0 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -24,7 +24,7 @@ the master's web UI, which is [http://localhost:8080](http://localhost:8080) by Similarly, you can start one or more workers and connect them to the master via: - ./sbin/start-slave.sh + ./sbin/start-slave.sh Once you have started a worker, look at the master's web UI ([http://localhost:8080](http://localhost:8080) by default). You should see the new node listed there, along with its number of CPUs and memory (minus one gigabyte left for the OS). @@ -77,7 +77,7 @@ Note, the master machine accesses each of the worker machines via ssh. By defaul If you do not have a password-less setup, you can set the environment variable SPARK_SSH_FOREGROUND and serially provide a password for each worker. -Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`: +Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/sbin`: - `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on. - `sbin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file. diff --git a/docs/sparkr.md b/docs/sparkr.md new file mode 100644 index 0000000000000..4385a4eeacd5c --- /dev/null +++ b/docs/sparkr.md @@ -0,0 +1,232 @@ +--- +layout: global +displayTitle: SparkR (R on Spark) +title: SparkR (R on Spark) +--- + +* This will become a table of contents (this text will be scraped). +{:toc} + +# Overview +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. +In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that +supports operations like selection, filtering, aggregation etc. (similar to R data frames, +[dplyr](https://github.com/hadley/dplyr)) but on large datasets. + +# SparkR DataFrames + +A DataFrame is a distributed collection of data organized into named columns. It is conceptually +equivalent to a table in a relational database or a data frame in R, but with richer +optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: +structured data files, tables in Hive, external databases, or existing local R data frames. + +All of the examples on this page use sample data included in R or the Spark distribution and can be run using the `./bin/sparkR` shell. + +## Starting Up: SparkContext, SQLContext + +
    +The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster. +You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name +, any spark packages depended on, etc. Further, to work with DataFrames we will need a `SQLContext`, +which can be created from the SparkContext. If you are working from the SparkR shell, the +`SQLContext` and `SparkContext` should already be created for you. + +{% highlight r %} +sc <- sparkR.init() +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} + +
    + +## Creating DataFrames +With a `SQLContext`, applications can create `DataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). + +### From local data frames +The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R. + +
    +{% highlight r %} +df <- createDataFrame(sqlContext, faithful) + +# Displays the content of the DataFrame to stdout +head(df) +## eruptions waiting +##1 3.600 79 +##2 1.800 54 +##3 3.333 74 + +{% endhighlight %} +
    + +### From Data Sources + +SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. + +The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). These packages can either be added by +specifying `--packages` with `spark-submit` or `sparkR` commands, or if creating context through `init` +you can specify the packages with the `packages` argument. + +
    +{% highlight r %} +sc <- sparkR.init(sparkPackages="com.databricks:spark-csv_2.11:1.0.3") +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} +
    + +We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. + +
    + +{% highlight r %} +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") +head(people) +## age name +##1 NA Michael +##2 30 Andy +##3 19 Justin + +# SparkR automatically infers the schema from the JSON file +printSchema(people) +# root +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) + +{% endhighlight %} +
    + +The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example +to a Parquet file using `write.df` + +
    +{% highlight r %} +write.df(people, path="people.parquet", source="parquet", mode="overwrite") +{% endhighlight %} +
    + +### From Hive tables + +You can also create SparkR DataFrames from Hive tables. To do this we will need to create a HiveContext which can access tables in the Hive MetaStore. Note that Spark should have been built with [Hive support](building-spark.html#building-with-hive-and-jdbc-support) and more details on the difference between SQLContext and HiveContext can be found in the [SQL programming guide](sql-programming-guide.html#starting-point-sqlcontext). + +
    +{% highlight r %} +# sc is an existing SparkContext. +hiveContext <- sparkRHive.init(sc) + +sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results <- sql(hiveContext, "FROM src SELECT key, value") + +# results is now a DataFrame +head(results) +## key value +## 1 238 val_238 +## 2 86 val_86 +## 3 311 val_311 + +{% endhighlight %} +
    + +## DataFrame Operations + +SparkR DataFrames support a number of functions to do structured data processing. +Here we include some basic examples and a complete list can be found in the [API](api/R/index.html) docs: + +### Selecting rows, columns + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, faithful) + +# Get basic information about the DataFrame +df +## DataFrame[eruptions:double, waiting:double] + +# Select only the "eruptions" column +head(select(df, df$eruptions)) +## eruptions +##1 3.600 +##2 1.800 +##3 3.333 + +# You can also pass in column name as strings +head(select(df, "eruptions")) + +# Filter the DataFrame to only retain rows with wait times shorter than 50 mins +head(filter(df, df$waiting < 50)) +## eruptions waiting +##1 1.750 47 +##2 1.750 47 +##3 1.867 48 + +{% endhighlight %} + +
    + +### Grouping, Aggregation + +SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below + +
    +{% highlight r %} + +# We use the `n` operator to count the number of times each waiting time appears +head(summarize(groupBy(df, df$waiting), count = n(df$waiting))) +## waiting count +##1 81 13 +##2 60 6 +##3 68 1 + +# We can also sort the output from the aggregation to get the most common waiting times +waiting_counts <- summarize(groupBy(df, df$waiting), count = n(df$waiting)) +head(arrange(waiting_counts, desc(waiting_counts$count))) + +## waiting count +##1 78 15 +##2 83 14 +##3 81 13 + +{% endhighlight %} +
    + +### Operating on Columns + +SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. + +
    +{% highlight r %} + +# Convert waiting time from hours to seconds. +# Note that we can assign this to a new column in the same DataFrame +df$waiting_secs <- df$waiting * 60 +head(df) +## eruptions waiting waiting_secs +##1 3.600 79 4740 +##2 1.800 54 3240 +##3 3.333 74 4440 + +{% endhighlight %} +
    + +## Running SQL Queries from SparkR +A SparkR DataFrame can also be registered as a temporary table in Spark SQL and registering a DataFrame as a table allows you to run SQL queries over its data. +The `sql` function enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. + +
    +{% highlight r %} +# Load a JSON file +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") + +# Register this DataFrame as a table. +registerTempTable(people, "people") + +# SQL statements can be run by using the sql method +teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") +head(teenagers) +## name +##1 Justin + +{% endhighlight %} +
    diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 78b8e8ad515a0..5838bc172fe86 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -11,17 +11,18 @@ title: Spark SQL and DataFrames Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. +For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section. # DataFrames A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs. -The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), and [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame). +The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). -All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell` or the `pyspark` shell. +All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. -## Starting Point: `SQLContext` +## Starting Point: SQLContext
    @@ -64,6 +65,17 @@ from pyspark.sql import SQLContext sqlContext = SQLContext(sc) {% endhighlight %} +
    + +
    + +The entry point into all relational functionality in Spark is the +`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. + +{% highlight r %} +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} +
    @@ -97,7 +109,7 @@ As an example, the following creates a `DataFrame` based on the content of a JSO val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -val df = sqlContext.jsonFile("examples/src/main/resources/people.json") +val df = sqlContext.read.json("examples/src/main/resources/people.json") // Displays the content of the DataFrame to stdout df.show() @@ -110,7 +122,7 @@ df.show() JavaSparkContext sc = ...; // An existing JavaSparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); -DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); // Displays the content of the DataFrame to stdout df.show(); @@ -123,13 +135,26 @@ df.show(); from pyspark.sql import SQLContext sqlContext = SQLContext(sc) -df = sqlContext.jsonFile("examples/src/main/resources/people.json") +df = sqlContext.read.json("examples/src/main/resources/people.json") # Displays the content of the DataFrame to stdout df.show() {% endhighlight %} + +
    +{% highlight r %} +sqlContext <- SQLContext(sc) + +df <- jsonFile(sqlContext, "examples/src/main/resources/people.json") + +# Displays the content of the DataFrame to stdout +showDF(df) +{% endhighlight %} + +
    + @@ -146,7 +171,7 @@ val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame -val df = sqlContext.jsonFile("examples/src/main/resources/people.json") +val df = sqlContext.read.json("examples/src/main/resources/people.json") // Show the content of the DataFrame df.show() @@ -196,7 +221,7 @@ JavaSparkContext sc // An existing SparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame -DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); // Show the content of the DataFrame df.show(); @@ -252,7 +277,7 @@ from pyspark.sql import SQLContext sqlContext = SQLContext(sc) # Create the DataFrame -df = sqlContext.jsonFile("examples/src/main/resources/people.json") +df = sqlContext.read.json("examples/src/main/resources/people.json") # Show the content of the DataFrame df.show() @@ -296,6 +321,57 @@ df.groupBy("age").count().show() {% endhighlight %} + +
    +{% highlight r %} +sqlContext <- sparkRSQL.init(sc) + +# Create the DataFrame +df <- jsonFile(sqlContext, "examples/src/main/resources/people.json") + +# Show the content of the DataFrame +showDF(df) +## age name +## null Michael +## 30 Andy +## 19 Justin + +# Print the schema in a tree format +printSchema(df) +## root +## |-- age: long (nullable = true) +## |-- name: string (nullable = true) + +# Select only the "name" column +showDF(select(df, "name")) +## name +## Michael +## Andy +## Justin + +# Select everybody, but increment the age by 1 +showDF(select(df, df$name, df$age + 1)) +## name (age + 1) +## Michael null +## Andy 31 +## Justin 20 + +# Select people older than 21 +showDF(where(df, df$age > 21)) +## age name +## 30 Andy + +# Count people by age +showDF(count(groupBy(df, "age"))) +## age count +## null 1 +## 19 1 +## 30 1 + +{% endhighlight %} + +
    + @@ -325,6 +401,14 @@ sqlContext = SQLContext(sc) df = sqlContext.sql("SELECT * FROM table") {% endhighlight %} + +
    +{% highlight r %} +sqlContext <- sparkRSQL.init(sc) +df <- sql(sqlContext, "SELECT * FROM table") +{% endhighlight %} +
    + @@ -693,8 +777,8 @@ In the simplest form, the default data source (`parquet` unless otherwise config
    {% highlight scala %} -val df = sqlContext.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").save("namesAndFavColors.parquet") +val df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") {% endhighlight %}
    @@ -703,8 +787,8 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet") {% highlight java %} -DataFrame df = sqlContext.load("examples/src/main/resources/users.parquet"); -df.select("name", "favorite_color").save("namesAndFavColors.parquet"); +DataFrame df = sqlContext.read().load("examples/src/main/resources/users.parquet"); +df.select("name", "favorite_color").write().save("namesAndFavColors.parquet"); {% endhighlight %} @@ -714,11 +798,20 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet"); {% highlight python %} -df = sqlContext.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").save("namesAndFavColors.parquet") +df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") {% endhighlight %} + + +
    + +{% highlight r %} +df <- loadDF(sqlContext, "people.parquet") +saveDF(select(df, "name", "age"), "namesAndAges.parquet") +{% endhighlight %} +
    @@ -726,16 +819,16 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet") You can also manually specify the data source that will be used along with any extra options that you would like to pass to the data source. Data sources are specified by their fully qualified -name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use the shorted -name (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types +name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short +names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types using this syntax.
    {% highlight scala %} -val df = sqlContext.load("examples/src/main/resources/people.json", "json") -df.select("name", "age").save("namesAndAges.parquet", "parquet") +val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") +df.select("name", "age").write.format("parquet").save("namesAndAges.parquet") {% endhighlight %}
    @@ -744,8 +837,8 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet") {% highlight java %} -DataFrame df = sqlContext.load("examples/src/main/resources/people.json", "json"); -df.select("name", "age").save("namesAndAges.parquet", "parquet"); +DataFrame df = sqlContext.read().format("json").load("examples/src/main/resources/people.json"); +df.select("name", "age").write().format("parquet").save("namesAndAges.parquet"); {% endhighlight %} @@ -755,8 +848,18 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet"); {% highlight python %} -df = sqlContext.load("examples/src/main/resources/people.json", "json") -df.select("name", "age").save("namesAndAges.parquet", "parquet") +df = sqlContext.read.load("examples/src/main/resources/people.json", format="json") +df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") + +{% endhighlight %} + +
    +
    + +{% highlight r %} + +df <- loadDF(sqlContext, "people.json", "json") +saveDF(select(df, "name", "age"), "namesAndAges.parquet", "parquet") {% endhighlight %} @@ -804,7 +907,7 @@ new data. Ignore mode means that when saving a DataFrame to a data source, if data already exists, the save operation is expected to not save the contents of the DataFrame and to not - change the existing data. This is similar to a `CREATE TABLE IF NOT EXISTS` in SQL. + change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL. @@ -844,11 +947,11 @@ import sqlContext.implicits._ val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. // The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. -people.saveAsParquetFile("people.parquet") +people.write.parquet("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a Parquet file is also a DataFrame. -val parquetFile = sqlContext.parquetFile("people.parquet") +val parquetFile = sqlContext.read.parquet("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile") @@ -866,13 +969,13 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println) DataFrame schemaPeople = ... // The DataFrame from the previous example. // DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet"); +schemaPeople.write().parquet("people.parquet"); // Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. -DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); +DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); -//Parquet files can also be registered as tables and then used in SQL statements. +// Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); List teenagerNames = teenagers.javaRDD().map(new Function() { @@ -892,11 +995,11 @@ List teenagerNames = teenagers.javaRDD().map(new Function() schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet") +schemaPeople.write.parquet("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. -parquetFile = sqlContext.parquetFile("people.parquet") +parquetFile = sqlContext.read.parquet("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -908,6 +1011,40 @@ for teenName in teenNames.collect():
    +
    + +{% highlight r %} +# sqlContext from the previous example is used in this example. + +schemaPeople # The DataFrame from the previous example. + +# DataFrames can be saved as Parquet files, maintaining the schema information. +saveAsParquetFile(schemaPeople, "people.parquet") + +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# The result of loading a parquet file is also a DataFrame. +parquetFile <- parquetFile(sqlContext, "people.parquet") + +# Parquet files can also be registered as tables and then used in SQL statements. +registerTempTable(parquetFile, "parquetFile"); +teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") +teenNames <- map(teenagers, function(p) { paste("Name:", p$name)}) +for (teenName in collect(teenNames)) { + cat(teenName, "\n") +} +{% endhighlight %} + +
    + +
    + +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.sql("REFRESH TABLE my_table") +{% endhighlight %} + +
    +
    {% highlight sql %} @@ -926,12 +1063,12 @@ SELECT * FROM parquetTable
    -### Partition discovery +### Partition Discovery Table partitioning is a common optimization approach used in systems like Hive. In a partitioned table, data are usually stored in different directories, with partitioning column values encoded in the path of each partition directory. The Parquet data source is now able to discover and infer -partitioning information automatically. For exmaple, we can store all our previously used +partitioning information automatically. For example, we can store all our previously used population data into a partitioned table using the following directory structure, with two extra columns, `gender` and `country` as partitioning columns: @@ -959,9 +1096,9 @@ path {% endhighlight %} -By passing `path/to/table` to either `SQLContext.parquetFile` or `SQLContext.load`, Spark SQL will -automatically extract the partitioning information from the paths. Now the schema of the returned -DataFrame becomes: +By passing `path/to/table` to either `SQLContext.read.parquet` or `SQLContext.read.load`, Spark SQL +will automatically extract the partitioning information from the paths. +Now the schema of the returned DataFrame becomes: {% highlight text %} @@ -974,9 +1111,13 @@ root {% endhighlight %} Notice that the data types of the partitioning columns are automatically inferred. Currently, -numeric data types and string type are supported. +numeric data types and string type are supported. Sometimes users may not want to automatically +infer the data types of the partitioning columns. For these use cases, the automatic type inference +can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to +`true`. When type inference is disabled, string type will be used for the partitioning columns. -### Schema merging + +### Schema Merging Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with a simple schema, and gradually add more columns to the schema as needed. In this way, users may end @@ -993,20 +1134,20 @@ source is now able to automatically detect this case and merge schemas of all th import sqlContext.implicits._ // Create a simple DataFrame, stored into a partition directory -val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") -df1.saveAsParquetFile("data/test_table/key=1") +val df1 = sc.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") +df1.write.parquet("data/test_table/key=1") // Create another DataFrame in a new partition directory, // adding a new column and dropping an existing column -val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") -df2.saveAsParquetFile("data/test_table/key=2") +val df2 = sc.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") +df2.write.parquet("data/test_table/key=2") // Read the partitioned table -val df3 = sqlContext.parquetFile("data/test_table") +val df3 = sqlContext.read.parquet("data/test_table") df3.printSchema() // The final schema consists of all 3 columns in the Parquet files together -// with the partiioning column appeared in the partition directory paths. +// with the partitioning column appeared in the partition directory paths. // root // |-- single: int (nullable = true) // |-- double: int (nullable = true) @@ -1033,11 +1174,38 @@ df2 = sqlContext.createDataFrame(sc.parallelize(range(6, 11)) df2.save("data/test_table/key=2", "parquet") # Read the partitioned table -df3 = sqlContext.parquetFile("data/test_table") +df3 = sqlContext.load("data/test_table", "parquet") df3.printSchema() # The final schema consists of all 3 columns in the Parquet files together -# with the partiioning column appeared in the partition directory paths. +# with the partitioning column appeared in the partition directory paths. +# root +# |-- single: int (nullable = true) +# |-- double: int (nullable = true) +# |-- triple: int (nullable = true) +# |-- key : int (nullable = true) +{% endhighlight %} + + + +
    + +{% highlight r %} +# sqlContext from the previous example is used in this example. + +# Create a simple DataFrame, stored into a partition directory +saveDF(df1, "data/test_table/key=1", "parquet", "overwrite") + +# Create another DataFrame in a new partition directory, +# adding a new column and dropping an existing column +saveDF(df2, "data/test_table/key=2", "parquet", "overwrite") + +# Read the partitioned table +df3 <- loadDF(sqlContext, "data/test_table", "parquet") +printSchema(df3) + +# The final schema consists of all 3 columns in the Parquet files together +# with the partitioning column appeared in the partition directory paths. # root # |-- single: int (nullable = true) # |-- double: int (nullable = true) @@ -1049,6 +1217,79 @@ df3.printSchema()
    +### Hive metastore Parquet table conversion + +When reading from and writing to Hive metastore Parquet tables, Spark SQL will try to use its own +Parquet support instead of Hive SerDe for better performance. This behavior is controlled by the +`spark.sql.hive.convertMetastoreParquet` configuration, and is turned on by default. + +#### Hive/Parquet Schema Reconciliation + +There are two key differences between Hive and Parquet from the perspective of table schema +processing. + +1. Hive is case insensitive, while Parquet is not +1. Hive considers all columns nullable, while nullability in Parquet is significant + +Due to this reason, we must reconcile Hive metastore schema with Parquet schema when converting a +Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: + +1. Fields that have the same name in both schema must have the same data type regardless of + nullability. The reconciled field should have the data type of the Parquet side, so that + nullability is respected. + +1. The reconciled schema contains exactly those fields defined in Hive metastore schema. + + - Any fields that only appear in the Parquet schema are dropped in the reconciled schema. + - Any fileds that only appear in the Hive metastore schema are added as nullable field in the + reconciled schema. + +#### Metadata Refreshing + +Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table +conversion is enabled, metadata of those converted tables are also cached. If these tables are +updated by Hive or other external tools, you need to refresh them manually to ensure consistent +metadata. + +
    + +
    + +{% highlight scala %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
    + +
    + +{% highlight java %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
    + +
    + +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
    + +
    + +{% highlight sql %} +REFRESH TABLE my_table; +{% endhighlight %} + +
    + +
    + ### Configuration Configuration of Parquet can be done using the `setConf` method on `SQLContext` or by running @@ -1061,7 +1302,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` false Some other Parquet-producing systems, in particular Impala and older versions of Spark SQL, do - not differentiate between binary data and strings when writing out the Parquet schema. This + not differentiate between binary data and strings when writing out the Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. @@ -1078,7 +1319,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.cacheMetadata true - Turns on caching of Parquet schema metadata. Can speed up querying of static data. + Turns on caching of Parquet schema metadata. Can speed up querying of static data. @@ -1094,7 +1335,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` false Turn on Parquet filter pushdown optimization. This feature is turned off by default because of a known - bug in Paruet 1.6.0rc3 (PARQUET-136). + bug in Parquet 1.6.0rc3 (PARQUET-136). However, if your table doesn't contain any nullable string or binary columns, it's still safe to turn this feature on. @@ -1107,6 +1348,34 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` support. + + spark.sql.parquet.output.committer.class + org.apache.parquet.hadoop.
    ParquetOutputCommitter
    + +

    + The output committer class used by Parquet. The specified class needs to be a subclass of + org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a + subclass of org.apache.parquet.hadoop.ParquetOutputCommitter. +

    +

    + Note: +

      +
    • + This option must be set via Hadoop Configuration rather than Spark + SQLConf. +
    • +
    • + This option overrides spark.sql.sources.
      outputCommitterClass
      . +
    • +
    +

    +

    + Spark SQL comes with a builtin + org.apache.spark.sql.
    parquet.DirectParquetOutputCommitter
    , which can be more + efficient then the default Parquet output committer when writing data to S3. +

    + + ## JSON Datasets @@ -1114,12 +1383,10 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: +This conversion can be done using `SQLContext.read.json()` on either an RDD of String, +or a JSON file. -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. - -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1130,8 +1397,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. val path = "examples/src/main/resources/people.json" -// Create a DataFrame from the file(s) pointed to by path -val people = sqlContext.jsonFile(path) +val people = sqlContext.read.json(path) // The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -1149,19 +1415,17 @@ val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age // an RDD[String] storing one JSON object per string. val anotherPeopleRDD = sc.parallelize( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) -val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) +val anotherPeople = sqlContext.read.json(anotherPeopleRDD) {% endhighlight %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext` : - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +This conversion can be done using `SQLContext.read().json()` on either an RDD of String, +or a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1171,9 +1435,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. -String path = "examples/src/main/resources/people.json"; -// Create a DataFrame from the file(s) pointed to by path -DataFrame people = sqlContext.jsonFile(path); +DataFrame people = sqlContext.read().json("examples/src/main/resources/people.json"); // The inferred schema can be visualized using the printSchema() method. people.printSchema(); @@ -1192,18 +1454,15 @@ DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AN List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); -DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); +DataFrame anotherPeople = sqlContext.read().json(anotherPeopleRDD); {% endhighlight %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +This conversion can be done using `SQLContext.read.json` on a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1214,9 +1473,7 @@ sqlContext = SQLContext(sc) # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. -path = "examples/src/main/resources/people.json" -# Create a DataFrame from the file(s) pointed to by path -people = sqlContext.jsonFile(path) +people = sqlContext.read.json("examples/src/main/resources/people.json") # The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -1238,6 +1495,39 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) {% endhighlight %}
    +
    +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. using +the `jsonFile` function, which loads data from a directory of JSON files where each line of the +files is a JSON object. + +Note that the file that is offered as _a json file_ is not a typical JSON file. Each +line must contain a separate, self-contained valid JSON object. As a consequence, +a regular multi-line JSON file will most often fail. + +{% highlight r %} +# sc is an existing SparkContext. +sqlContext <- sparkRSQL.init(sc) + +# A JSON dataset is pointed to by path. +# The path can be either a single text file or a directory storing text files. +path <- "examples/src/main/resources/people.json" +# Create a DataFrame from the file(s) pointed to by path +people <- jsonFile(sqlContext, path) + +# The inferred schema can be visualized using the printSchema() method. +printSchema(people) +# root +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) + +# Register this DataFrame as a table. +registerTempTable(people, "people") + +# SQL statements can be run by using the sql methods provided by `sqlContext`. +teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") +{% endhighlight %} +
    +
    {% highlight sql %} @@ -1265,7 +1555,12 @@ This command builds a new assembly jar that includes Hive. Note that this Hive a on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running +the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the +YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the +`spark-submit` command. +
    @@ -1294,12 +1589,12 @@ sqlContext.sql("FROM src SELECT key, value").collect().foreach(println) When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be +the `sql` method a `HiveContext` also provides an `hql` method, which allows queries to be expressed in HiveQL. {% highlight java %} // sc is an existing JavaSparkContext. -HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc); +HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc.sc); sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); @@ -1314,10 +1609,7 @@ Row[] results = sqlContext.sql("FROM src SELECT key, value").collect();
    When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be -expressed in HiveQL. - +adds support for finding tables in the MetaStore and writing queries using HiveQL. {% highlight python %} # sc is an existing SparkContext. from pyspark.sql import HiveContext @@ -1331,9 +1623,91 @@ results = sqlContext.sql("FROM src SELECT key, value").collect() {% endhighlight %} +
    + +
    + +When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and +adds support for finding tables in the MetaStore and writing queries using HiveQL. +{% highlight r %} +# sc is an existing SparkContext. +sqlContext <- sparkRHive.init(sc) + +sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results <- collect(sql(sqlContext, "FROM src SELECT key, value")) + +{% endhighlight %} +
    +### Interacting with Different Versions of Hive Metastore + +One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. + +Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET` +and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive +jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the +version specified by users. An isolated classloader is used here to avoid dependency conflicts. + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.sql.hive.metastore.version0.13.1 + Version of the Hive metastore. Available + options are 0.12.0 and 0.13.1. Support for more versions is coming in the future. +
    spark.sql.hive.metastore.jarsbuiltin + Location of the jars that should be used to instantiate the HiveMetastoreClient. This + property can be one of three options: +
      +
    1. builtin
    2. + Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is + enabled. When this option is chosen, spark.sql.hive.metastore.version must be + either 0.13.1 or not defined. +
    3. maven
    4. + Use Hive jars of specified version downloaded from Maven repositories. +
    5. A classpath in the standard format for both Hive and Hadoop.
    6. +
    +
    spark.sql.hive.metastore.sharedPrefixescom.mysql.jdbc,
    org.postgresql,
    com.microsoft.sqlserver,
    oracle.jdbc
    +

    + A comma separated list of class prefixes that should be loaded using the classloader that is + shared between Spark SQL and a specific version of Hive. An example of classes that should + be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need + to be shared are those that interact with classes that are already shared. For example, + custom appenders that are used by log4j. +

    +
    spark.sql.hive.metastore.barrierPrefixes(empty) +

    + A comma separated list of class prefixes that should explicitly be reloaded for each version + of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + prefix that typically would be shared (i.e. org.apache.spark.*). +

    +
    + + ## JDBC To Other Databases Spark SQL also includes a data source that can read data from other databases using JDBC. This @@ -1367,7 +1741,7 @@ the Data Sources API. The following options are supported: dbtable - The JDBC table that should be read. Note that anything that is valid in a `FROM` clause of + The JDBC table that should be read. Note that anything that is valid in a FROM clause of a SQL query can be used. For example, instead of a full table you could also use a subquery in parentheses. @@ -1399,9 +1773,9 @@ the Data Sources API. The following options are supported:
    {% highlight scala %} -val jdbcDF = sqlContext.load("jdbc", Map( - "url" -> "jdbc:postgresql:dbserver", - "dbtable" -> "schema.tablename")) +val jdbcDF = sqlContext.read.format("jdbc").options( + Map("url" -> "jdbc:postgresql:dbserver", + "dbtable" -> "schema.tablename")).load() {% endhighlight %}
    @@ -1414,7 +1788,7 @@ Map options = new HashMap(); options.put("url", "jdbc:postgresql:dbserver"); options.put("dbtable", "schema.tablename"); -DataFrame jdbcDF = sqlContext.load("jdbc", options) +DataFrame jdbcDF = sqlContext.read().format("jdbc"). options(options).load(); {% endhighlight %} @@ -1424,7 +1798,17 @@ DataFrame jdbcDF = sqlContext.load("jdbc", options) {% highlight python %} -df = sqlContext.load(source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") +df = sqlContext.read.format('jdbc').options(url='jdbc:postgresql:dbserver', dbtable='schema.tablename').load() + +{% endhighlight %} + +
    + +
    + +{% highlight r %} + +df <- loadDF(sqlContext, source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") {% endhighlight %} @@ -1501,7 +1885,7 @@ that these options will be deprecated in future release as more optimizations ar Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently statistics are only supported for Hive Metastore tables where the command - `ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan` has been run. + ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run. @@ -1520,11 +1904,20 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. + + spark.sql.planner.externalSort + false + + When true, performs sorts spilling to disk as needed otherwise sort each partition in memory. + + # Distributed SQL Engine -Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, without the need to write any code. +Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. +In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, +without the need to write any code. ## Running the Thrift JDBC/ODBC server @@ -1538,7 +1931,7 @@ To start the JDBC/ODBC server, run the following in the Spark directory: This script accepts all `bin/spark-submit` command line options, plus a `--hiveconf` option to specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of all available options. By default, the server listens on localhost:10000. You may override this -bahaviour via either environment variables, i.e.: +behaviour via either environment variables, i.e.: {% highlight bash %} export HIVE_SERVER2_THRIFT_PORT= @@ -1603,6 +1996,25 @@ options. ## Upgrading from Spark SQL 1.3 to 1.4 +#### DataFrame data reader/writer interface + +Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`) +and writing data out (`DataFrame.write`), +and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`). + +See the API docs for `SQLContext.read` ( + Scala, + Java, + Python +) and `DataFrame.write` ( + Scala, + Java, + Python +) more information. + + +#### DataFrame.groupBy retains grouping columns + Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
    @@ -1726,7 +2138,7 @@ sqlContext.udf.register("strLen", (s: String) => s.length())
    {% highlight java %} -sqlContext.udf().register("strLen", (String s) -> { s.length(); }); +sqlContext.udf().register("strLen", (String s) -> s.length(), DataTypes.IntegerType); {% endhighlight %}
    @@ -2354,5 +2766,151 @@ from pyspark.sql.types import *
    +
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Data typeValue type in RAPI to access or create a data type
    ByteType + integer
    + Note: Numbers will be converted to 1-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -128 to 127. +
    + "byte" +
    ShortType + integer
    + Note: Numbers will be converted to 2-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -32768 to 32767. +
    + "short" +
    IntegerType integer + "integer" +
    LongType + integer
    + Note: Numbers will be converted to 8-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of + -9223372036854775808 to 9223372036854775807. + Otherwise, please convert data to decimal.Decimal and use DecimalType. +
    + "long" +
    FloatType + numeric
    + Note: Numbers will be converted to 4-byte single-precision floating + point numbers at runtime. +
    + "float" +
    DoubleType numeric + "double" +
    DecimalType Not supported + Not supported +
    StringType character + "string" +
    BinaryType raw + "binary" +
    BooleanType logical + "bool" +
    TimestampType POSIXct + "timestamp" +
    DateType Date + "date" +
    ArrayType vector or list + list(type="array", elementType=elementType, containsNull=[containsNull])
    + Note: The default value of containsNull is True. +
    MapType environment + list(type="map", keyType=keyType, valueType=valueType, valueContainsNull=[valueContainsNull])
    + Note: The default value of valueContainsNull is True. +
    StructType named list + list(type="struct", fields=fields)
    + Note: fields is a Seq of StructFields. Also, two fields with the same + name are not allowed. +
    StructField The value type in R of the data type of this field + (For example, integer for a StructField with the data type IntegerType) + list(name=name, type=dataType, nullable=nullable) +
    + +
    +
    diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index 6a2048121f8bf..a75587a92adc7 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -4,7 +4,7 @@ title: Spark Streaming Custom Receivers --- Spark Streaming can receive streaming data from any arbitrary data source beyond -the one's for which it has in-built support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.). +the ones for which it has built-in support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.). This requires the developer to implement a *receiver* that is customized for receiving data from the concerned data source. This guide walks through the process of implementing a custom receiver and using it in a Spark Streaming application. Note that custom receivers can be implemented @@ -21,15 +21,15 @@ A custom receiver must extend this abstract class by implementing two methods - `onStop()`: Things to do to stop receiving data. Both `onStart()` and `onStop()` must not block indefinitely. Typically, `onStart()` would start the threads -that responsible for receiving the data and `onStop()` would ensure that the receiving by those threads +that are responsible for receiving the data, and `onStop()` would ensure that these threads receiving the data are stopped. The receiving threads can also use `isStopped()`, a `Receiver` method, to check whether they should stop receiving data. Once the data is received, that data can be stored inside Spark by calling `store(data)`, which is a method provided by the Receiver class. -There are number of flavours of `store()` which allow you store the received data -record-at-a-time or as whole collection of objects / serialized bytes. Note that the flavour of -`store()` used to implemented a receiver affects its reliability and fault-tolerance semantics. +There are a number of flavors of `store()` which allow one to store the received data +record-at-a-time or as whole collection of objects / serialized bytes. Note that the flavor of +`store()` used to implement a receiver affects its reliability and fault-tolerance semantics. This is discussed [later](#receiver-reliability) in more detail. Any exception in the receiving threads should be caught and handled properly to avoid silent @@ -60,7 +60,7 @@ class CustomReceiver(host: String, port: Int) def onStop() { // There is nothing much to do as the thread calling receive() - // is designed to stop by itself isStopped() returns false + // is designed to stop by itself if isStopped() returns false } /** Create a socket connection and receive data until receiver is stopped */ @@ -123,7 +123,7 @@ public class JavaCustomReceiver extends Receiver { public void onStop() { // There is nothing much to do as the thread calling receive() - // is designed to stop by itself isStopped() returns false + // is designed to stop by itself if isStopped() returns false } /** Create a socket connection and receive data until receiver is stopped */ @@ -167,7 +167,7 @@ public class JavaCustomReceiver extends Receiver { The custom receiver can be used in a Spark Streaming application by using `streamingContext.receiverStream()`. This will create -input DStream using data received by the instance of custom receiver, as shown below +an input DStream using data received by the instance of custom receiver, as shown below:
    @@ -206,22 +206,20 @@ there are two kinds of receivers based on their reliability and fault-tolerance and stored in Spark reliably (that is, replicated successfully). Usually, implementing this receiver involves careful consideration of the semantics of source acknowledgements. -1. *Unreliable Receiver* - These are receivers for unreliable sources that do not support - acknowledging. Even for reliable sources, one may implement an unreliable receiver that - do not go into the complexity of acknowledging correctly. +1. *Unreliable Receiver* - An *unreliable receiver* does *not* send acknowledgement to a source. This can be used for sources that do not support acknowledgement, or even for reliable sources when one does not want or need to go into the complexity of acknowledgement. To implement a *reliable receiver*, you have to use `store(multiple-records)` to store data. -This flavour of `store` is a blocking call which returns only after all the given records have +This flavor of `store` is a blocking call which returns only after all the given records have been stored inside Spark. If the receiver's configured storage level uses replication (enabled by default), then this call returns after replication has completed. Thus it ensures that the data is reliably stored, and the receiver can now acknowledge the -source appropriately. This ensures that no data is caused when the receiver fails in the middle +source appropriately. This ensures that no data is lost when the receiver fails in the middle of replicating data -- the buffered data will not be acknowledged and hence will be later resent by the source. An *unreliable receiver* does not have to implement any of this logic. It can simply receive records from the source and insert them one-at-a-time using `store(single-record)`. While it does -not get the reliability guarantees of `store(multiple-records)`, it has the following advantages. +not get the reliability guarantees of `store(multiple-records)`, it has the following advantages: - The system takes care of chunking that data into appropriate sized blocks (look for block interval in the [Spark Streaming Programming Guide](streaming-programming-guide.html)). diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index c8ab146bcae0a..de0461010daec 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -58,6 +58,15 @@ configuring Flume agents. See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java).
    +
    + from pyspark.streaming.flume import FlumeUtils + + flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/flume_wordcount.py). +
    Note that the hostname should be the same as the one used by the resource manager in the @@ -99,6 +108,12 @@ Configuring Flume on the chosen machine requires the following two steps. artifactId = scala-library version = {{site.SCALA_VERSION}} + (iii) *Commons Lang 3 JAR*: Download the Commons Lang 3 JAR. It can be found with the following artifact detail (or, [direct link](http://search.maven.org/remotecontent?filepath=org/apache/commons/commons-lang3/3.3.2/commons-lang3-3.3.2.jar)). + + groupId = org.apache.commons + artifactId = commons-lang3 + version = 3.3.2 + 2. **Configuration file**: On that machine, configure Flume agent to send data to an Avro sink by having the following in the configuration file. agent.sinks = spark @@ -129,6 +144,15 @@ configuring Flume agents. JavaReceiverInputDStreamflumeStream = FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]); +
    + from pyspark.streaming.flume import FlumeUtils + + addresses = [([sink machine hostname 1], [sink port 1]), ([sink machine hostname 2], [sink port 2])] + flumeStream = FlumeUtils.createPollingStream(streamingContext, addresses) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils). +
    See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 64714f0b799fc..775d508d4879b 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -2,12 +2,12 @@ layout: global title: Spark Streaming + Kafka Integration Guide --- -[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. +[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. ## Approach 1: Receiver-based Approach This approach uses a Receiver to receive the data. The Received is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. -However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming. To ensure zero data loss, enable the Write Ahead Logs (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. Next, we discuss how to use this approach in your streaming application. @@ -29,7 +29,7 @@ Next, we discuss how to use this approach in your streaming application. [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
    import org.apache.spark.streaming.kafka.*; @@ -39,7 +39,7 @@ Next, we discuss how to use this approach in your streaming application. [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]); You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
    @@ -74,15 +74,15 @@ Next, we discuss how to use this approach in your streaming application. [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_2.10%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. ## Approach 2: Direct Approach (No Receivers) -This is a new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature in Spark 1.3 and is only available in the Scala and Java API. +This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API. Spark 1.4 added a Python API, but it is not yet at full feature parity. -This approach has the following advantages over the received-based approach (i.e. Approach 1). +This approach has the following advantages over the receiver-based approach (i.e. Approach 1). -- *Simplified Parallelism:* No need to create multiple input Kafka streams and union-ing them. With `directStream`, Spark Streaming will create as many RDD partitions as there is Kafka partitions to consume, which will all read data from Kafka in parallel. So there is one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. +- *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. -- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminate the problem as there is no receiver, and hence no need for Write Ahead Logs. +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper and offsets tracked only by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. +- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semanitcs of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). @@ -105,7 +105,7 @@ Next, we discuss how to use this approach in your streaming application. streamingContext, [map of Kafka parameters], [set of topics to consume]) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
    import org.apache.spark.streaming.kafka.*; @@ -116,8 +116,15 @@ Next, we discuss how to use this approach in your streaming application. [map of Kafka parameters], [set of topics to consume]); See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). +
    +
    + from pyspark.streaming.kafka import KafkaUtils + directKafkaStream = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + + By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/direct_kafka_wordcount.py).
    @@ -128,29 +135,60 @@ Next, we discuss how to use this approach in your streaming application.
    - directKafkaStream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges] - // offsetRanges.length = # of Kafka partitions being consumed - ... + // Hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + + directKafkaStream.transform { rdd => + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd + }.map { + ... + }.foreachRDD { rdd => + for (o <- offsetRanges) { + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } + ... }
    - directKafkaStream.foreachRDD( - new Function, Void>() { - @Override - public Void call(JavaPairRDD rdd) throws IOException { - OffsetRange[] offsetRanges = ((HasOffsetRanges)rdd).offsetRanges - // offsetRanges.length = # of Kafka partitions being consumed - ... - return null; - } + // Hold a reference to the current offset ranges, so it can be used downstream + final AtomicReference offsetRanges = new AtomicReference(); + + directKafkaStream.transformToPair( + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); + return rdd; + } + } + ).map( + ... + ).foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws IOException { + for (OffsetRange o : offsetRanges.get()) { + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() + ); + } + ... + return null; } + } );
    +
    + Not supported yet
    +
    You can use this to update Zookeeper yourself if you want Zookeeper-based Kafka monitoring tools to show progress of the streaming application. - Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate at which each Kafka partition will be read by this direct API. + Note that the typecast to HasOffsetRanges will only succeed if it is done in the first method called on the directKafkaStream, not later down a chain of methods. You can use transform() instead of foreachRDD() as your first method call in order to access offsets, then call further Spark methods. However, be aware that the one-to-one mapping between RDD partition and Kafka partition does not remain after any methods that shuffle or repartition, e.g. reduceByKey() or window(). + + Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate (in messages per second) at which each Kafka partition will be read by this direct API. -3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. \ No newline at end of file +3. **Deploying:** This is same as the first approach, for Scala, Java and Python. diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 379eb513d521e..aa9749afbc867 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -32,7 +32,8 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream val kinesisStream = KinesisUtils.createStream( - streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]) + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the Running the Example section for instructions on how to run the example. @@ -44,7 +45,8 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; JavaReceiverInputDStream kinesisStream = KinesisUtils.createStream( - streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]); + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2); See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. @@ -54,19 +56,23 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m - `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream - - `[Kinesis stream name]`: The Kinesis stream that this streaming application receives from - - The application name used in the streaming context becomes the Kinesis application name + - `[Kineiss app name]`: The application name that will be used to checkpoint the Kinesis + sequence numbers in DynamoDB table. - The application name must be unique for a given account and region. - - The Kinesis backend automatically associates the application name to the Kinesis stream using a DynamoDB table (always in the us-east-1 region) created during Kinesis Client Library initialization. - - Changing the application name or stream name can lead to Kinesis errors in some cases. If you see errors, you may need to manually delete the DynamoDB table. + - If the table exists but has incorrect checkpoint information (for a different stream, or + old expired sequenced numbers), then there may be temporary errors. + - `[Kinesis stream name]`: The Kinesis stream that this streaming application will pull data from. - `[endpoint URL]`: Valid Kinesis endpoints URL can be found [here](http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region). + - `[region name]`: Valid Kinesis region names can be found [here](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html). + - `[checkpoint interval]`: The interval (e.g., Duration(2000) = 2 seconds) at which the Kinesis Client Library saves its position in the stream. For starters, set it to the same as the batch interval of the streaming application. - `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see Kinesis Checkpointing section and Amazon Kinesis API documentation for more details). + In other versions of the API, you can also specify the AWS access key and secret key directly. 3. **Deploying:** Package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). @@ -122,12 +128,12 @@ To run the example,
    - bin/run-example streaming.KinesisWordCountASL [Kinesis stream name] [endpoint URL] + bin/run-example streaming.KinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
    - bin/run-example streaming.JavaKinesisWordCountASL [Kinesis stream name] [endpoint URL] + bin/run-example streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
    @@ -136,7 +142,7 @@ To run the example, - To generate random string data to put onto the Kinesis stream, in another terminal, run the associated Kinesis data producer. - bin/run-example streaming.KinesisWordCountProducerASL [Kinesis stream name] [endpoint URL] 1000 10 + bin/run-example streaming.KinesisWordProducerASL [Kinesis stream name] [endpoint URL] 1000 10 This will push 1000 lines per second of 10 random numbers per line to the Kinesis stream. This data should then be received and processed by the running example. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index bd863d48d53e3..2f3013b533eb0 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -11,7 +11,7 @@ description: Spark Streaming programming guide and tutorial for Spark SPARK_VERS # Overview Spark Streaming is an extension of the core Spark API that enables scalable, high-throughput, fault-tolerant stream processing of live data streams. Data can be ingested from many sources -like Kafka, Flume, Twitter, ZeroMQ, Kinesis or TCP sockets can be processed using complex +like Kafka, Flume, Twitter, ZeroMQ, Kinesis, or TCP sockets, and can be processed using complex algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`. Finally, processed data can be pushed out to filesystems, databases, and live dashboards. In fact, you can apply Spark's @@ -52,7 +52,7 @@ different languages. **Note:** Python API for Spark Streaming has been introduced in Spark 1.2. It has all the DStream transformations and almost all the output operations available in Scala and Java interfaces. -However, it has only support for basic sources like text files and text data over sockets. +However, it only has support for basic sources like text files and text data over sockets. APIs for additional sources, like Kafka and Flume, will be available in the future. Further information about available features in the Python API are mentioned throughout this document; look out for the tag @@ -69,15 +69,15 @@ do is as follows.
    -First, we import the names of the Spark Streaming classes, and some implicit -conversions from StreamingContext into our environment, to add useful methods to +First, we import the names of the Spark Streaming classes and some implicit +conversions from StreamingContext into our environment in order to add useful methods to other classes we need (like DStream). [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) is the -main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and batch interval of 1 second. +main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and a batch interval of 1 second. {% highlight scala %} import org.apache.spark._ import org.apache.spark.streaming._ -import org.apache.spark.streaming.StreamingContext._ // not necessary in Spark 1.3+ +import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3 // Create a local StreamingContext with two working thread and batch interval of 1 second. // The master requires 2 cores to prevent from a starvation scenario. @@ -96,7 +96,7 @@ val lines = ssc.socketTextStream("localhost", 9999) This `lines` DStream represents the stream of data that will be received from the data server. Each record in this DStream is a line of text. Next, we want to split the lines by -space into words. +space characters into words. {% highlight scala %} // Split each line into words @@ -109,7 +109,7 @@ each line will be split into multiple words and the stream of words is represent `words` DStream. Next, we want to count these words. {% highlight scala %} -import org.apache.spark.streaming.StreamingContext._ // not necessary in Spark 1.3+ +import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3 // Count each word in each batch val pairs = words.map(word => (word, 1)) val wordCounts = pairs.reduceByKey(_ + _) @@ -463,7 +463,7 @@ receive it there. However, for local testing and unit tests, you can pass "local in-process (detects the number of cores in the local system). Note that this internally creates a [SparkContext](api/scala/index.html#org.apache.spark.SparkContext) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`. The batch interval must be set based on the latency requirements of your application -and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval) section for more details. A `StreamingContext` object can also be created from an existing `SparkContext` object. @@ -498,7 +498,7 @@ receive it there. However, for local testing and unit tests, you can pass "local in-process. Note that this internally creates a [JavaSparkContext](api/java/index.html?org/apache/spark/api/java/JavaSparkContext.html) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`. The batch interval must be set based on the latency requirements of your application -and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval) section for more details. A `JavaStreamingContext` object can also be created from an existing `JavaSparkContext`. @@ -531,7 +531,7 @@ receive it there. However, for local testing and unit tests, you can pass "local in-process (detects the number of cores in the local system). The batch interval must be set based on the latency requirements of your application -and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval) section for more details.
    @@ -549,7 +549,7 @@ After a context is defined, you have to do the following. - Once a context has been started, no new streaming computations can be set up or added to it. - Once a context has been stopped, it cannot be restarted. - Only one StreamingContext can be active in a JVM at the same time. -- stop() on StreamingContext also stops the SparkContext. To stop only the StreamingContext, set optional parameter of `stop()` called `stopSparkContext` to false. +- stop() on StreamingContext also stops the SparkContext. To stop only the StreamingContext, set the optional parameter of `stop()` called `stopSparkContext` to false. - A SparkContext can be re-used to create multiple StreamingContexts, as long as the previous StreamingContext is stopped (without stopping the SparkContext) before the next StreamingContext is created. *** @@ -583,7 +583,7 @@ the `flatMap` operation is applied on each RDD in the `lines` DStream to generat These underlying RDD transformations are computed by the Spark engine. The DStream operations -hide most of these details and provide the developer with higher-level API for convenience. +hide most of these details and provide the developer with a higher-level API for convenience. These operations are discussed in detail in later sections. *** @@ -600,7 +600,7 @@ data from a source and stores it in Spark's memory for processing. Spark Streaming provides two categories of built-in streaming sources. - *Basic sources*: Sources directly available in the StreamingContext API. - Example: file systems, socket connections, and Akka actors. + Examples: file systems, socket connections, and Akka actors. - *Advanced sources*: Sources like Kafka, Flume, Kinesis, Twitter, etc. are available through extra utility classes. These require linking against extra dependencies as discussed in the [linking](#linking) section. @@ -610,11 +610,11 @@ We are going to discuss some of the sources present in each category later in th Note that, if you want to receive multiple streams of data in parallel in your streaming application, you can create multiple input DStreams (discussed further in the [Performance Tuning](#level-of-parallelism-in-data-receiving) section). This will -create multiple receivers which will simultaneously receive multiple data streams. But note that -Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the -Spark Streaming application. Hence, it is important to remember that Spark Streaming application +create multiple receivers which will simultaneously receive multiple data streams. But note that a +Spark worker/executor is a long-running task, hence it occupies one of the cores allocated to the +Spark Streaming application. Therefore, it is important to remember that a Spark Streaming application needs to be allocated enough cores (or threads, if running locally) to process the received data, -as well as, to run the receiver(s). +as well as to run the receiver(s). ##### Points to remember {:.no_toc} @@ -623,13 +623,13 @@ as well as, to run the receiver(s). Either of these means that only one thread will be used for running tasks locally. If you are using a input DStream based on a receiver (e.g. sockets, Kafka, Flume, etc.), then the single thread will be used to run the receiver, leaving no thread for processing the received data. Hence, when - running locally, always use "local[*n*]" as the master URL where *n* > number of receivers to run - (see [Spark Properties](configuration.html#spark-properties.html) for information on how to set + running locally, always use "local[*n*]" as the master URL, where *n* > number of receivers to run + (see [Spark Properties](configuration.html#spark-properties) for information on how to set the master). - Extending the logic to running on a cluster, the number of cores allocated to the Spark Streaming - application must be more than the number of receivers. Otherwise the system will receive data, but - not be able to process them. + application must be more than the number of receivers. Otherwise the system will receive data, but + not be able to process it. ### Basic Sources {:.no_toc} @@ -639,7 +639,7 @@ which creates a DStream from text data received over a TCP socket connection. Besides sockets, the StreamingContext API provides methods for creating DStreams from files and Akka actors as input sources. -- **File Streams:** For reading data from files on any file system compatible with the HDFS API (that is, HDFS, S3, NFS, etc.), a DStream can be created as +- **File Streams:** For reading data from files on any file system compatible with the HDFS API (that is, HDFS, S3, NFS, etc.), a DStream can be created as:
    @@ -682,14 +682,14 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea ### Advanced Sources {:.no_toc} -Python API As of Spark 1.3, -out of these sources, *only* Kafka is available in the Python API. We will add more advanced sources in the Python API in future. +Python API As of Spark {{site.SPARK_VERSION_SHORT}}, +out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts -of dependencies, the functionality to create DStreams from these sources have been moved to separate -libraries, that can be [linked](#linking) to explicitly when necessary. For example, if you want to -create a DStream using data from Twitter's stream of tweets, you have to do the following. +of dependencies, the functionality to create DStreams from these sources has been moved to separate +libraries that can be [linked](#linking) to explicitly when necessary. For example, if you want to +create a DStream using data from Twitter's stream of tweets, you have to do the following: 1. *Linking*: Add the artifact `spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` to the SBT/Maven project dependencies. @@ -719,11 +719,11 @@ TwitterUtils.createStream(jssc); Note that these advanced sources are not available in the Spark shell, hence applications based on these advanced sources cannot be tested in the shell. If you really want to use them in the Spark shell you will have to download the corresponding Maven artifact's JAR along with its dependencies -and it in the classpath. +and add it to the classpath. Some of these advanced sources are as follows. -- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.1.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. +- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.2.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. - **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. @@ -743,7 +743,7 @@ Some of these advanced sources are as follows. Python API This is not yet supported in Python. -Input DStreams can also be created out of custom data sources. All you have to do is implement an +Input DStreams can also be created out of custom data sources. All you have to do is implement a user-defined **receiver** (see next section to understand what that is) that can receive data from the custom sources and push it into Spark. See the [Custom Receiver Guide](streaming-custom-receivers.html) for details. @@ -753,14 +753,12 @@ Guide](streaming-custom-receivers.html) for details. There can be two kinds of data sources based on their *reliability*. Sources (like Kafka and Flume) allow the transferred data to be acknowledged. If the system receiving -data from these *reliable* sources acknowledge the received data correctly, it can be ensured -that no data gets lost due to any kind of failure. This leads to two kinds of receivers. +data from these *reliable* sources acknowledges the received data correctly, it can be ensured +that no data will be lost due to any kind of failure. This leads to two kinds of receivers: -1. *Reliable Receiver* - A *reliable receiver* correctly acknowledges a reliable - source that the data has been received and stored in Spark with replication. -1. *Unreliable Receiver* - These are receivers for sources that do not support acknowledging. Even - for reliable sources, one may implement an unreliable receiver that do not go into the complexity - of acknowledging correctly. +1. *Reliable Receiver* - A *reliable receiver* correctly sends acknowledgment to a reliable + source when the data has been received and stored in Spark with replication. +1. *Unreliable Receiver* - An *unreliable receiver* does *not* send acknowledgment to a source. This can be used for sources that do not support acknowledgment, or even for reliable sources when one does not want or need to go into the complexity of acknowledgment. The details of how to write a reliable receiver are discussed in the [Custom Receiver Guide](streaming-custom-receivers.html). @@ -828,7 +826,7 @@ Some of the common ones are as follows. cogroup(otherStream, [numTasks]) - When called on DStream of (K, V) and (K, W) pairs, return a new DStream of + When called on a DStream of (K, V) and (K, W) pairs, return a new DStream of (K, Seq[V], Seq[W]) tuples. @@ -852,13 +850,15 @@ A few of these transformations are worth discussing in more detail. The `updateStateByKey` operation allows you to maintain arbitrary state while continuously updating it with new information. To use this, you will have to do two steps. -1. Define the state - The state can be of arbitrary data type. +1. Define the state - The state can be an arbitrary data type. 1. Define the state update function - Specify with a function how to update the state using the -previous state and the new values from input stream. +previous state and the new values from an input stream. + +In every batch, Spark will apply the state update function for all existing keys, regardless of whether they have new data in a batch or not. If the update function returns `None` then the key-value pair will be eliminated. Let's illustrate this with an example. Say you want to maintain a running count of each word seen in a text data stream. Here, the running count is the state and it is an integer. We -define the update function as +define the update function as:
    @@ -947,7 +947,7 @@ operation that is not exposed in the DStream API. For example, the functionality of joining every batch in a data stream with another dataset is not directly exposed in the DStream API. However, you can easily use `transform` to do this. This enables very powerful possibilities. For example, -if you want to do real-time data cleaning by joining the input data stream with precomputed +one can do real-time data cleaning by joining the input data stream with precomputed spam information (maybe generated with Spark as well) and then filtering based on it.
    @@ -991,13 +991,14 @@ cleanedDStream = wordCounts.transform(lambda rdd: rdd.join(spamInfoRDD).filter(.
    -In fact, you can also use [machine learning](mllib-guide.html) and -[graph computation](graphx-programming-guide.html) algorithms in the `transform` method. +Note that the supplied function gets called in every batch interval. This allows you to do +time-varying RDD operations, that is, RDD operations, number of partitions, broadcast variables, +etc. can be changed between batches. #### Window Operations {:.no_toc} Spark Streaming also provides *windowed computations*, which allow you to apply -transformations over a sliding window of data. This following figure illustrates this sliding +transformations over a sliding window of data. The following figure illustrates this sliding window.

    @@ -1009,11 +1010,11 @@ window. As shown in the figure, every time the window *slides* over a source DStream, the source RDDs that fall within the window are combined and operated upon to produce the -RDDs of the windowed DStream. In this specific case, the operation is applied over last 3 time +RDDs of the windowed DStream. In this specific case, the operation is applied over the last 3 time units of data, and slides by 2 time units. This shows that any window operation needs to specify two parameters. - * window length - The duration of the window (3 in the figure) + * window length - The duration of the window (3 in the figure). * sliding interval - The interval at which the window operation is performed (2 in the figure). @@ -1021,7 +1022,7 @@ These two parameters must be multiples of the batch interval of the source DStre figure). Let's illustrate the window operations with an example. Say, you want to extend the -[earlier example](#a-quick-example) by generating word counts over last 30 seconds of data, +[earlier example](#a-quick-example) by generating word counts over the last 30 seconds of data, every 10 seconds. To do this, we have to apply the `reduceByKey` operation on the `pairs` DStream of `(word, 1)` pairs over the last 30 seconds of data. This is done using the operation `reduceByKeyAndWindow`. @@ -1096,13 +1097,13 @@ said two parameters - windowLength and slideInterval. reduceByKeyAndWindow(func, invFunc, windowLength, slideInterval, [numTasks]) - A more efficient version of the above reduceByKeyAndWindow() where the reduce + A more efficient version of the above reduceByKeyAndWindow() where the reduce value of each window is calculated incrementally using the reduce values of the previous window. - This is done by reducing the new data that enter the sliding window, and "inverse reducing" the - old data that leave the window. An example would be that of "adding" and "subtracting" counts - of keys as the window slides. However, it is applicable to only "invertible reduce functions", + This is done by reducing the new data that enters the sliding window, and "inverse reducing" the + old data that leaves the window. An example would be that of "adding" and "subtracting" counts + of keys as the window slides. However, it is applicable only to "invertible reduce functions", that is, those reduce functions which have a corresponding "inverse reduce" function (taken as - parameter invFunc. Like in reduceByKeyAndWindow, the number of reduce tasks + parameter invFunc). Like in reduceByKeyAndWindow, the number of reduce tasks is configurable through an optional argument. Note that [checkpointing](#checkpointing) must be enabled for using this operation. @@ -1224,7 +1225,7 @@ For the Python API, see [DStream](api/python/pyspark.streaming.html#pyspark.stre *** ## Output Operations on DStreams -Output operations allow DStream's data to be pushed out external systems like a database or a file systems. +Output operations allow DStream's data to be pushed out to external systems like a database or a file systems. Since the output operations actually allow the transformed data to be consumed by external systems, they trigger the actual execution of all the DStream transformations (similar to actions for RDDs). Currently, the following output operations are defined: @@ -1233,7 +1234,7 @@ Currently, the following output operations are defined: Output OperationMeaning print() - Prints first ten elements of every batch of data in a DStream on the driver node running + Prints the first ten elements of every batch of data in a DStream on the driver node running the streaming application. This is useful for development and debugging.
    Python API This is called @@ -1242,12 +1243,12 @@ Currently, the following output operations are defined: saveAsTextFiles(prefix, [suffix]) - Save this DStream's contents as a text files. The file name at each batch interval is + Save this DStream's contents as text files. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". saveAsObjectFiles(prefix, [suffix]) - Save this DStream's contents as a SequenceFile of serialized Java objects. The file + Save this DStream's contents as SequenceFiles of serialized Java objects. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    @@ -1257,7 +1258,7 @@ Currently, the following output operations are defined: saveAsHadoopFiles(prefix, [suffix]) - Save this DStream's contents as a Hadoop file. The file name at each batch interval is + Save this DStream's contents as Hadoop files. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    Python API This is not available in @@ -1267,7 +1268,7 @@ Currently, the following output operations are defined: foreachRDD(func) The most generic output operator that applies a function, func, to each RDD generated from - the stream. This function should push the data in each RDD to a external system, like saving the RDD to + the stream. This function should push the data in each RDD to an external system, such as saving the RDD to files, or writing it over the network to a database. Note that the function func is executed in the driver process running the streaming application, and will usually have RDD actions in it that will force the computation of the streaming RDDs. @@ -1277,14 +1278,14 @@ Currently, the following output operations are defined: ### Design Patterns for using foreachRDD {:.no_toc} -`dstream.foreachRDD` is a powerful primitive that allows data to sent out to external systems. +`dstream.foreachRDD` is a powerful primitive that allows data to be sent out to external systems. However, it is important to understand how to use this primitive correctly and efficiently. Some of the common mistakes to avoid are as follows. Often writing data to external system requires creating a connection object (e.g. TCP connection to a remote server) and using it to send data to a remote system. For this purpose, a developer may inadvertently try creating a connection object at -the Spark driver, but try to use it in a Spark worker to save records in the RDDs. +the Spark driver, and then try to use it in a Spark worker to save records in the RDDs. For example (in Scala),

    @@ -1346,7 +1347,7 @@ dstream.foreachRDD(lambda rdd: rdd.foreach(sendRecord)) Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use -`rdd.foreachPartition` - create a single connection object and send all the records in a RDD +`rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection.
    @@ -1427,26 +1428,6 @@ You can easily use [DataFrames and SQL](sql-programming-guide.html) operations o
    {% highlight scala %} -/** Lazily instantiated singleton instance of SQLContext */ -object SQLContextSingleton { - @transient private var instance: SQLContext = null - - // Instantiate SQLContext on demand - def getInstance(sparkContext: SparkContext): SQLContext = synchronized { - if (instance == null) { - instance = new SQLContext(sparkContext) - } - instance - } -} - -... - -/** Case class for converting RDD to DataFrame */ -case class Row(word: String) - -... - /** DataFrame operations inside your streaming program */ val words: DStream[String] = ... @@ -1454,11 +1435,11 @@ val words: DStream[String] = ... words.foreachRDD { rdd => // Get the singleton instance of SQLContext - val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext) + val sqlContext = SQLContext.getOrCreate(rdd.sparkContext) import sqlContext.implicits._ - // Convert RDD[String] to RDD[case class] to DataFrame - val wordsDataFrame = rdd.map(w => Row(w)).toDF() + // Convert RDD[String] to DataFrame + val wordsDataFrame = rdd.toDF("word") // Register as table wordsDataFrame.registerTempTable("words") @@ -1476,19 +1457,6 @@ See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/ma
    {% highlight java %} -/** Lazily instantiated singleton instance of SQLContext */ -class JavaSQLContextSingleton { - static private transient SQLContext instance = null; - static public SQLContext getInstance(SparkContext sparkContext) { - if (instance == null) { - instance = new SQLContext(sparkContext); - } - return instance; - } -} - -... - /** Java Bean class for converting RDD to DataFrame */ public class JavaRow implements java.io.Serializable { private String word; @@ -1512,7 +1480,9 @@ words.foreachRDD( new Function2, Time, Void>() { @Override public Void call(JavaRDD rdd, Time time) { - SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context()); + + // Get the singleton instance of SQLContext + SQLContext sqlContext = SQLContext.getOrCreate(rdd.context()); // Convert RDD[String] to RDD[case class] to DataFrame JavaRDD rowRDD = rdd.map(new Function() { @@ -1581,7 +1551,7 @@ See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/ma
    -You can also run SQL queries on tables defined on streaming data from a different thread (that is, asynchronous to the running StreamingContext). Just make sure that you set the StreamingContext to remember sufficient amount of streaming data such that query can run. Otherwise the StreamingContext, which is unaware of the any asynchronous SQL queries, will delete off old streaming data before the query can complete. For example, if you want to query the last batch, but your query can take 5 minutes to run, then call `streamingContext.remember(Minutes(5))` (in Scala, or equivalent in other languages). +You can also run SQL queries on tables defined on streaming data from a different thread (that is, asynchronous to the running StreamingContext). Just make sure that you set the StreamingContext to remember a sufficient amount of streaming data such that the query can run. Otherwise the StreamingContext, which is unaware of the any asynchronous SQL queries, will delete off old streaming data before the query can complete. For example, if you want to query the last batch, but your query can take 5 minutes to run, then call `streamingContext.remember(Minutes(5))` (in Scala, or equivalent in other languages). See the [DataFrames and SQL](sql-programming-guide.html) guide to learn more about DataFrames. @@ -1594,7 +1564,7 @@ You can also easily use machine learning algorithms provided by [MLlib](mllib-gu ## Caching / Persistence Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, -using `persist()` method on a DStream will automatically persist every RDD of that DStream in +using the `persist()` method on a DStream will automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. @@ -1606,28 +1576,27 @@ default persistence level is set to replicate the data to two nodes for fault-to Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More -information on different persistence levels can be found in -[Spark Programming Guide](programming-guide.html#rdd-persistence). +information on different persistence levels can be found in the [Spark Programming Guide](programming-guide.html#rdd-persistence). *** ## Checkpointing A streaming application must operate 24/7 and hence must be resilient to failures unrelated to the application logic (e.g., system failures, JVM crashes, etc.). For this to be possible, -Spark Streaming needs to *checkpoints* enough information to a fault- +Spark Streaming needs to *checkpoint* enough information to a fault- tolerant storage system such that it can recover from failures. There are two types of data that are checkpointed. - *Metadata checkpointing* - Saving of the information defining the streaming computation to fault-tolerant storage like HDFS. This is used to recover from failure of the node running the driver of the streaming application (discussed in detail later). Metadata includes: - + *Configuration* - The configuration that were used to create the streaming application. + + *Configuration* - The configuration that was used to create the streaming application. + *DStream operations* - The set of DStream operations that define the streaming application. + *Incomplete batches* - Batches whose jobs are queued but have not completed yet. - *Data checkpointing* - Saving of the generated RDDs to reliable storage. This is necessary in some *stateful* transformations that combine data across multiple batches. In such - transformations, the generated RDDs depends on RDDs of previous batches, which causes the length - of the dependency chain to keep increasing with time. To avoid such unbounded increase in recovery + transformations, the generated RDDs depend on RDDs of previous batches, which causes the length + of the dependency chain to keep increasing with time. To avoid such unbounded increases in recovery time (proportional to dependency chain), intermediate RDDs of stateful transformations are periodically *checkpointed* to reliable storage (e.g. HDFS) to cut off the dependency chains. @@ -1641,10 +1610,10 @@ transformations are used. Checkpointing must be enabled for applications with any of the following requirements: - *Usage of stateful transformations* - If either `updateStateByKey` or `reduceByKeyAndWindow` (with - inverse function) is used in the application, then the checkpoint directory must be provided for - allowing periodic RDD checkpointing. + inverse function) is used in the application, then the checkpoint directory must be provided to + allow for periodic RDD checkpointing. - *Recovering from failures of the driver running the application* - Metadata checkpoints are used - for to recover with progress information. + to recover with progress information. Note that simple streaming applications without the aforementioned stateful transformations can be run without enabling checkpointing. The recovery from driver failures will also be partial in @@ -1659,7 +1628,7 @@ Checkpointing can be enabled by setting a directory in a fault-tolerant, reliable file system (e.g., HDFS, S3, etc.) to which the checkpoint information will be saved. This is done by using `streamingContext.checkpoint(checkpointDirectory)`. This will allow you to use the aforementioned stateful transformations. Additionally, -if you want make the application recover from driver failures, you should rewrite your +if you want to make the application recover from driver failures, you should rewrite your streaming application to have the following behavior. + When the program is being started for the first time, it will create a new StreamingContext, @@ -1780,18 +1749,17 @@ You can also explicitly create a `StreamingContext` from the checkpoint data and In addition to using `getOrCreate` one also needs to ensure that the driver process gets restarted automatically on failure. This can only be done by the deployment infrastructure that is used to run the application. This is further discussed in the -[Deployment](#deploying-applications.html) section. +[Deployment](#deploying-applications) section. Note that checkpointing of RDDs incurs the cost of saving to reliable storage. This may cause an increase in the processing time of those batches where RDDs get checkpointed. Hence, the interval of checkpointing needs to be set carefully. At small batch sizes (say 1 second), checkpointing every batch may significantly reduce operation throughput. Conversely, checkpointing too infrequently -causes the lineage and task sizes to grow which may have detrimental effects. For stateful +causes the lineage and task sizes to grow, which may have detrimental effects. For stateful transformations that require RDD checkpointing, the default interval is a multiple of the batch interval that is at least 10 seconds. It can be set by using -`dstream.checkpoint(checkpointInterval)`. Typically, a checkpoint interval of 5 - 10 times of -sliding interval of a DStream is good setting to try. +`dstream.checkpoint(checkpointInterval)`. Typically, a checkpoint interval of 5 - 10 sliding intervals of a DStream is a good setting to try. *** @@ -1864,17 +1832,17 @@ To run a Spark Streaming applications, you need to have the following. {:.no_toc} If a running Spark Streaming application needs to be upgraded with new -application code, then there are two possible mechanism. +application code, then there are two possible mechanisms. - The upgraded Spark Streaming application is started and run in parallel to the existing application. -Once the new one (receiving the same data as the old one) has been warmed up and ready +Once the new one (receiving the same data as the old one) has been warmed up and is ready for prime time, the old one be can be brought down. Note that this can be done for data sources that support sending the data to two destinations (i.e., the earlier and upgraded applications). - The existing application is shutdown gracefully (see [`StreamingContext.stop(...)`](api/scala/index.html#org.apache.spark.streaming.StreamingContext) or [`JavaStreamingContext.stop(...)`](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html) -for graceful shutdown options) which ensure data that have been received is completely +for graceful shutdown options) which ensure data that has been received is completely processed before shutdown. Then the upgraded application can be started, which will start processing from the same point where the earlier application left off. Note that this can be done only with input sources that support source-side buffering @@ -1909,10 +1877,10 @@ The following two metrics in web UI are particularly important: to finish. If the batch processing time is consistently more than the batch interval and/or the queueing -delay keeps increasing, then it indicates the system is -not able to process the batches as fast they are being generated and falling behind. +delay keeps increasing, then it indicates that the system is +not able to process the batches as fast they are being generated and is falling behind. In that case, consider -[reducing](#reducing-the-processing-time-of-each-batch) the batch processing time. +[reducing](#reducing-the-batch-processing-times) the batch processing time. The progress of a Spark Streaming program can also be monitored using the [StreamingListener](api/scala/index.html#org.apache.spark.streaming.scheduler.StreamingListener) interface, @@ -1923,8 +1891,8 @@ and it is likely to be improved upon (i.e., more information reported) in the fu *************************************************************************************************** # Performance Tuning -Getting the best performance of a Spark Streaming application on a cluster requires a bit of -tuning. This section explains a number of the parameters and configurations that can tuned to +Getting the best performance out of a Spark Streaming application on a cluster requires a bit of +tuning. This section explains a number of the parameters and configurations that can be tuned to improve the performance of you application. At a high level, you need to consider two things: 1. Reducing the processing time of each batch of data by efficiently using cluster resources. @@ -1934,22 +1902,22 @@ improve the performance of you application. At a high level, you need to conside ## Reducing the Batch Processing Times There are a number of optimizations that can be done in Spark to minimize the processing time of -each batch. These have been discussed in detail in [Tuning Guide](tuning.html). This section +each batch. These have been discussed in detail in the [Tuning Guide](tuning.html). This section highlights some of the most important ones. ### Level of Parallelism in Data Receiving {:.no_toc} -Receiving data over the network (like Kafka, Flume, socket, etc.) requires the data to deserialized +Receiving data over the network (like Kafka, Flume, socket, etc.) requires the data to be deserialized and stored in Spark. If the data receiving becomes a bottleneck in the system, then consider parallelizing the data receiving. Note that each input DStream creates a single receiver (running on a worker machine) that receives a single stream of data. Receiving multiple data streams can therefore be achieved by creating multiple input DStreams and configuring them to receive different partitions of the data stream from the source(s). For example, a single Kafka input DStream receiving two topics of data can be split into two -Kafka input streams, each receiving only one topic. This would run two receivers on two workers, -thus allowing data to be received in parallel, and increasing overall throughput. These multiple -DStream can be unioned together to create a single DStream. Then the transformations that was -being applied on the single input DStream can applied on the unified stream. This is done as follows. +Kafka input streams, each receiving only one topic. This would run two receivers, +allowing data to be received in parallel, thus increasing overall throughput. These multiple +DStreams can be unioned together to create a single DStream. Then the transformations that were +being applied on a single input DStream can be applied on the unified stream. This is done as follows.
    @@ -1971,16 +1939,24 @@ JavaPairDStream unifiedStream = streamingContext.union(kafkaStre unifiedStream.print(); {% endhighlight %}
    +
    +{% highlight python %} +numStreams = 5 +kafkaStreams = [KafkaUtils.createStream(...) for _ in range (numStreams)] +unifiedStream = streamingContext.union(kafkaStreams) +unifiedStream.print() +{% endhighlight %} +
    Another parameter that should be considered is the receiver's blocking interval, which is determined by the [configuration parameter](configuration.html#spark-streaming) `spark.streaming.blockInterval`. For most receivers, the received data is coalesced together into blocks of data before storing inside Spark's memory. The number of blocks in each batch -determines the number of tasks that will be used to process those +determines the number of tasks that will be used to process the received data in a map-like transformation. The number of tasks per receiver per batch will be approximately (batch interval / block interval). For example, block interval of 200 ms will -create 10 tasks per 2 second batches. Too low the number of tasks (that is, less than the number +create 10 tasks per 2 second batches. If the number of tasks is too low (that is, less than the number of cores per machine), then it will be inefficient as all available cores will not be used to process the data. To increase the number of tasks for a given batch interval, reduce the block interval. However, the recommended minimum value of block interval is about 50 ms, @@ -1988,7 +1964,7 @@ below which the task launching overheads may be a problem. An alternative to receiving data with multiple input streams / receivers is to explicitly repartition the input data stream (using `inputStream.repartition()`). -This distributes the received batches of data across specified number of machines in the cluster +This distributes the received batches of data across the specified number of machines in the cluster before further processing. ### Level of Parallelism in Data Processing @@ -1996,7 +1972,7 @@ before further processing. Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is controlled by -the`spark.default.parallelism` [configuration property](configuration.html#spark-properties). You +the `spark.default.parallelism` [configuration property](configuration.html#spark-properties). You can pass the level of parallelism as an argument (see [`PairDStreamFunctions`](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions) documentation), or set the `spark.default.parallelism` @@ -2004,20 +1980,20 @@ documentation), or set the `spark.default.parallelism` ### Data Serialization {:.no_toc} -The overheads of data serialization can be reduce by tuning the serialization formats. In case of streaming, there are two types of data that are being serialized. +The overheads of data serialization can be reduced by tuning the serialization formats. In the case of streaming, there are two types of data that are being serialized. -* **Input data**: By default, the input data received through Receivers is stored in the executors' memory with [StorageLevel.MEMORY_AND_DISK_SER_2](api/scala/index.html#org.apache.spark.storage.StorageLevel$). That is, the data is serialized into bytes to reduce GC overheads, and replicated for tolerating executor failures. Also, the data is kept first in memory, and spilled over to disk only if the memory is unsufficient to hold all the input data necessary for the streaming computation. This serialization obviously has overheads -- the receiver must deserialize the received data and re-serialize it using Spark's serialization format. +* **Input data**: By default, the input data received through Receivers is stored in the executors' memory with [StorageLevel.MEMORY_AND_DISK_SER_2](api/scala/index.html#org.apache.spark.storage.StorageLevel$). That is, the data is serialized into bytes to reduce GC overheads, and replicated for tolerating executor failures. Also, the data is kept first in memory, and spilled over to disk only if the memory is insufficient to hold all of the input data necessary for the streaming computation. This serialization obviously has overheads -- the receiver must deserialize the received data and re-serialize it using Spark's serialization format. -* **Persisted RDDs generated by Streaming Operations**: RDDs generated by streaming computations may be persisted in memory. For example, window operation persist data in memory as they would be processed multiple times. However, unlike Spark, by default RDDs are persisted with [StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) (i.e. serialized) to minimize GC overheads. +* **Persisted RDDs generated by Streaming Operations**: RDDs generated by streaming computations may be persisted in memory. For example, window operations persist data in memory as they would be processed multiple times. However, unlike the Spark Core default of [StorageLevel.MEMORY_ONLY](api/scala/index.html#org.apache.spark.storage.StorageLevel$), persisted RDDs generated by streaming computations are persisted with [StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) (i.e. serialized) by default to minimize GC overheads. -In both cases, using Kryo serialization can reduce both CPU and memory overheads. See the [Spark Tuning Guide](tuning.html#data-serialization)) for more details. Consider registering custom classes, and disabling object reference tracking for Kryo (see Kryo-related configurations in the [Configuration Guide](configuration.html#compression-and-serialization)). +In both cases, using Kryo serialization can reduce both CPU and memory overheads. See the [Spark Tuning Guide](tuning.html#data-serialization) for more details. For Kryo, consider registering custom classes, and disabling object reference tracking (see Kryo-related configurations in the [Configuration Guide](configuration.html#compression-and-serialization)). -In specific cases where the amount of data that needs to be retained for the streaming application is not large, it may be feasible to persist data (both types) as deserialized objects without incurring excessive GC overheads. For example, if you are using batch intervals of few seconds and no window operations, then you can try disabling serialization in persisted data by explicitly setting the storage level accordingly. This would reduce the CPU overheads due to serialization, potentially improving performance without too much GC overheads. +In specific cases where the amount of data that needs to be retained for the streaming application is not large, it may be feasible to persist data (both types) as deserialized objects without incurring excessive GC overheads. For example, if you are using batch intervals of a few seconds and no window operations, then you can try disabling serialization in persisted data by explicitly setting the storage level accordingly. This would reduce the CPU overheads due to serialization, potentially improving performance without too much GC overheads. ### Task Launching Overheads {:.no_toc} If the number of tasks launched per second is high (say, 50 or more per second), then the overhead -of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second +of sending out tasks to the slaves may be significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: * **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task @@ -2036,7 +2012,7 @@ thus allowing sub-second batch size to be viable. For a Spark Streaming application running on a cluster to be stable, the system should be able to process data as fast as it is being received. In other words, batches of data should be processed as fast as they are being generated. Whether this is true for an application can be found by -[monitoring](#monitoring) the processing times in the streaming web UI, where the batch +[monitoring](#monitoring-applications) the processing times in the streaming web UI, where the batch processing time should be less than the batch interval. Depending on the nature of the streaming @@ -2049,35 +2025,35 @@ production can be sustained. A good approach to figure out the right batch size for your application is to test it with a conservative batch interval (say, 5-10 seconds) and a low data rate. To verify whether the system -is able to keep up with data rate, you can check the value of the end-to-end delay experienced +is able to keep up with the data rate, you can check the value of the end-to-end delay experienced by each processed batch (either look for "Total delay" in Spark driver log4j logs, or use the [StreamingListener](api/scala/index.html#org.apache.spark.streaming.scheduler.StreamingListener) interface). If the delay is maintained to be comparable to the batch size, then system is stable. Otherwise, if the delay is continuously increasing, it means that the system is unable to keep up and it therefore unstable. Once you have an idea of a stable configuration, you can try increasing the -data rate and/or reducing the batch size. Note that momentary increase in the delay due to -temporary data rate increases maybe fine as long as the delay reduces back to a low value +data rate and/or reducing the batch size. Note that a momentary increase in the delay due to +temporary data rate increases may be fine as long as the delay reduces back to a low value (i.e., less than batch size). *** ## Memory Tuning -Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail +Tuning the memory usage and GC behavior of Spark applications has been discussed in great detail in the [Tuning Guide](tuning.html#memory-tuning). It is strongly recommended that you read that. In this section, we discuss a few tuning parameters specifically in the context of Spark Streaming applications. -The amount of cluster memory required by a Spark Streaming application depends heavily on the type of transformations used. For example, if you want to use a window operation on last 10 minutes of data, then your cluster should have sufficient memory to hold 10 minutes of worth of data in memory. Or if you want to use `updateStateByKey` with a large number of keys, then the necessary memory will be high. On the contrary, if you want to do a simple map-filter-store operation, then necessary memory will be low. +The amount of cluster memory required by a Spark Streaming application depends heavily on the type of transformations used. For example, if you want to use a window operation on the last 10 minutes of data, then your cluster should have sufficient memory to hold 10 minutes worth of data in memory. Or if you want to use `updateStateByKey` with a large number of keys, then the necessary memory will be high. On the contrary, if you want to do a simple map-filter-store operation, then the necessary memory will be low. -In general, since the data received through receivers are stored with StorageLevel.MEMORY_AND_DISK_SER_2, the data that does not fit in memory will spill over to the disk. This may reduce the performance of the streaming application, and hence it is advised to provide sufficient memory as required by your streaming application. Its best to try and see the memory usage on a small scale and estimate accordingly. +In general, since the data received through receivers is stored with StorageLevel.MEMORY_AND_DISK_SER_2, the data that does not fit in memory will spill over to the disk. This may reduce the performance of the streaming application, and hence it is advised to provide sufficient memory as required by your streaming application. Its best to try and see the memory usage on a small scale and estimate accordingly. -Another aspect of memory tuning is garbage collection. For a streaming application that require low latency, it is undesirable to have large pauses caused by JVM Garbage Collection. +Another aspect of memory tuning is garbage collection. For a streaming application that requires low latency, it is undesirable to have large pauses caused by JVM Garbage Collection. -There are a few parameters that can help you tune the memory usage and GC overheads. +There are a few parameters that can help you tune the memory usage and GC overheads: -* **Persistence Level of DStreams**: As mentioned earlier in the [Data Serialization](#data-serialization) section, the input data and RDDs are by default persisted as serialized bytes. This reduces both, the memory usage and GC overheads, compared to deserialized persistence. Enabling Kryo serialization further reduces serialized sizes and memory usage. Further reduction in memory usage can be achieved with compression (see the Spark configuration `spark.rdd.compress`), at the cost of CPU time. +* **Persistence Level of DStreams**: As mentioned earlier in the [Data Serialization](#data-serialization) section, the input data and RDDs are by default persisted as serialized bytes. This reduces both the memory usage and GC overheads, compared to deserialized persistence. Enabling Kryo serialization further reduces serialized sizes and memory usage. Further reduction in memory usage can be achieved with compression (see the Spark configuration `spark.rdd.compress`), at the cost of CPU time. -* **Clearing old data**: By default, all input data and persisted RDDs generated by DStream transformations are automatically cleared. Spark Streaming decides when to clear the data based on the transformations that are used. For example, if you are using window operation of 10 minutes, then Spark Streaming will keep around last 10 minutes of data, and actively throw away older data. -Data can be retained for longer duration (e.g. interactively querying older data) by setting `streamingContext.remember`. +* **Clearing old data**: By default, all input data and persisted RDDs generated by DStream transformations are automatically cleared. Spark Streaming decides when to clear the data based on the transformations that are used. For example, if you are using a window operation of 10 minutes, then Spark Streaming will keep around the last 10 minutes of data, and actively throw away older data. +Data can be retained for a longer duration (e.g. interactively querying older data) by setting `streamingContext.remember`. * **CMS Garbage Collector**: Use of the concurrent mark-and-sweep GC is strongly recommended for keeping GC-related pauses consistently low. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more @@ -2107,18 +2083,18 @@ re-computed from the original fault-tolerant dataset using the lineage of operat 1. Assuming that all of the RDD transformations are deterministic, the data in the final transformed RDD will always be the same irrespective of failures in the Spark cluster. -Spark operates on data on fault-tolerant file systems like HDFS or S3. Hence, +Spark operates on data in fault-tolerant file systems like HDFS or S3. Hence, all of the RDDs generated from the fault-tolerant data are also fault-tolerant. However, this is not the case for Spark Streaming as the data in most cases is received over the network (except when `fileStream` is used). To achieve the same fault-tolerance properties for all of the generated RDDs, the received data is replicated among multiple Spark executors in worker nodes in the cluster (default replication factor is 2). This leads to two kinds of data in the -system that needs to recovered in the event of failures: +system that need to recovered in the event of failures: 1. *Data received and replicated* - This data survives failure of a single worker node as a copy - of it exists on one of the nodes. + of it exists on one of the other nodes. 1. *Data received but buffered for replication* - Since this is not replicated, - the only way to recover that data is to get it again from the source. + the only way to recover this data is to get it again from the source. Furthermore, there are two kinds of failures that we should be concerned about: @@ -2145,13 +2121,13 @@ In any stream processing system, broadly speaking, there are three steps in proc 1. *Receiving the data*: The data is received from sources using Receivers or otherwise. -1. *Transforming the data*: The data received data is transformed using DStream and RDD transformations. +1. *Transforming the data*: The received data is transformed using DStream and RDD transformations. 1. *Pushing out the data*: The final transformed data is pushed out to external systems like file systems, databases, dashboards, etc. -If a streaming application has to achieve end-to-end exactly-once guarantees, then each step has to provide exactly-once guarantee. That is, each record must be received exactly once, transformed exactly once, and pushed to downstream systems exactly once. Let's understand the semantics of these steps in the context of Spark Streaming. +If a streaming application has to achieve end-to-end exactly-once guarantees, then each step has to provide an exactly-once guarantee. That is, each record must be received exactly once, transformed exactly once, and pushed to downstream systems exactly once. Let's understand the semantics of these steps in the context of Spark Streaming. -1. *Receiving the data*: Different input sources provided different guarantees. This is discussed in detail in the next subsection. +1. *Receiving the data*: Different input sources provide different guarantees. This is discussed in detail in the next subsection. 1. *Transforming the data*: All data that has been received will be processed _exactly once_, thanks to the guarantees that RDDs provide. Even if there are failures, as long as the received input data is accessible, the final transformed RDDs will always have the same contents. @@ -2163,9 +2139,9 @@ Different input sources provide different guarantees, ranging from _at-least onc ### With Files {:.no_toc} -If all of the input data is already present in a fault-tolerant files system like -HDFS, Spark Streaming can always recover from any failure and process all the data. This gives -*exactly-once* semantics, that all the data will be processed exactly once no matter what fails. +If all of the input data is already present in a fault-tolerant file system like +HDFS, Spark Streaming can always recover from any failure and process all of the data. This gives +*exactly-once* semantics, meaning all of the data will be processed exactly once no matter what fails. ### With Receiver-based Sources {:.no_toc} @@ -2174,21 +2150,21 @@ scenario and the type of receiver. As we discussed [earlier](#receiver-reliability), there are two types of receivers: 1. *Reliable Receiver* - These receivers acknowledge reliable sources only after ensuring that - the received data has been replicated. If such a receiver fails, - the buffered (unreplicated) data does not get acknowledged to the source. If the receiver is - restarted, the source will resend the data, and therefore no data will be lost due to the failure. -1. *Unreliable Receiver* - Such receivers can lose data when they fail due to worker - or driver failures. + the received data has been replicated. If such a receiver fails, the source will not receive + acknowledgment for the buffered (unreplicated) data. Therefore, if the receiver is + restarted, the source will resend the data, and no data will be lost due to the failure. +1. *Unreliable Receiver* - Such receivers do *not* send acknowledgment and therefore *can* lose + data when they fail due to worker or driver failures. Depending on what type of receivers are used we achieve the following semantics. If a worker node fails, then there is no data loss with reliable receivers. With unreliable receivers, data received but not replicated can get lost. If the driver node fails, -then besides these losses, all the past data that was received and replicated in memory will be +then besides these losses, all of the past data that was received and replicated in memory will be lost. This will affect the results of the stateful transformations. To avoid this loss of past received data, Spark 1.2 introduced _write -ahead logs_ which saves the received data to fault-tolerant storage. With the [write ahead logs -enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides at-least once guarantee. +ahead logs_ which save the received data to fault-tolerant storage. With the [write ahead logs +enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides an at-least once guarantee. The following table summarizes the semantics under failures: @@ -2234,7 +2210,7 @@ The following table summarizes the semantics under failures: ### With Kafka Direct API {:.no_toc} -In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach (experimental as of Spark 1.3) is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html). +In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach (experimental as of Spark {{site.SPARK_VERSION_SHORT}}) is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html). ## Semantics of output operations {:.no_toc} @@ -2248,9 +2224,16 @@ additional effort may be necessary to achieve exactly-once semantics. There are - *Transactional updates*: All updates are made transactionally so that updates are made exactly once atomically. One way to do this would be the following. - - Use the batch time (available in `foreachRDD`) and the partition index of the transformed RDD to create an identifier. This identifier uniquely identifies a blob data in the streaming application. - - Update external system with this blob transactionally (that is, exactly once, atomically) using the identifier. That is, if the identifier is not already committed, commit the partition data and the identifier atomically. Else if this was already committed, skip the update. + - Use the batch time (available in `foreachRDD`) and the partition index of the RDD to create an identifier. This identifier uniquely identifies a blob data in the streaming application. + - Update external system with this blob transactionally (that is, exactly once, atomically) using the identifier. That is, if the identifier is not already committed, commit the partition data and the identifier atomically. Else, if this was already committed, skip the update. + dstream.foreachRDD { (rdd, time) => + rdd.foreachPartition { partitionIterator => + val partitionId = TaskContext.get.partitionId() + val uniqueId = generateUniqueId(time.milliseconds, partitionId) + // use this uniqueId to transactionally commit the data in partitionIterator + } + } *************************************************************************************************** *************************************************************************************************** @@ -2325,7 +2308,7 @@ package and renamed for better clarity. - Java docs * [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html), [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) and - [PairJavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/PairJavaDStream.html) + [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html) * [KafkaUtils](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html), [FlumeUtils](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html), [KinesisUtils](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index ab4a96f232c13..7c83d68e7993e 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -19,8 +19,9 @@ # limitations under the License. # -from __future__ import with_statement, print_function +from __future__ import division, print_function, with_statement +import codecs import hashlib import itertools import logging @@ -47,8 +48,10 @@ else: from urllib.request import urlopen, Request from urllib.error import HTTPError + raw_input = input + xrange = range -SPARK_EC2_VERSION = "1.2.1" +SPARK_EC2_VERSION = "1.4.0" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) VALID_SPARK_VERSIONS = set([ @@ -65,6 +68,9 @@ "1.1.1", "1.2.0", "1.2.1", + "1.3.0", + "1.3.1", + "1.4.0", ]) SPARK_TACHYON_MAP = { @@ -75,6 +81,9 @@ "1.1.1": "0.5.0", "1.2.0": "0.5.0", "1.2.1": "0.5.0", + "1.3.0": "0.5.0", + "1.3.1": "0.5.0", + "1.4.0": "0.6.4", } DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION @@ -82,7 +91,7 @@ # Default location to get the spark-ec2 scripts (and ami-list) from DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/mesos/spark-ec2" -DEFAULT_SPARK_EC2_BRANCH = "branch-1.3" +DEFAULT_SPARK_EC2_BRANCH = "branch-1.4" def setup_external_libs(libs): @@ -116,7 +125,7 @@ def setup_external_libs(libs): ) with open(tgz_file_path, "wb") as tgz_file: tgz_file.write(download_stream.read()) - with open(tgz_file_path) as tar: + with open(tgz_file_path, "rb") as tar: if hashlib.md5(tar.read()).hexdigest() != lib["md5"]: print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr) sys.exit(1) @@ -212,7 +221,8 @@ def parse_args(): "(default: %default).") parser.add_option( "--hadoop-major-version", default="1", - help="Major version of Hadoop (default: %default)") + help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.2.0), yarn " + + "(Hadoop 2.4.0) (default: %default)") parser.add_option( "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + @@ -264,7 +274,8 @@ def parse_args(): help="Launch fresh slaves, but use an existing stopped master if possible") parser.add_option( "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: %default)") + help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " + + "is used as Hadoop major version (default: %default)") parser.add_option( "--master-opts", type="string", default="", help="Extra options to give to master through SPARK_MASTER_OPTS variable " + @@ -278,6 +289,10 @@ def parse_args(): parser.add_option( "--additional-security-group", type="string", default="", help="Additional security group to place the machines in") + parser.add_option( + "--additional-tags", type="string", default="", + help="Additional tags to set on the machines; tags are comma-separated, while name and " + + "value are colon separated; ex: \"Task:MySparkProject,Env:production\"") parser.add_option( "--copy-aws-credentials", action="store_true", default=False, help="Add AWS credentials to hadoop configuration to allow Spark to access S3") @@ -291,6 +306,13 @@ def parse_args(): "--private-ips", action="store_true", default=False, help="Use private IPs for instances rather than public if VPC/subnet " + "requires that.") + parser.add_option( + "--instance-initiated-shutdown-behavior", default="stop", + choices=["stop", "terminate"], + help="Whether instances should terminate when shut down or just stop") + parser.add_option( + "--instance-profile-name", default=None, + help="IAM profile name to launch instances under") (opts, args) = parser.parse_args() if len(args) != 2: @@ -303,14 +325,16 @@ def parse_args(): home_dir = os.getenv('HOME') if home_dir is None or not os.path.isfile(home_dir + '/.boto'): if not os.path.isfile('/etc/boto.cfg'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", - file=stderr) - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", - file=stderr) - sys.exit(1) + # If there is no boto config, check aws credentials + if not os.path.isfile(home_dir + '/.aws/credentials'): + if os.getenv('AWS_ACCESS_KEY_ID') is None: + print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", + file=stderr) + sys.exit(1) + if os.getenv('AWS_SECRET_ACCESS_KEY') is None: + print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", + file=stderr) + sys.exit(1) return (opts, action, cluster_name) @@ -347,7 +371,7 @@ def get_validate_spark_version(version, repo): # Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ -# Last Updated: 2015-05-08 +# Last Updated: 2015-06-19 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. EC2_INSTANCE_TYPES = { "c1.medium": "pvm", @@ -389,6 +413,11 @@ def get_validate_spark_version(version, repo): "m3.large": "hvm", "m3.xlarge": "hvm", "m3.2xlarge": "hvm", + "m4.large": "hvm", + "m4.xlarge": "hvm", + "m4.2xlarge": "hvm", + "m4.4xlarge": "hvm", + "m4.10xlarge": "hvm", "r3.large": "hvm", "r3.xlarge": "hvm", "r3.2xlarge": "hvm", @@ -398,6 +427,7 @@ def get_validate_spark_version(version, repo): "t2.micro": "hvm", "t2.small": "hvm", "t2.medium": "hvm", + "t2.large": "hvm", } @@ -419,13 +449,14 @@ def get_spark_ami(opts): b=opts.spark_ec2_git_branch) ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) + reader = codecs.getreader("ascii") try: - ami = urlopen(ami_path).read().strip() - print("Spark AMI: " + ami) + ami = reader(urlopen(ami_path)).read().strip() except: print("Could not resolve AMI at: " + ami_path, file=stderr) sys.exit(1) + print("Spark AMI: " + ami) return ami @@ -476,6 +507,8 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('tcp', 50070, 50070, authorized_address) master_group.authorize('tcp', 60070, 60070, authorized_address) master_group.authorize('tcp', 4040, 4045, authorized_address) + # Rstudio (GUI for R) needs port 8787 for web access + master_group.authorize('tcp', 8787, 8787, authorized_address) # HDFS NFS gateway requires 111,2049,4242 for tcp & udp master_group.authorize('tcp', 111, 111, authorized_address) master_group.authorize('udp', 111, 111, authorized_address) @@ -483,6 +516,8 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('udp', 2049, 2049, authorized_address) master_group.authorize('tcp', 4242, 4242, authorized_address) master_group.authorize('udp', 4242, 4242, authorized_address) + # RM in YARN mode uses 8088 + master_group.authorize('tcp', 8088, 8088, authorized_address) if opts.ganglia: master_group.authorize('tcp', 5080, 5080, authorized_address) if slave_group.rules == []: # Group was just now created @@ -578,7 +613,8 @@ def launch_cluster(conn, opts, cluster_name): block_device_map=block_map, subnet_id=opts.subnet_id, placement_group=opts.placement_group, - user_data=user_data_content) + user_data=user_data_content, + instance_profile_name=opts.instance_profile_name) my_req_ids += [req.id for req in slave_reqs] i += 1 @@ -623,16 +659,19 @@ def launch_cluster(conn, opts, cluster_name): for zone in zones: num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) if num_slaves_this_zone > 0: - slave_res = image.run(key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - placement=zone, - min_count=num_slaves_this_zone, - max_count=num_slaves_this_zone, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) + slave_res = image.run( + key_name=opts.key_pair, + security_group_ids=[slave_group.id] + additional_group_ids, + instance_type=opts.instance_type, + placement=zone, + min_count=num_slaves_this_zone, + max_count=num_slaves_this_zone, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) slave_nodes += slave_res.instances print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( s=num_slaves_this_zone, @@ -654,32 +693,43 @@ def launch_cluster(conn, opts, cluster_name): master_type = opts.instance_type if opts.zone == 'all': opts.zone = random.choice(conn.get_all_zones()).name - master_res = image.run(key_name=opts.key_pair, - security_group_ids=[master_group.id] + additional_group_ids, - instance_type=master_type, - placement=opts.zone, - min_count=1, - max_count=1, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) + master_res = image.run( + key_name=opts.key_pair, + security_group_ids=[master_group.id] + additional_group_ids, + instance_type=master_type, + placement=opts.zone, + min_count=1, + max_count=1, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) master_nodes = master_res.instances print("Launched master in %s, regid = %s" % (zone, master_res.id)) # This wait time corresponds to SPARK-4983 print("Waiting for AWS to propagate instance metadata...") - time.sleep(5) - # Give the instances descriptive names + time.sleep(15) + + # Give the instances descriptive names and set additional tags + additional_tags = {} + if opts.additional_tags.strip(): + additional_tags = dict( + map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',') + ) + for master in master_nodes: - master.add_tag( - key='Name', - value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + master.add_tags( + dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + ) + for slave in slave_nodes: - slave.add_tag( - key='Name', - value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + slave.add_tags( + dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + ) # Return all the instances return (master_nodes, slave_nodes) @@ -743,14 +793,18 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', - 'mapreduce', 'spark-standalone', 'tachyon'] + 'mapreduce', 'spark-standalone', 'tachyon', 'rstudio'] if opts.hadoop_major_version == "1": - modules = filter(lambda x: x != "mapreduce", modules) + modules = list(filter(lambda x: x != "mapreduce", modules)) if opts.ganglia: modules.append('ganglia') + # Clear SPARK_WORKER_INSTANCES if running on YARN + if opts.hadoop_major_version == "yarn": + opts.worker_instances = "" + # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( @@ -860,7 +914,11 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): for i in cluster_instances: i.update() - statuses = conn.get_all_instance_status(instance_ids=[i.id for i in cluster_instances]) + max_batch = 100 + statuses = [] + for j in xrange(0, len(cluster_instances), max_batch): + batch = [i.id for i in cluster_instances[j:j + max_batch]] + statuses.extend(conn.get_all_instance_status(instance_ids=batch)) if cluster_state == 'ssh-ready': if all(i.state == 'running' for i in cluster_instances) and \ @@ -889,7 +947,7 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): # Get number of local disks available for a given EC2 instance type. def get_num_disks(instance_type): # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Last Updated: 2015-05-08 + # Last Updated: 2015-06-19 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. disks_by_instance = { "c1.medium": 1, @@ -931,6 +989,11 @@ def get_num_disks(instance_type): "m3.large": 1, "m3.xlarge": 2, "m3.2xlarge": 2, + "m4.large": 0, + "m4.xlarge": 0, + "m4.2xlarge": 0, + "m4.4xlarge": 0, + "m4.10xlarge": 0, "r3.large": 1, "r3.xlarge": 1, "r3.2xlarge": 1, @@ -940,6 +1003,7 @@ def get_num_disks(instance_type): "t2.micro": 0, "t2.small": 0, "t2.medium": 0, + "t2.large": 0, } if instance_type in disks_by_instance: return disks_by_instance[instance_type] @@ -984,6 +1048,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] + worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else "" template_vars = { "master_list": '\n'.join(master_addresses), "active_master": active_master, @@ -997,7 +1062,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): "spark_version": spark_v, "tachyon_version": tachyon_v, "hadoop_major_version": opts.hadoop_major_version, - "spark_worker_instances": "%d" % opts.worker_instances, + "spark_worker_instances": worker_instances_str, "spark_master_opts": opts.master_opts } @@ -1090,8 +1155,8 @@ def ssh(host, opts, command): # If this was an ssh failure, provide the user with hints. if e.returncode == 255: raise UsageError( - "Failed to SSH to remote host {0}.\n" + - "Please check that you have provided the correct --identity-file and " + + "Failed to SSH to remote host {0}.\n" + "Please check that you have provided the correct --identity-file and " "--key-pair parameters and try again.".format(host)) else: raise e @@ -1152,7 +1217,7 @@ def get_zones(conn, opts): # Gets the number of items in a partition def get_partition(total, num_partitions, current_partitions): - num_slaves_this_zone = total / num_partitions + num_slaves_this_zone = total // num_partitions if (total % num_partitions) - current_partitions > 0: num_slaves_this_zone += 1 return num_slaves_this_zone diff --git a/examples/pom.xml b/examples/pom.xml index 5b04b4f8d6ca0..e6884b09dca94 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -97,6 +97,11 @@ + + org.apache.spark + spark-streaming-kafka_${scala.binary.version} + ${project.version} + org.apache.hbase hbase-testing-util @@ -392,45 +397,6 @@ - - - scala-2.10 - - !scala-2.11 - - - - org.apache.spark - spark-streaming-kafka_${scala.binary.version} - ${project.version} - - - - - - org.codehaus.mojo - build-helper-maven-plugin - - - add-scala-sources - generate-sources - - add-source - - - - src/main/scala - scala-2.10/src/main/scala - scala-2.10/src/main/java - - - - - - - - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index eac4f898a475d..9df26ffca5775 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -28,6 +28,7 @@ import org.apache.spark.ml.classification.ClassificationModel; import org.apache.spark.ml.param.IntParam; import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.util.Identifiable$; import org.apache.spark.mllib.linalg.BLAS; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; @@ -103,7 +104,23 @@ public static void main(String[] args) throws Exception { * However, this should still compile and run successfully. */ class MyJavaLogisticRegression - extends Classifier { + extends Classifier { + + public MyJavaLogisticRegression() { + init(); + } + + public MyJavaLogisticRegression(String uid) { + this.uid_ = uid; + init(); + } + + private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); + + @Override + public String uid() { + return uid_; + } /** * Param for max number of iterations @@ -117,7 +134,7 @@ class MyJavaLogisticRegression int getMaxIter() { return (Integer) getOrDefault(maxIter); } - public MyJavaLogisticRegression() { + private void init() { setMaxIter(100); } @@ -137,7 +154,12 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. // Create a model, and return it. - return new MyJavaLogisticRegressionModel(this, weights); + return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this); + } + + @Override + public MyJavaLogisticRegression copy(ParamMap extra) { + return defaultCopy(extra); } } @@ -149,17 +171,21 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { * However, this should still compile and run successfully. */ class MyJavaLogisticRegressionModel - extends ClassificationModel { - - private MyJavaLogisticRegression parent_; - public MyJavaLogisticRegression parent() { return parent_; } + extends ClassificationModel { private Vector weights_; public Vector weights() { return weights_; } - public MyJavaLogisticRegressionModel(MyJavaLogisticRegression parent_, Vector weights_) { - this.parent_ = parent_; - this.weights_ = weights_; + public MyJavaLogisticRegressionModel(String uid, Vector weights) { + this.uid_ = uid; + this.weights_ = weights; + } + + private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); + + @Override + public String uid() { + return uid_; } // This uses the default implementation of transform(), which reads column "features" and outputs @@ -204,6 +230,6 @@ public Vector predictRaw(Vector features) { */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(parent_, weights_), extra); + return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java new file mode 100644 index 0000000000000..75063dbf800d8 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.commons.cli.*; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.OneVsRest; +import org.apache.spark.ml.classification.OneVsRestModel; +import org.apache.spark.ml.util.MetadataUtils; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructField; + +/** + * An example runner for Multiclass to Binary Reduction with One Vs Rest. + * The example uses Logistic Regression as the base classifier. All parameters that + * can be specified on the base classifier can be passed in to the runner options. + * Run with + *
    + * bin/run-example ml.JavaOneVsRestExample [options]
    + * 
    + */ +public class JavaOneVsRestExample { + + private static class Params { + String input; + String testInput = null; + Integer maxIter = 100; + double tol = 1E-6; + boolean fitIntercept = true; + Double regParam = null; + Double elasticNetParam = null; + double fracTest = 0.2; + } + + public static void main(String[] args) { + // parse the arguments + Params params = parse(args); + SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // configure the base classifier + LogisticRegression classifier = new LogisticRegression() + .setMaxIter(params.maxIter) + .setTol(params.tol) + .setFitIntercept(params.fitIntercept); + + if (params.regParam != null) { + classifier.setRegParam(params.regParam); + } + if (params.elasticNetParam != null) { + classifier.setElasticNetParam(params.elasticNetParam); + } + + // instantiate the One Vs Rest Classifier + OneVsRest ovr = new OneVsRest().setClassifier(classifier); + + String input = params.input; + RDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), input); + RDD train; + RDD test; + + // compute the train/ test split: if testInput is not provided use part of input + String testInput = params.testInput; + if (testInput != null) { + train = inputData; + // compute the number of features in the training set. + int numFeatures = inputData.first().features().size(); + test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures); + } else { + double f = params.fracTest; + RDD[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); + train = tmp[0]; + test = tmp[1]; + } + + // train the multiclass model + DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class); + OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache()); + + // score the model on test data + DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class); + DataFrame predictions = ovrModel.transform(testDataFrame.cache()) + .select("prediction", "label"); + + // obtain metrics + MulticlassMetrics metrics = new MulticlassMetrics(predictions); + StructField predictionColSchema = predictions.schema().apply("prediction"); + Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get(); + + // compute the false positive rate per label + StringBuilder results = new StringBuilder(); + results.append("label\tfpr\n"); + for (int label = 0; label < numClasses; label++) { + results.append(label); + results.append("\t"); + results.append(metrics.falsePositiveRate((double) label)); + results.append("\n"); + } + + Matrix confusionMatrix = metrics.confusionMatrix(); + // output the Confusion Matrix + System.out.println("Confusion Matrix"); + System.out.println(confusionMatrix); + System.out.println(); + System.out.println(results); + + jsc.stop(); + } + + private static Params parse(String[] args) { + Options options = generateCommandlineOptions(); + CommandLineParser parser = new PosixParser(); + Params params = new Params(); + + try { + CommandLine cmd = parser.parse(options, args); + String value; + if (cmd.hasOption("input")) { + params.input = cmd.getOptionValue("input"); + } + if (cmd.hasOption("maxIter")) { + value = cmd.getOptionValue("maxIter"); + params.maxIter = Integer.parseInt(value); + } + if (cmd.hasOption("tol")) { + value = cmd.getOptionValue("tol"); + params.tol = Double.parseDouble(value); + } + if (cmd.hasOption("fitIntercept")) { + value = cmd.getOptionValue("fitIntercept"); + params.fitIntercept = Boolean.parseBoolean(value); + } + if (cmd.hasOption("regParam")) { + value = cmd.getOptionValue("regParam"); + params.regParam = Double.parseDouble(value); + } + if (cmd.hasOption("elasticNetParam")) { + value = cmd.getOptionValue("elasticNetParam"); + params.elasticNetParam = Double.parseDouble(value); + } + if (cmd.hasOption("testInput")) { + value = cmd.getOptionValue("testInput"); + params.testInput = value; + } + if (cmd.hasOption("fracTest")) { + value = cmd.getOptionValue("fracTest"); + params.fracTest = Double.parseDouble(value); + } + + } catch (ParseException e) { + printHelpAndQuit(options); + } + return params; + } + + private static Options generateCommandlineOptions() { + Option input = OptionBuilder.withArgName("input") + .hasArg() + .isRequired() + .withDescription("input path to labeled examples. This path must be specified") + .create("input"); + Option testInput = OptionBuilder.withArgName("testInput") + .hasArg() + .withDescription("input path to test examples") + .create("testInput"); + Option fracTest = OptionBuilder.withArgName("testInput") + .hasArg() + .withDescription("fraction of data to hold out for testing." + + " If given option testInput, this option is ignored. default: 0.2") + .create("fracTest"); + Option maxIter = OptionBuilder.withArgName("maxIter") + .hasArg() + .withDescription("maximum number of iterations for Logistic Regression. default:100") + .create("maxIter"); + Option tol = OptionBuilder.withArgName("tol") + .hasArg() + .withDescription("the convergence tolerance of iterations " + + "for Logistic Regression. default: 1E-6") + .create("tol"); + Option fitIntercept = OptionBuilder.withArgName("fitIntercept") + .hasArg() + .withDescription("fit intercept for logistic regression. default true") + .create("fitIntercept"); + Option regParam = OptionBuilder.withArgName( "regParam" ) + .hasArg() + .withDescription("the regularization parameter for Logistic Regression.") + .create("regParam"); + Option elasticNetParam = OptionBuilder.withArgName("elasticNetParam" ) + .hasArg() + .withDescription("the ElasticNet mixing parameter for Logistic Regression.") + .create("elasticNetParam"); + + Options options = new Options() + .addOption(input) + .addOption(testInput) + .addOption(fracTest) + .addOption(maxIter) + .addOption(tol) + .addOption(fitIntercept) + .addOption(regParam) + .addOption(elasticNetParam); + + return options; + } + + private static void printHelpAndQuit(Options options) { + HelpFormatter formatter = new HelpFormatter(); + formatter.printHelp("JavaOneVsRestExample", options); + System.exit(-1); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 29158d5c85651..dac649d1d5ae6 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -97,7 +97,7 @@ public static void main(String[] args) { DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. + // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. DataFrame results = model2.transform(test); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index ef1ec103a879f..54738813d0016 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -66,7 +66,7 @@ public static void main(String[] args) { .setOutputCol("features"); LogisticRegression lr = new LogisticRegression() .setMaxIter(10) - .setRegParam(0.01); + .setRegParam(0.001); Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); @@ -77,7 +77,7 @@ public static void main(String[] args) { List localTest = Lists.newArrayList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), + new Document(6L, "spark hadoop spark"), new Document(7L, "apache hadoop")); DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 8159ffbe2d269..afee279ec32b1 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -94,12 +94,12 @@ public String call(Row row) { System.out.println("=== Data source: Parquet File ==="); // DataFrames can be saved as parquet files, maintaining the schema information. - schemaPeople.saveAsParquetFile("people.parquet"); + schemaPeople.write().parquet("people.parquet"); // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. - DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); + DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -120,7 +120,7 @@ public String call(Row row) { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a DataFrame from the file(s) pointed by path - DataFrame peopleFromJsonFile = sqlContext.jsonFile(path); + DataFrame peopleFromJsonFile = sqlContext.read().json(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -151,7 +151,7 @@ public String call(Row row) { List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - DataFrame peopleFromJsonRDD = sqlContext.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlContext.read().json(anotherPeopleRDD.rdd()); // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java similarity index 100% rename from examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java similarity index 100% rename from examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index 5b82a14fba413..c5ae5d043b8ea 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -18,6 +18,7 @@ from __future__ import print_function import sys +import json from pyspark import SparkContext @@ -27,24 +28,24 @@ hbase(main):016:0> create 'test', 'f1' 0 row(s) in 1.0430 seconds -hbase(main):017:0> put 'test', 'row1', 'f1', 'value1' +hbase(main):017:0> put 'test', 'row1', 'f1:a', 'value1' 0 row(s) in 0.0130 seconds -hbase(main):018:0> put 'test', 'row2', 'f1', 'value2' +hbase(main):018:0> put 'test', 'row1', 'f1:b', 'value2' 0 row(s) in 0.0030 seconds -hbase(main):019:0> put 'test', 'row3', 'f1', 'value3' +hbase(main):019:0> put 'test', 'row2', 'f1', 'value3' 0 row(s) in 0.0050 seconds -hbase(main):020:0> put 'test', 'row4', 'f1', 'value4' +hbase(main):020:0> put 'test', 'row3', 'f1', 'value4' 0 row(s) in 0.0110 seconds hbase(main):021:0> scan 'test' ROW COLUMN+CELL - row1 column=f1:, timestamp=1401883411986, value=value1 - row2 column=f1:, timestamp=1401883415212, value=value2 - row3 column=f1:, timestamp=1401883417858, value=value3 - row4 column=f1:, timestamp=1401883420805, value=value4 + row1 column=f1:a, timestamp=1401883411986, value=value1 + row1 column=f1:b, timestamp=1401883415212, value=value2 + row2 column=f1:, timestamp=1401883417858, value=value3 + row3 column=f1:, timestamp=1401883420805, value=value4 4 row(s) in 0.0240 seconds """ if __name__ == "__main__": @@ -64,6 +65,8 @@ table = sys.argv[2] sc = SparkContext(appName="HBaseInputFormat") + # Other options for configuring scan behavior are available. More information available at + # https://github.com/apache/hbase/blob/master/hbase-server/src/main/java/org/apache/hadoop/hbase/mapreduce/TableInputFormat.java conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table} if len(sys.argv) > 3: conf = {"hbase.zookeeper.quorum": host, "zookeeper.znode.parent": sys.argv[3], @@ -78,6 +81,8 @@ keyConverter=keyConv, valueConverter=valueConv, conf=conf) + hbase_rdd = hbase_rdd.flatMapValues(lambda v: v.split("\n")).mapValues(json.loads) + output = hbase_rdd.collect() for (k, v) in output: print((k, v)) diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 1456c87312841..0ea7cfb7025a0 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -68,7 +68,7 @@ def closestPoint(p, centers): closest = data.map( lambda p: (closestPoint(p, kPoints), (p, 1))) pointStats = closest.reduceByKey( - lambda (p1, c1), (p2, c2): (p1 + p2, c1 + c2)) + lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1])) newPoints = pointStats.map( lambda st: (st[0], st[1][0] / st[1][1])).collect() diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py new file mode 100644 index 0000000000000..f0ca97c724940 --- /dev/null +++ b/examples/src/main/python/ml/cross_validator.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.evaluation import BinaryClassificationEvaluator +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.ml.tuning import CrossValidator, ParamGridBuilder +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating model selection using CrossValidator. +This example also demonstrates how Pipelines are Estimators. +Run with: + + bin/spark-submit examples/src/main/python/ml/cross_validator.py +""" + +if __name__ == "__main__": + sc = SparkContext(appName="CrossValidatorExample") + sqlContext = SQLContext(sc) + + # Prepare training documents, which are labeled. + LabeledDocument = Row("id", "text", "label") + training = sc.parallelize([(0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0), + (4, "b spark who", 1.0), + (5, "g d a y", 0.0), + (6, "spark fly", 1.0), + (7, "was mapreduce", 0.0), + (8, "e spark program", 1.0), + (9, "a e c l", 0.0), + (10, "spark compile", 1.0), + (11, "hadoop software", 0.0) + ]) \ + .map(lambda x: LabeledDocument(*x)).toDF() + + # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + + # We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + # This will allow us to jointly choose parameters for all Pipeline stages. + # A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + # We use a ParamGridBuilder to construct a grid of parameters to search over. + # With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + # this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + paramGrid = ParamGridBuilder() \ + .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \ + .addGrid(lr.regParam, [0.1, 0.01]) \ + .build() + + crossval = CrossValidator(estimator=pipeline, + estimatorParamMaps=paramGrid, + evaluator=BinaryClassificationEvaluator(), + numFolds=2) # use 3+ folds in practice + + # Run cross-validation, and choose the best set of parameters. + cvModel = crossval.fit(training) + + # Prepare test documents, which are unlabeled. + Document = Row("id", "text") + test = sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() + + # Make predictions on test documents. cvModel uses the best model found (lrModel). + prediction = cvModel.transform(test) + selected = prediction.select("id", "text", "probability", "prediction") + for row in selected.collect(): + print(row) + + sc.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py new file mode 100644 index 0000000000000..6446f0fe5eeab --- /dev/null +++ b/examples/src/main/python/ml/gradient_boosted_trees.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import GBTClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import GBTRegressor +from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline. +Note: GBTClassifier only supports binary classification currently +Run with: + bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py +""" + + +def testClassification(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = BinaryClassificationMetrics(predictionAndLabels) + print("AUC %.3f" % metrics.areaUnderROC) + + +def testRegression(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: gradient_boosted_trees", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonGBTExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py new file mode 100644 index 0000000000000..55afe1b207fe0 --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.evaluation import MulticlassMetrics +from pyspark.ml.feature import StringIndexer +from pyspark.mllib.util import MLUtils +from pyspark.sql import SQLContext + +""" +A simple example demonstrating a logistic regression with elastic net regularization Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression.py +""" + +if __name__ == "__main__": + + if len(sys.argv) > 1: + print("Usage: logistic_regression", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonLogisticRegressionExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [training, test] = td.randomSplit([0.7, 0.3]) + + lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") + lr.setElasticNetParam(0.8) + + # Fit the model + lrModel = lr.fit(training) + + predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py new file mode 100644 index 0000000000000..c7730e1bfacd9 --- /dev/null +++ b/examples/src/main/python/ml/random_forest_example.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import RandomForestClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import RandomForestRegressor +from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a RandomForest Classification/Regression Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/random_forest_example.py +""" + + +def testClassification(train, test): + # Train a RandomForest model. + # Setting featureSubsetStrategy="auto" lets the algorithm choose. + # Note: Use larger numTrees in practice. + + rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + +def testRegression(train, test): + # Train a RandomForest model. + # Note: Use larger numTrees in practice. + + rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: random_forest_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonRandomForestExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py new file mode 100644 index 0000000000000..a9f29dab2d602 --- /dev/null +++ b/examples/src/main/python/ml/simple_params_example.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import pprint +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.linalg import DenseVector +from pyspark.mllib.regression import LabeledPoint +from pyspark.sql import SQLContext + +""" +A simple example demonstrating ways to specify parameters for Estimators and Transformers. +Run with: + bin/spark-submit examples/src/main/python/ml/simple_params_example.py +""" + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: simple_params_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonSimpleParamsExample") + sqlContext = SQLContext(sc) + + # prepare training data. + # We create an RDD of LabeledPoints and convert them into a DataFrame. + # A LabeledPoint is an Object with two fields named label and features + # and Spark SQL identifies these fields and creates the schema appropriately. + training = sc.parallelize([ + LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])), + LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])), + LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])), + LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF() + + # Create a LogisticRegression instance with maxIter = 10. + # This instance is an Estimator. + lr = LogisticRegression(maxIter=10) + # Print out the parameters, documentation, and any default values. + print("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + + # We may also set parameters using setter methods. + lr.setRegParam(0.01) + + # Learn a LogisticRegression model. This uses the parameters stored in lr. + model1 = lr.fit(training) + + # Since model1 is a Model (i.e., a Transformer produced by an Estimator), + # we can view the parameters it used during fit(). + # This prints the parameter (name: value) pairs, where names are unique IDs for this + # LogisticRegression instance. + print("Model 1 was fit using parameters:\n") + pprint.pprint(model1.extractParamMap()) + + # We may alternatively specify parameters using a parameter map. + # paramMap overrides all lr parameters set earlier. + paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"} + + # Now learn a new model using the new parameters. + model2 = lr.fit(training, paramMap) + print("Model 2 was fit using parameters:\n") + pprint.pprint(model2.extractParamMap()) + + # prepare test data. + test = sc.parallelize([ + LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])), + LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])), + LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF() + + # Make predictions on test data using the Transformer.transform() method. + # LogisticRegressionModel.transform will only use the 'features' column. + # Note that model2.transform() outputs a 'myProbability' column instead of the usual + # 'probability' column since we renamed the lr.probabilityCol parameter previously. + result = model2.transform(test) \ + .select("features", "label", "myProbability", "prediction") \ + .collect() + + for row in result: + print("features=%s,label=%s -> prob=%s, prediction=%s" + % (row.features, row.label, row.myProbability, row.prediction)) + + sc.stop() diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index fab21f003b233..b4f06bf888746 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -48,7 +48,7 @@ # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. tokenizer = Tokenizer(inputCol="text", outputCol="words") hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") - lr = LogisticRegression(maxIter=10, regParam=0.01) + lr = LogisticRegression(maxIter=10, regParam=0.001) pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) # Fit the pipeline to training documents. @@ -58,7 +58,7 @@ Document = Row("id", "text") test = sc.parallelize([(4, "spark i j k"), (5, "l m n"), - (6, "mapreduce spark"), + (6, "spark hadoop spark"), (7, "apache hadoop")]) \ .map(lambda x: Document(*x)).toDF() diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index 96ddac761d698..e1fd85b082c08 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -51,7 +51,7 @@ parquet_rdd = sc.newAPIHadoopFile( path, - 'parquet.avro.AvroParquetInputFormat', + 'org.apache.parquet.avro.AvroParquetInputFormat', 'java.lang.Void', 'org.apache.avro.generic.IndexedRecord', valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter') diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py new file mode 100644 index 0000000000000..091b64d8c4af4 --- /dev/null +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: flume_wordcount.py + + To run this on your local machine, you need to setup Flume first, see + https://flume.apache.org/documentation.html + + and then run the example + `$ bin/spark-submit --jars external/flume-assembly/target/scala-*/\ + spark-streaming-flume-assembly-*.jar examples/src/main/python/streaming/flume_wordcount.py \ + localhost 12345 +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.flume import FlumeUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: flume_wordcount.py ", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonStreamingFlumeWordCount") + ssc = StreamingContext(sc, 1) + + hostname, port = sys.argv[1:] + kvs = FlumeUtils.createStream(ssc, hostname, int(port)) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/queue_stream.py b/examples/src/main/python/streaming/queue_stream.py new file mode 100644 index 0000000000000..dcd6a0fc6ff91 --- /dev/null +++ b/examples/src/main/python/streaming/queue_stream.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Create a queue of RDDs that will be mapped/reduced one at a time in + 1 second intervals. + + To run this example use + `$ bin/spark-submit examples/src/main/python/streaming/queue_stream.py +""" +import sys +import time + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonStreamingQueueStream") + ssc = StreamingContext(sc, 1) + + # Create the queue through which RDDs can be pushed to + # a QueueInputDStream + rddQueue = [] + for i in xrange(5): + rddQueue += [ssc.sparkContext.parallelize([j for j in xrange(1, 1001)], 10)] + + # Create the QueueInputDStream and use it do some processing + inputStream = ssc.queueStream(rddQueue) + mappedStream = inputStream.map(lambda x: (x % 10, 1)) + reducedStream = mappedStream.reduceByKey(lambda a, b: a + b) + reducedStream.pprint() + + ssc.start() + time.sleep(6) + ssc.stop(stopSparkContext=True, stopGraceFully=True) diff --git a/examples/src/main/r/data-manipulation.R b/examples/src/main/r/data-manipulation.R new file mode 100644 index 0000000000000..aa2336e300a91 --- /dev/null +++ b/examples/src/main/r/data-manipulation.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# For this example, we shall use the "flights" dataset +# The dataset consists of every flight departing Houston in 2011. +# The data set is made up of 227,496 rows x 14 columns. + +# To run this example use +# ./bin/sparkR --packages com.databricks:spark-csv_2.10:1.0.3 +# examples/src/main/r/data-manipulation.R + +# Load SparkR library into your R session +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: data-manipulation.R % + summarize(avg(flightsDF$dep_delay), avg(flightsDF$arr_delay)) -> dailyDelayDF + + # Print the computed data frame + head(dailyDelayDF) +} + +# Stop the SparkContext now +sparkR.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 4c129dbe2d12d..d812262fd87dc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -52,3 +53,4 @@ object BroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 11d5c92c5952d..36832f51d2ad4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ + // scalastyle:off println package org.apache.spark.examples import java.nio.ByteBuffer @@ -104,8 +105,8 @@ object CassandraCQLTest { val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), classOf[CqlPagingInputFormat], - classOf[java.util.Map[String,ByteBuffer]], - classOf[java.util.Map[String,ByteBuffer]]) + classOf[java.util.Map[String, ByteBuffer]], + classOf[java.util.Map[String, ByteBuffer]]) println("Count: " + casRdd.count) val productSaleRDD = casRdd.map { @@ -118,7 +119,7 @@ object CassandraCQLTest { case (productId, saleCount) => println(productId + ":" + saleCount) } - val casoutputCF = aggregatedRDD.map { + val casoutputCF = aggregatedRDD.map { case (productId, saleCount) => { val outColFamKey = Map("prod_id" -> ByteBufferUtil.bytes(productId)) val outKey: java.util.Map[String, ByteBuffer] = outColFamKey @@ -140,3 +141,4 @@ object CassandraCQLTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index ec689474aecb0..96ef3e198e380 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.nio.ByteBuffer @@ -130,6 +131,7 @@ object CassandraTest { sc.stop() } } +// scalastyle:on println /* create keyspace casDemo; diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala new file mode 100644 index 0000000000000..d651fe4d6ee75 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples + +import java.io.File + +import scala.io.Source._ + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.SparkContext._ + +/** + * Simple test for reading and writing to a distributed + * file system. This example does the following: + * + * 1. Reads local file + * 2. Computes word count on local file + * 3. Writes local file to a DFS + * 4. Reads the file back from the DFS + * 5. Computes word count on the file using Spark + * 6. Compares the word count results + */ +object DFSReadWriteTest { + + private var localFilePath: File = new File(".") + private var dfsDirPath: String = "" + + private val NPARAMS = 2 + + private def readFile(filename: String): List[String] = { + val lineIter: Iterator[String] = fromFile(filename).getLines() + val lineList: List[String] = lineIter.toList + lineList + } + + private def printUsage(): Unit = { + val usage: String = "DFS Read-Write Test\n" + + "\n" + + "Usage: localFile dfsDir\n" + + "\n" + + "localFile - (string) local file to use in test\n" + + "dfsDir - (string) DFS directory for read/write tests\n" + + println(usage) + } + + private def parseArgs(args: Array[String]): Unit = { + if (args.length != NPARAMS) { + printUsage() + System.exit(1) + } + + var i = 0 + + localFilePath = new File(args(i)) + if (!localFilePath.exists) { + System.err.println("Given path (" + args(i) + ") does not exist.\n") + printUsage() + System.exit(1) + } + + if (!localFilePath.isFile) { + System.err.println("Given path (" + args(i) + ") is not a file.\n") + printUsage() + System.exit(1) + } + + i += 1 + dfsDirPath = args(i) + } + + def runLocalWordCount(fileContents: List[String]): Int = { + fileContents.flatMap(_.split(" ")) + .flatMap(_.split("\t")) + .filter(_.size > 0) + .groupBy(w => w) + .mapValues(_.size) + .values + .sum + } + + def main(args: Array[String]): Unit = { + parseArgs(args) + + println("Performing local word count") + val fileContents = readFile(localFilePath.toString()) + val localWordCount = runLocalWordCount(fileContents) + + println("Creating SparkConf") + val conf = new SparkConf().setAppName("DFS Read Write Test") + + println("Creating SparkContext") + val sc = new SparkContext(conf) + + println("Writing local file to DFS") + val dfsFilename = dfsDirPath + "/dfs_read_write_test" + val fileRDD = sc.parallelize(fileContents) + fileRDD.saveAsTextFile(dfsFilename) + + println("Reading file from DFS and running Word Count") + val readFileRDD = sc.textFile(dfsFilename) + + val dfsWordCount = readFileRDD + .flatMap(_.split(" ")) + .flatMap(_.split("\t")) + .filter(_.size > 0) + .map(w => (w, 1)) + .countByKey() + .values + .sum + + sc.stop() + + if (localWordCount == dfsWordCount) { + println(s"Success! Local Word Count ($localWordCount) " + + s"and DFS Word Count ($dfsWordCount) agree.") + } else { + println(s"Failure! Local Word Count ($localWordCount) " + + s"and DFS Word Count ($dfsWordCount) disagree.") + } + + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index e757283823fc3..c42df2b8845d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.collection.JavaConversions._ @@ -46,3 +47,4 @@ object DriverSubmissionTest { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 15f6678648b29..fa4a3afeecd19 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -53,3 +54,4 @@ object GroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 849887d23c9cf..244742327a907 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.hadoop.hbase.client.HBaseAdmin @@ -59,5 +60,7 @@ object HBaseTest { hBaseRDD.count() sc.stop() + admin.close() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index ed2b38e2ca6f8..124dc9af6390f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark._ @@ -41,3 +42,4 @@ object HdfsTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 3d5259463003d..af5f216f28ba4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -142,3 +143,4 @@ object LocalALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index ac2ea35bbd0e0..9c8aae53cf48d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -73,3 +74,4 @@ object LocalFileLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index 04fc0a033014a..e7b28d38bdfc6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -119,3 +120,4 @@ object LocalKMeans { println("Final centers: " + kPoints) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index a55e0dc8d36c2..4f6b092a59ca5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -39,7 +40,7 @@ object LocalLR { def generateData: Array[DataPoint] = { def generatePoint(i: Int): DataPoint = { - val y = if(i % 2 == 0) -1 else 1 + val y = if (i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) } @@ -77,3 +78,4 @@ object LocalLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index ee6b3ee34aeb2..3d923625f11b6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -33,3 +34,4 @@ object LocalPi { println("Pi is roughly " + 4 * count / 100000.0) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 32e02eab8b031..a80de10f4610a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -22,7 +23,7 @@ import org.apache.spark.SparkContext._ /** * Executes a roll up-style query against Apache logs. - * + * * Usage: LogQuery [logFile] */ object LogQuery { @@ -83,3 +84,4 @@ object LogQuery { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 2a5c0c0defe13..61ce9db914f9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.rdd.RDD @@ -53,3 +54,4 @@ object MultiBroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 5291ab81f459e..3b0b00fe4dd0a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -67,3 +68,4 @@ object SimpleSkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 017d4e1e5ce13..719e2176fed3f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -57,3 +58,4 @@ object SkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 6c0ac8013ce34..69799b7c2bb30 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -117,7 +118,7 @@ object SparkALS { var us = Array.fill(U)(randomVector(F)) // Iteratively update movies then users - val Rc = sc.broadcast(R) + val Rc = sc.broadcast(R) var msb = sc.broadcast(ms) var usb = sc.broadcast(us) for (iter <- 1 to ITERATIONS) { @@ -144,3 +145,4 @@ object SparkALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 9099c2fcc90b3..505ea5a4c7a85 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -97,3 +98,4 @@ object SparkHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index b514d9123f5e7..c56e1124ad415 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import breeze.linalg.{Vector, DenseVector, squaredDistance} @@ -100,3 +101,4 @@ object SparkKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 8c01a60844620..d265c227f4ed2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -44,7 +45,7 @@ object SparkLR { def generateData: Array[DataPoint] = { def generatePoint(i: Int): DataPoint = { - val y = if(i % 2 == 0) -1 else 1 + val y = if (i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) } @@ -86,3 +87,4 @@ object SparkLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 8d092b6506d33..0fd79660dd196 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.SparkContext._ @@ -51,7 +52,7 @@ object SparkPageRank { showWarning() val sparkConf = new SparkConf().setAppName("PageRank") - val iters = if (args.length > 0) args(1).toInt else 10 + val iters = if (args.length > 1) args(1).toInt else 10 val ctx = new SparkContext(sparkConf) val lines = ctx.textFile(args(0), 1) val links = lines.map{ s => @@ -74,3 +75,4 @@ object SparkPageRank { ctx.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 35b8dd6c29b66..818d4f2b81f82 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -37,3 +38,4 @@ object SparkPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 772cd897f5140..95072071ccddb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.util.Random @@ -70,3 +71,4 @@ object SparkTC { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 4393b99e636b6..cfbdae02212a5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -94,3 +95,4 @@ object SparkTachyonHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala index 7743f7968b100..e46ac655beb58 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -46,3 +47,4 @@ object SparkTachyonPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala deleted file mode 100644 index ab6e63deb3c95..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import org.apache.spark._ -import org.apache.spark.bagel._ - -class PageRankUtils extends Serializable { - def computeWithCombiner(numVertices: Long, epsilon: Double)( - self: PRVertex, messageSum: Option[Double], superstep: Int - ): (PRVertex, Array[PRMessage]) = { - val newValue = messageSum match { - case Some(msgSum) if msgSum != 0 => - 0.15 / numVertices + 0.85 * msgSum - case _ => self.value - } - - val terminate = superstep >= 10 - - val outbox: Array[PRMessage] = - if (!terminate) { - self.outEdges.map(targetId => new PRMessage(targetId, newValue / self.outEdges.size)) - } else { - Array[PRMessage]() - } - - (new PRVertex(newValue, self.outEdges, !terminate), outbox) - } - - def computeNoCombiner(numVertices: Long, epsilon: Double) - (self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int) - : (PRVertex, Array[PRMessage]) = - computeWithCombiner(numVertices, epsilon)(self, messages match { - case Some(msgs) => Some(msgs.map(_.value).sum) - case None => None - }, superstep) -} - -class PRCombiner extends Combiner[PRMessage, Double] with Serializable { - def createCombiner(msg: PRMessage): Double = - msg.value - def mergeMsg(combiner: Double, msg: PRMessage): Double = - combiner + msg.value - def mergeCombiners(a: Double, b: Double): Double = - a + b -} - -class PRVertex() extends Vertex with Serializable { - var value: Double = _ - var outEdges: Array[String] = _ - var active: Boolean = _ - - def this(value: Double, outEdges: Array[String], active: Boolean = true) { - this() - this.value = value - this.outEdges = outEdges - this.active = active - } - - override def toString(): String = { - "PRVertex(value=%f, outEdges.length=%d, active=%s)" - .format(value, outEdges.length, active.toString) - } -} - -class PRMessage() extends Message[String] with Serializable { - var targetId: String = _ - var value: Double = _ - - def this(targetId: String, value: Double) { - this() - this.targetId = targetId - this.value = value - } -} - -class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions: Int = partitions - - def getPartition(key: Any): Int = { - val hash = key match { - case k: Long => (k & 0x00000000FFFFFFFFL).toInt - case _ => key.hashCode - } - - val mod = key.hashCode % partitions - if (mod < 0) mod + partitions else mod - } - - override def equals(other: Any): Boolean = other match { - case c: CustomPartitioner => - c.numPartitions == numPartitions - case _ => false - } - - override def hashCode: Int = numPartitions -} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala deleted file mode 100644 index 859abedf2a55e..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import org.apache.spark._ -import org.apache.spark.SparkContext._ - -import org.apache.spark.bagel._ - -import scala.xml.{XML,NodeSeq} - -/** - * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles" - * files from there, which contains one line per wiki article in a tab-separated format - * (http://wiki.freebase.com/wiki/WEX/Documentation#articles). - */ -object WikipediaPageRank { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println( - "Usage: WikipediaPageRank ") - System.exit(-1) - } - val sparkConf = new SparkConf() - sparkConf.setAppName("WikipediaPageRank") - sparkConf.registerKryoClasses(Array(classOf[PRVertex], classOf[PRMessage])) - - val inputFile = args(0) - val threshold = args(1).toDouble - val numPartitions = args(2).toInt - val usePartitioner = args(3).toBoolean - - sparkConf.setAppName("WikipediaPageRank") - val sc = new SparkContext(sparkConf) - - // Parse the Wikipedia page data into a graph - val input = sc.textFile(inputFile) - - println("Counting vertices...") - val numVertices = input.count() - println("Done counting vertices.") - - println("Parsing input file...") - var vertices = input.map(line => { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val links = - if (body == "\\N") { - NodeSeq.Empty - } else { - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body) - NodeSeq.Empty - } - } - val outEdges = links.map(link => new String(link.text)).toArray - val id = new String(title) - (id, new PRVertex(1.0 / numVertices, outEdges)) - }) - if (usePartitioner) { - vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache() - } else { - vertices = vertices.cache() - } - println("Done parsing input file.") - - // Do the computation - val epsilon = 0.01 / numVertices - val messages = sc.parallelize(Array[(String, PRMessage)]()) - val utils = new PageRankUtils - val result = - Bagel.run( - sc, vertices, messages, combiner = new PRCombiner(), - numPartitions = numPartitions)( - utils.computeWithCombiner(numVertices, epsilon)) - - // Print the result - System.err.println("Articles with PageRank >= " + threshold + ":") - val top = - (result - .filter { case (id, vertex) => vertex.value >= threshold } - .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } - .collect().mkString) - println(top) - - sc.stop() - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala deleted file mode 100644 index 576a3e371b993..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer -import scala.xml.{XML, NodeSeq} - -import org.apache.spark._ -import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - -import scala.reflect.ClassTag - -object WikipediaPageRankStandalone { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: WikipediaPageRankStandalone " + - " ") - System.exit(-1) - } - val sparkConf = new SparkConf() - sparkConf.set("spark.serializer", "spark.bagel.examples.WPRSerializer") - - val inputFile = args(0) - val threshold = args(1).toDouble - val numIterations = args(2).toInt - val usePartitioner = args(3).toBoolean - - sparkConf.setAppName("WikipediaPageRankStandalone") - - val sc = new SparkContext(sparkConf) - - val input = sc.textFile(inputFile) - val partitioner = new HashPartitioner(sc.defaultParallelism) - val links = - if (usePartitioner) { - input.map(parseArticle _).partitionBy(partitioner).cache() - } else { - input.map(parseArticle _).cache() - } - val n = links.count() - val defaultRank = 1.0 / n - val a = 0.15 - - // Do the computation - val startTime = System.currentTimeMillis - val ranks = - pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, - sc.defaultParallelism) - - // Print the result - System.err.println("Articles with PageRank >= " + threshold + ":") - val top = - (ranks - .filter { case (id, rank) => rank >= threshold } - .map { case (id, rank) => "%s\t%s\n".format(id, rank) } - .collect().mkString) - println(top) - - val time = (System.currentTimeMillis - startTime) / 1000.0 - println("Completed %d iterations in %f seconds: %f seconds per iteration" - .format(numIterations, time, time / numIterations)) - sc.stop() - } - - def parseArticle(line: String): (String, Array[String]) = { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val id = new String(title) - val links = - if (body == "\\N") { - NodeSeq.Empty - } else { - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body) - NodeSeq.Empty - } - } - val outEdges = links.map(link => new String(link.text)).toArray - (id, outEdges) - } - - def pageRank( - links: RDD[(String, Array[String])], - numIterations: Int, - defaultRank: Double, - a: Double, - n: Long, - partitioner: Partitioner, - usePartitioner: Boolean, - numPartitions: Int - ): RDD[(String, Double)] = { - var ranks = links.mapValues { edges => defaultRank } - for (i <- 1 to numIterations) { - val contribs = links.groupWith(ranks).flatMap { - case (id, (linksWrapperIterable, rankWrapperIterable)) => - val linksWrapper = linksWrapperIterable.iterator - val rankWrapper = rankWrapperIterable.iterator - if (linksWrapper.hasNext) { - val linksWrapperHead = linksWrapper.next - if (rankWrapper.hasNext) { - val rankWrapperHead = rankWrapper.next - linksWrapperHead.map(dest => (dest, rankWrapperHead / linksWrapperHead.size)) - } else { - linksWrapperHead.map(dest => (dest, defaultRank / linksWrapperHead.size)) - } - } else { - Array[(String, Double)]() - } - } - ranks = (contribs.combineByKey((x: Double) => x, - (x: Double, y: Double) => x + y, - (x: Double, y: Double) => x + y, - partitioner) - .mapValues(sum => a/n + (1-a)*sum)) - } - ranks - } -} - -class WPRSerializer extends org.apache.spark.serializer.Serializer { - def newInstance(): SerializerInstance = new WPRSerializerInstance() -} - -class WPRSerializerInstance extends SerializerInstance { - def serialize[T: ClassTag](t: T): ByteBuffer = { - throw new UnsupportedOperationException() - } - - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - throw new UnsupportedOperationException() - } - - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { - throw new UnsupportedOperationException() - } - - def serializeStream(s: OutputStream): SerializationStream = { - new WPRSerializationStream(s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new WPRDeserializationStream(s) - } -} - -class WPRSerializationStream(os: OutputStream) extends SerializationStream { - val dos = new DataOutputStream(os) - - def writeObject[T: ClassTag](t: T): SerializationStream = t match { - case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { - case links: Array[String] => { - dos.writeInt(0) // links - dos.writeUTF(id) - dos.writeInt(links.length) - for (link <- links) { - dos.writeUTF(link) - } - this - } - case rank: Double => { - dos.writeInt(1) // rank - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - case (id: String, rank: Double) => { - dos.writeInt(2) // rank without wrapper - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - - def flush() { dos.flush() } - def close() { dos.close() } -} - -class WPRDeserializationStream(is: InputStream) extends DeserializationStream { - val dis = new DataInputStream(is) - - def readObject[T: ClassTag](): T = { - val typeId = dis.readInt() - typeId match { - case 0 => { - val id = dis.readUTF() - val numLinks = dis.readInt() - val links = new Array[String](numLinks) - for (i <- 0 until numLinks) { - val link = dis.readUTF() - links(i) = link - } - (id, ArrayBuffer(links)).asInstanceOf[T] - } - case 1 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, ArrayBuffer(rank)).asInstanceOf[T] - } - case 2 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, rank).asInstanceOf[T] - } - } - } - - def close() { dis.close() } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 409721b01c8fd..8dd6c9706e7df 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import scala.collection.mutable @@ -151,3 +152,4 @@ object Analytics extends Logging { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala index f6f8d9f90c275..da3ffca1a6f2a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx /** @@ -42,3 +43,4 @@ object LiveJournalPageRank { Analytics.main(args.patch(0, List("pagerank"), 0)) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 3ec20d594b784..46e52aacd90bb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import org.apache.spark.SparkContext._ @@ -128,3 +129,4 @@ object SynthBenchmark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index 6c0af20461d3b..14b358d46f6ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -110,3 +111,4 @@ object CrossValidatorExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 54e4073941056..f28671f7869fc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -355,3 +356,4 @@ object DecisionTreeExample { println(s" Root mean squared error (RMSE): $RMSE") } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 2a2d0677272a0..78f31b4ffe56a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -15,11 +15,13 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams} import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.{DataFrame, Row, SQLContext} @@ -106,10 +108,12 @@ private trait MyLogisticRegressionParams extends ClassifierParams { * * NOTE: This is private since it is an example. In practice, you may not want it to be private. */ -private class MyLogisticRegression +private class MyLogisticRegression(override val uid: String) extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel] with MyLogisticRegressionParams { + def this() = this(Identifiable.randomUID("myLogReg")) + setMaxIter(100) // Initialize // The parameter setter is in this class since it should return type MyLogisticRegression. @@ -125,8 +129,10 @@ private class MyLogisticRegression val weights = Vectors.zeros(numFeatures) // Learning would happen here. // Create a model, and return it. - new MyLogisticRegressionModel(this, weights) + new MyLogisticRegressionModel(uid, weights).setParent(this) } + + override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra) } /** @@ -135,7 +141,7 @@ private class MyLogisticRegression * NOTE: This is private since it is an example. In practice, you may not want it to be private. */ private class MyLogisticRegressionModel( - override val parent: MyLogisticRegression, + override val uid: String, val weights: Vector) extends ClassificationModel[Vector, MyLogisticRegressionModel] with MyLogisticRegressionParams { @@ -173,6 +179,7 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(parent, weights), extra) + copyValues(new MyLogisticRegressionModel(uid, weights), extra) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index 33905277c7341..f4a15f806ea81 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -236,3 +237,4 @@ object GBTExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala new file mode 100644 index 0000000000000..b73299fb12d3f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import org.apache.spark.sql.DataFrame + +/** + * An example runner for linear regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LinearRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LinearRegressionExample --regParam 0.15 --elasticNetParam 1.0 \ + * data/mllib/sample_linear_regression_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LinearRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LinearRegressionExample") { + head("LinearRegressionExample: an example Linear Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LinearRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LinearRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "regression", params.fracTest) + + val lir = new LinearRegression() + .setFeaturesCol("features") + .setLabelCol("label") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + // Train the model + val startTime = System.nanoTime() + val lirModel = lir.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Print the weights and intercept for linear regression. + println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label") + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, test, "label") + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala new file mode 100644 index 0000000000000..7682557127b51 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.feature.StringIndexer +import org.apache.spark.sql.DataFrame + +/** + * An example runner for logistic regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LogisticRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_libsvm_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LogisticRegressionExample --regParam 0.3 --elasticNetParam 0.8 \ + * data/mllib/sample_libsvm_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LogisticRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + fitIntercept: Boolean = true, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LogisticRegressionExample") { + head("LogisticRegressionExample: an example Logistic Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Boolean]("fitIntercept") + .text(s"whether to fit an intercept term, default: ${defaultParams.fitIntercept}") + .action((x, c) => c.copy(fitIntercept = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LogisticRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LogisticRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "classification", params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol("indexedLabel") + stages += labelIndexer + + val lor = new LogisticRegression() + .setFeaturesCol("features") + .setLabelCol("indexedLabel") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + stages += lor + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + val lorModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel] + // Print the weights and intercept for logistic regression. + println(s"Weights: ${lorModel.weights} Intercept: ${lorModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel") + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, "indexedLabel") + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index 25f21113bf622..cd411397a4b9d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scopt.OptionParser @@ -178,3 +179,4 @@ object MovieLensALS { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala new file mode 100644 index 0000000000000..bab31f585b0ef --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} + +import scopt.OptionParser + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext + +/** + * An example runner for Multiclass to Binary Reduction with One Vs Rest. + * The example uses Logistic Regression as the base classifier. All parameters that + * can be specified on the base classifier can be passed in to the runner options. + * Run with + * {{{ + * ./bin/run-example ml.OneVsRestExample [options] + * }}} + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.OneVsRestExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object OneVsRestExample { + + case class Params private[ml] ( + input: String = null, + testInput: Option[String] = None, + maxIter: Int = 100, + tol: Double = 1E-6, + fitIntercept: Boolean = true, + regParam: Option[Double] = None, + elasticNetParam: Option[Double] = None, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("OneVsRest Example") { + head("OneVsRest Example: multiclass to binary reduction using OneVsRest") + opt[String]("input") + .text("input path to labeled examples. This path must be specified") + .required() + .action((x, c) => c.copy(input = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text("input path to test dataset. If given, option fracTest is ignored") + .action((x, c) => c.copy(testInput = Some(x))) + opt[Int]("maxIter") + .text(s"maximum number of iterations for Logistic Regression." + + s" default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations for Logistic Regression." + + s" default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Boolean]("fitIntercept") + .text(s"fit intercept for Logistic Regression." + + s" default: ${defaultParams.fitIntercept}") + .action((x, c) => c.copy(fitIntercept = x)) + opt[Double]("regParam") + .text(s"the regularization parameter for Logistic Regression.") + .action((x, c) => c.copy(regParam = Some(x))) + opt[Double]("elasticNetParam") + .text(s"the ElasticNet mixing parameter for Logistic Regression.") + .action((x, c) => c.copy(elasticNetParam = Some(x))) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + private def run(params: Params) { + val conf = new SparkConf().setAppName(s"OneVsRestExample with $params") + val sc = new SparkContext(conf) + val inputData = MLUtils.loadLibSVMFile(sc, params.input) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // compute the train/test split: if testInput is not provided use part of input. + val data = params.testInput match { + case Some(t) => { + // compute the number of features in the training set. + val numFeatures = inputData.first().features.size + val testData = MLUtils.loadLibSVMFile(sc, t, numFeatures) + Array[RDD[LabeledPoint]](inputData, testData) + } + case None => { + val f = params.fracTest + inputData.randomSplit(Array(1 - f, f), seed = 12345) + } + } + val Array(train, test) = data.map(_.toDF().cache()) + + // instantiate the base classifier + val classifier = new LogisticRegression() + .setMaxIter(params.maxIter) + .setTol(params.tol) + .setFitIntercept(params.fitIntercept) + + // Set regParam, elasticNetParam if specified in params + params.regParam.foreach(classifier.setRegParam) + params.elasticNetParam.foreach(classifier.setElasticNetParam) + + // instantiate the One Vs Rest Classifier. + + val ovr = new OneVsRest() + ovr.setClassifier(classifier) + + // train the multiclass model. + val (trainingDuration, ovrModel) = time(ovr.fit(train)) + + // score the model on test data. + val (predictionDuration, predictions) = time(ovrModel.transform(test)) + + // evaluate the model + val predictionsAndLabels = predictions.select("prediction", "label") + .map(row => (row.getDouble(0), row.getDouble(1))) + + val metrics = new MulticlassMetrics(predictionsAndLabels) + + val confusionMatrix = metrics.confusionMatrix + + // compute the false positive rate per label + val predictionColSchema = predictions.schema("prediction") + val numClasses = MetadataUtils.getNumClasses(predictionColSchema).get + val fprs = Range(0, numClasses).map(p => (p, metrics.falsePositiveRate(p.toDouble))) + + println(s" Training Time ${trainingDuration} sec\n") + + println(s" Prediction Time ${predictionDuration} sec\n") + + println(s" Confusion Matrix\n ${confusionMatrix.toString}\n") + + println("label\tfpr") + + println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n")) + + sc.stop() + } + + private def time[R](block: => R): (Long, R) = { + val t0 = System.nanoTime() + val result = block // call-by-name + val t1 = System.nanoTime() + (NANO.toSeconds(t1 - t0), result) + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index 9f7cad68a4594..109178f4137b2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -244,3 +245,4 @@ object RandomForestExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index e8a991f50e338..58d7b67674ff7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -87,7 +88,7 @@ object SimpleParamsExample { LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) // Make predictions on test data using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. + // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. model2.transform(test.toDF()) @@ -100,3 +101,4 @@ object SimpleParamsExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 6772efd2c581c..960280137cbf9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.beans.BeanInfo @@ -64,7 +65,7 @@ object SimpleTextClassificationPipeline { .setOutputCol("features") val lr = new LogisticRegression() .setMaxIter(10) - .setRegParam(0.01) + .setRegParam(0.001) val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) @@ -75,7 +76,7 @@ object SimpleTextClassificationPipeline { val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), - Document(6L, "mapreduce spark"), + Document(6L, "spark hadoop spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. @@ -89,3 +90,4 @@ object SimpleTextClassificationPipeline { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index a113653810b93..1a4016f76c2ad 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -153,3 +154,4 @@ object BinaryClassification { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index e49129c4e7844..026d4ecc6d10a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -91,3 +92,4 @@ object Correlations { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala index cb1abbd18fd4d..69988cc1b9334 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -106,3 +107,4 @@ object CosineSimilarity { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index e943d6c889fab..dc13f82488af7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.io.File @@ -103,10 +104,10 @@ object DatasetExample { tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") - df.saveAsParquetFile(outputDir) + df.write.parquet(outputDir) println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.parquetFile(outputDir) + val newDataset = sqlContext.read.parquet(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } @@ -119,3 +120,4 @@ object DatasetExample { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index b0613632c9946..57ffe3dd2524f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.language.reflectiveCalls @@ -22,7 +23,6 @@ import scala.language.reflectiveCalls import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -354,7 +354,11 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. + * + * This is just for demo purpose. In general, don't copy this code because it is NOT efficient + * due to the use of structural types, which leads to one reflection call per record. */ + // scalastyle:off structural.type private[mllib] def meanSquaredError( model: { def predict(features: Vector): Double }, data: RDD[LabeledPoint]): Double = { @@ -363,4 +367,6 @@ object DecisionTreeRunner { err * err }.mean() } + // scalastyle:on structural.type } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index df76b45e50810..1fce4ba7efd60 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -40,23 +41,23 @@ object DenseGaussianMixture { private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") - val ctx = new SparkContext(conf) - + val ctx = new SparkContext(conf) + val data = ctx.textFile(inputFile).map { line => Vectors.dense(line.trim.split(' ').map(_.toDouble)) }.cache() - + val clusters = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) .setMaxIterations(maxIterations) .run(data) - + for (i <- 0 until clusters.k) { - println("weight=%f\nmu=%s\nsigma=\n%s\n" format + println("weight=%f\nmu=%s\nsigma=\n%s\n" format (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) } - + println("Cluster labels (first <= 100):") val clusterLabels = clusters.predict(data) clusterLabels.take(100).foreach { x => @@ -65,3 +66,4 @@ object DenseGaussianMixture { println() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 14cc5cbb679c5..380d85d60e7b4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -107,3 +108,4 @@ object DenseKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 13f24a1e59610..14b930550d554 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -80,3 +81,4 @@ object FPGrowthExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 7416fb5a40848..e16a6bf033574 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -145,3 +146,4 @@ object GradientBoostedTreesRunner { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 31d629f853161..75b0f69cf91aa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.text.BreakIterator @@ -302,3 +303,4 @@ private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Se } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 6a456ba7ec07b..8878061a0970b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -134,3 +135,4 @@ object LinearRegression { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 99588b0984ab2..e43a6f2864c73 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.collection.mutable @@ -189,3 +190,4 @@ object MovieLensALS { math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).mean()) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 6e4e2d07f284b..5f839c75dd581 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -97,3 +98,4 @@ object MultivariateSummarizer { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 6d8b806569dfd..0723223954610 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -154,4 +155,4 @@ object PowerIterationClusteringExample { coeff * math.exp(expCoeff * ssquares) } } - +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala index 924b586e3af99..bee85ba0f9969 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.random.RandomRDDs @@ -58,3 +59,4 @@ object RandomRDDGeneration { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index 663c12734af68..6963f43e082c4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.util.MLUtils @@ -125,3 +126,4 @@ object SampledRDDs { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index f1ff4e6911f5e..f81fc292a3bd1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -100,3 +101,4 @@ object SparseNaiveBayes { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala index 8bb12d2ee9ed2..af03724a8ac62 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.SparkConf @@ -75,3 +76,4 @@ object StreamingKMeansExample { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index 1a95048bbfe2d..b4a5dca031abd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -69,3 +70,4 @@ object StreamingLinearRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala index e1998099c2d78..b42f4cb5f9338 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -71,3 +72,4 @@ object StreamingLogisticRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 3cd9cb743e309..464fbd385ab5d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnyPCA { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 4d6690318615a..65b4bc46f0266 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnySVD { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index a11890d6f2b1c..3ebb112fc069e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -36,22 +36,21 @@ object AvroConversionUtil extends Serializable { return null } schema.getType match { - case UNION => unpackUnion(obj, schema) - case ARRAY => unpackArray(obj, schema) - case FIXED => unpackFixed(obj, schema) - case MAP => unpackMap(obj, schema) - case BYTES => unpackBytes(obj) - case RECORD => unpackRecord(obj) - case STRING => obj.toString - case ENUM => obj.toString - case NULL => obj + case UNION => unpackUnion(obj, schema) + case ARRAY => unpackArray(obj, schema) + case FIXED => unpackFixed(obj, schema) + case MAP => unpackMap(obj, schema) + case BYTES => unpackBytes(obj) + case RECORD => unpackRecord(obj) + case STRING => obj.toString + case ENUM => obj.toString + case NULL => obj case BOOLEAN => obj - case DOUBLE => obj - case FLOAT => obj - case INT => obj - case LONG => obj - case other => throw new SparkException( - s"Unknown Avro schema type ${other.getName}") + case DOUBLE => obj + case FLOAT => obj + case INT => obj + case LONG => obj + case other => throw new SparkException(s"Unknown Avro schema type ${other.getName}") } } diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala index 273bee0a8b30f..90d48a64106c7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -18,20 +18,34 @@ package org.apache.spark.examples.pythonconverters import scala.collection.JavaConversions._ +import scala.util.parsing.json.JSONObject import org.apache.spark.api.python.Converter import org.apache.hadoop.hbase.client.{Put, Result} import org.apache.hadoop.hbase.io.ImmutableBytesWritable import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.KeyValue.Type +import org.apache.hadoop.hbase.CellUtil /** - * Implementation of [[org.apache.spark.api.python.Converter]] that converts an - * HBase Result to a String + * Implementation of [[org.apache.spark.api.python.Converter]] that converts all + * the records in an HBase Result to a String */ class HBaseResultToStringConverter extends Converter[Any, String] { override def convert(obj: Any): String = { + import collection.JavaConverters._ val result = obj.asInstanceOf[Result] - Bytes.toStringBinary(result.value()) + val output = result.listCells.asScala.map(cell => + Map( + "row" -> Bytes.toStringBinary(CellUtil.cloneRow(cell)), + "columnFamily" -> Bytes.toStringBinary(CellUtil.cloneFamily(cell)), + "qualifier" -> Bytes.toStringBinary(CellUtil.cloneQualifier(cell)), + "timestamp" -> cell.getTimestamp.toString, + "type" -> Type.codeToType(cell.getTypeByte).toString, + "value" -> Bytes.toStringBinary(CellUtil.cloneValue(cell)) + ) + ) + output.map(JSONObject(_).toString()).mkString("\n") } } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 6331d1c0060f8..2cc56f04e5c1f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} @@ -58,10 +59,10 @@ object RDDRelation { df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. - df.saveAsParquetFile("pair.parquet") + df.write.parquet("pair.parquet") // Read in parquet file. Parquet files are self-describing so the schmema is preserved. - val parquetFile = sqlContext.parquetFile("pair.parquet") + val parquetFile = sqlContext.read.parquet("pair.parquet") // Queries can be run using the DSL on parequet files just like the original RDD. parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) @@ -73,3 +74,4 @@ object RDDRelation { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index b7ba60ec28155..bf40bd1ef13df 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql.hive import com.google.common.io.{ByteStreams, Files} @@ -77,3 +78,4 @@ object HiveFromSpark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 92867b44be138..e9c9907198769 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import scala.collection.mutable.LinkedList @@ -104,10 +105,8 @@ extends Actor with ActorHelper { object FeederActor { def main(args: Array[String]) { - if(args.length < 2){ - System.err.println( - "Usage: FeederActor \n" - ) + if (args.length < 2){ + System.err.println("Usage: FeederActor \n") System.exit(1) } val Seq(host, port) = args.toSeq @@ -172,3 +171,4 @@ object ActorWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 30269a7ccae97..28e9bf520e568 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.{InputStreamReader, BufferedReader, InputStream} @@ -100,3 +101,4 @@ class CustomReceiver(host: String, port: Int) } } } +// scalastyle:on println diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala similarity index 95% rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala rename to examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index 11a8cf09533ce..bd78526f8c299 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import kafka.serializer.StringDecoder @@ -51,7 +52,7 @@ object DirectKafkaWordCount { // Create context with 2 second batch interval val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") - val ssc = new StreamingContext(sparkConf, Seconds(2)) + val ssc = new StreamingContext(sparkConf, Seconds(2)) // Create direct kafka stream with brokers and topics val topicsSet = topics.split(",").toSet @@ -70,3 +71,4 @@ object DirectKafkaWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala index 20e7df7c45b1b..91e52e4eff5a7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -66,3 +67,4 @@ object FlumeEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala index 1cc8c8d5c23b6..2bdbc37e2a289 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -65,3 +66,4 @@ object FlumePollingEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala index 4b4667fec44e6..1f282d437dc38 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -53,3 +54,4 @@ object HdfsWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala similarity index 94% rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala rename to examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index f407367a54f6c..b40d17e9c2fa3 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.util.HashMap @@ -49,10 +50,10 @@ object KafkaWordCount { val Array(zkQuorum, group, topics, numThreads) = args val sparkConf = new SparkConf().setAppName("KafkaWordCount") - val ssc = new StreamingContext(sparkConf, Seconds(2)) + val ssc = new StreamingContext(sparkConf, Seconds(2)) ssc.checkpoint("checkpoint") - val topicMap = topics.split(",").map((_,numThreads.toInt)).toMap + val topicMap = topics.split(",").map((_, numThreads.toInt)).toMap val lines = KafkaUtils.createStream(ssc, zkQuorum, group, topicMap).map(_._2) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1L)) @@ -96,8 +97,9 @@ object KafkaWordCountProducer { producer.send(message) } - Thread.sleep(100) + Thread.sleep(1000) } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 85b9a54b40baf..d772ae309f40d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.eclipse.paho.client.mqttv3._ @@ -40,7 +41,7 @@ object MQTTPublisher { StreamingExamples.setStreamingLogLevels() val Seq(brokerUrl, topic) = args.toSeq - + var client: MqttClient = null try { @@ -49,7 +50,7 @@ object MQTTPublisher { client.connect() - val msgtopic = client.getTopic(topic) + val msgtopic = client.getTopic(topic) val msgContent = "hello mqtt demo for spark streaming" val message = new MqttMessage(msgContent.getBytes("utf-8")) @@ -59,10 +60,10 @@ object MQTTPublisher { println(s"Published data. topic: ${msgtopic.getName()}; Message: $message") } catch { case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(10) + Thread.sleep(10) println("Queue is full, wait for to consume data from the message queue") - } - } + } + } } catch { case e: MqttException => println("Exception Caught: " + e) } finally { @@ -96,8 +97,10 @@ object MQTTWordCount { def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println System.err.println( "Usage: MQTTWordCount ") + // scalastyle:on println System.exit(1) } @@ -107,9 +110,10 @@ object MQTTWordCount { val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) val words = lines.flatMap(x => x.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - + wordCounts.print() ssc.start() ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index 2cd8073dada14..9a57fe286d1ae 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -57,3 +58,4 @@ object NetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index a9aaa445bccb6..5322929d177b4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -58,3 +59,4 @@ object RawNetworkGrep { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 751b30ea15782..9916882e4f94a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.File @@ -108,3 +109,4 @@ object RecoverableNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index 5a6b9216a3fbc..ed617754cbf1c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -99,3 +100,4 @@ object SQLContextSingleton { instance } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 345d0bc441351..02ba1c2eed0f7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -78,3 +79,4 @@ object StatefulNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala index c10de84a80ffe..825c671a929b1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird._ @@ -113,3 +114,4 @@ object TwitterAlgebirdCMS { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala index 62db5e663b8af..49826ede70418 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird.HyperLogLogMonoid @@ -90,3 +91,4 @@ object TwitterAlgebirdHLL { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala index f253d75b279f7..49cee1b43c2dc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -82,3 +83,4 @@ object TwitterPopularTags { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index e99d1baa72b9f..6ac9a72c37941 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import akka.actor.ActorSystem @@ -97,3 +98,4 @@ object ZeroMQWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 54d996b8ac990..bea7a47cb2855 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import java.net.ServerSocket @@ -57,8 +58,7 @@ object PageViewGenerator { 404 -> .05) val userZipCode = Map(94709 -> .5, 94117 -> .5) - val userID = Map((1 to 100).map(_ -> .01):_*) - + val userID = Map((1 to 100).map(_ -> .01) : _*) def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { val rand = new Random().nextDouble() @@ -109,3 +109,4 @@ object PageViewGenerator { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index fbacaee98690f..ec7d39da8b2e9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import org.apache.spark.SparkContext._ @@ -107,3 +108,4 @@ object PageViewStream { ssc.start() } } +// scalastyle:on println diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml new file mode 100644 index 0000000000000..13189595d1d6c --- /dev/null +++ b/external/flume-assembly/pom.xml @@ -0,0 +1,158 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-flume-assembly_2.10 + jar + Spark Project External Flume Assembly + http://spark.apache.org/ + + + provided + streaming-flume-assembly + + + + + org.apache.spark + spark-streaming-flume_${scala.binary.version} + ${project.version} + + + org.mortbay.jetty + jetty + + + org.mortbay.jetty + jetty-util + + + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + commons-codec + commons-codec + provided + + + commons-net + commons-net + provided + + + com.google.protobuf + protobuf-java + provided + + + org.apache.avro + avro + provided + + + org.apache.avro + avro-ipc + provided + + + org.scala-lang + scala-library + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + + + flume-provided + + provided + + + + + diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 1f3e619d97a24..0664cfb2021e1 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -35,22 +35,49 @@ http://spark.apache.org/ - - org.apache.commons - commons-lang3 - org.apache.flume flume-ng-sdk + + + + com.google.guava + guava + + + + org.apache.thrift + libthrift + + org.apache.flume flume-ng-core + + + com.google.guava + guava + + + org.apache.thrift + libthrift + + org.scala-lang scala-library + + + com.google.guava + guava + test + + + + diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala index 17cbc6707b5ea..d87b86932dd41 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala @@ -113,7 +113,9 @@ private[sink] object Logging { try { // We use reflection here to handle the case where users remove the // slf4j-to-jul bridge order to route their logs to JUL. + // scalastyle:off classforname val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + // scalastyle:on classforname bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] if (!installed) { diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index fd01807fc3ac4..719fca0938b3a 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -16,14 +16,13 @@ */ package org.apache.spark.streaming.flume.sink +import java.util.UUID import java.util.concurrent.{CountDownLatch, Executors} import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.flume.Channel -import org.apache.commons.lang3.RandomStringUtils /** * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process @@ -45,8 +44,7 @@ import org.apache.commons.lang3.RandomStringUtils private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel, val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging { val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, - new ThreadFactoryBuilder().setDaemon(true) - .setNameFormat("Spark Sink Processor Thread - %d").build())) + new SparkSinkThreadFactory("Spark Sink Processor Thread - %d"))) // Protected by `sequenceNumberToProcessor` private val sequenceNumberToProcessor = mutable.HashMap[CharSequence, TransactionProcessor]() // This sink will not persist sequence numbers and reuses them if it gets restarted. @@ -55,7 +53,7 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha // Since the new txn may not have the same sequence number we must guard against accidentally // committing a new transaction. To reduce the probability of that happening a random string is // prepended to the sequence number. Does not change for life of sink - private val seqBase = RandomStringUtils.randomAlphanumeric(8) + private val seqBase = UUID.randomUUID().toString.substring(0, 8) private val seqCounter = new AtomicLong(0) // Protected by `sequenceNumberToProcessor` diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala new file mode 100644 index 0000000000000..845fc8debda75 --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.util.concurrent.ThreadFactory +import java.util.concurrent.atomic.AtomicLong + +/** + * Thread factory that generates daemon threads with a specified name format. + */ +private[sink] class SparkSinkThreadFactory(nameFormat: String) extends ThreadFactory { + + private val threadId = new AtomicLong() + + override def newThread(r: Runnable): Thread = { + val t = new Thread(r, nameFormat.format(threadId.incrementAndGet())) + t.setDaemon(true) + t + } + +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index ea45b14294df9..7ad43b1d7b0a0 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -143,7 +143,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, eventBatch.setErrorMsg(msg) } else { // At this point, the events are available, so fill them into the event batch - eventBatch = new EventBatch("",seqNum, events) + eventBatch = new EventBatch("", seqNum, events) } }) } catch { diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 650b2fbe1c142..fa43629d49771 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -24,16 +24,24 @@ import scala.collection.JavaConversions._ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.event.EventBuilder import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory + +// Due to MNG-1378, there is not a way to include test dependencies transitively. +// We cannot include Spark core tests as a dependency here because it depends on +// Spark core main, which has too many dependencies to require here manually. +// For this reason, we continue to use FunSuite and ignore the scalastyle checks +// that fail if this is detected. +//scalastyle:off import org.scalatest.FunSuite class SparkSinkSuite extends FunSuite { +//scalastyle:on + val eventsPerBatch = 1000 val channelCapacity = 5000 @@ -185,9 +193,8 @@ class SparkSinkSuite extends FunSuite { count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = { (1 to count).map(_ => { - lazy val channelFactoryExecutor = - Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). - setNameFormat("Flume Receiver Channel Thread - %d").build()) + lazy val channelFactoryExecutor = Executors.newCachedThreadPool( + new SparkSinkThreadFactory("Flume Receiver Channel Thread - %d")) lazy val channelFactory = new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) val transceiver = new NettyTransceiver(address, channelFactory) diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 8df7edbdcad33..14f7daaf417e0 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming-flume-sink_${scala.binary.version} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala index dc629df4f4ac2..65c49c131518b 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -60,7 +60,7 @@ private[streaming] object EventTransformer extends Logging { out.write(body) val numHeaders = headers.size() out.writeInt(numHeaders) - for ((k,v) <- headers) { + for ((k, v) <- headers) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 60e2994431b38..1e32a365a1eee 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -152,9 +152,9 @@ class FlumeReceiver( val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(), Executors.newCachedThreadPool()) val channelPipelineFactory = new CompressionChannelPipelineFactory() - + new NettyServer( - responder, + responder, new InetSocketAddress(host, port), channelFactory, channelPipelineFactory, @@ -188,12 +188,12 @@ class FlumeReceiver( override def preferredLocation: Option[String] = Option(host) - /** A Netty Pipeline factory that will decompress incoming data from + /** A Netty Pipeline factory that will decompress incoming data from * and the Netty client and compress data going back to the client. * * The compression on the return is required because Flume requires - * a successful response to indicate it can remove the event/batch - * from the configured channel + * a successful response to indicate it can remove the event/batch + * from the configured channel */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 92fa5b41be89e..583e7dca317ad 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -110,7 +110,7 @@ private[streaming] class FlumePollingReceiver( } /** - * A wrapper around the transceiver and the Avro IPC API. + * A wrapper around the transceiver and the Avro IPC API. * @param transceiver The transceiver to use for communication with Flume * @param client The client that the callbacks are received on. */ diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala new file mode 100644 index 0000000000000..9d9c3b189415f --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.net.{InetSocketAddress, ServerSocket} +import java.nio.ByteBuffer +import java.util.{List => JList} + +import scala.collection.JavaConversions._ + +import com.google.common.base.Charsets.UTF_8 +import org.apache.avro.ipc.NettyTransceiver +import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.commons.lang3.RandomUtils +import org.apache.flume.source.avro +import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent} +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.socket.SocketChannel +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder} + +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class FlumeTestUtils { + + private var transceiver: NettyTransceiver = null + + private val testPort: Int = findFreePort() + + def getTestPort(): Int = testPort + + /** Find a free port */ + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + /** Send data to the flume receiver */ + def writeInput(input: JList[String], enableCompression: Boolean): Unit = { + val testAddress = new InetSocketAddress("localhost", testPort) + + val inputEvents = input.map { item => + val event = new AvroFlumeEvent + event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) + event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) + event + } + + // if last attempted transceiver had succeeded, close it + close() + + // Create transceiver + transceiver = { + if (enableCompression) { + new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) + } else { + new NettyTransceiver(testAddress) + } + } + + // Create Avro client with the transceiver + val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) + if (client == null) { + throw new AssertionError("Cannot create client") + } + + // Send data + val status = client.appendBatch(inputEvents.toList) + if (status != avro.Status.OK) { + throw new AssertionError("Sent events unsuccessfully") + } + } + + def close(): Unit = { + if (transceiver != null) { + transceiver.close() + transceiver = null + } + } + + /** Class to create socket channel with compression */ + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { + val encoder = new ZlibEncoder(compressionLevel) + pipeline.addFirst("deflater", encoder) + pipeline.addFirst("inflater", new ZlibDecoder()) + super.newChannel(pipeline) + } + } + +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 44dec45c227ca..095bfb0c73a9a 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -18,10 +18,16 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress +import java.io.{DataOutputStream, ByteArrayOutputStream} +import java.util.{List => JList, Map => JMap} +import scala.collection.JavaConversions._ + +import org.apache.spark.api.java.function.PairFunction +import org.apache.spark.api.python.PythonRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -236,3 +242,71 @@ object FlumeUtils { createPollingStream(jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) } } + +/** + * This is a helper class that wraps the methods in FlumeUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's FlumeUtils. + */ +private class FlumeUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel, + enableDecompression: Boolean + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + val dstream = FlumeUtils.createStream(jssc, hostname, port, storageLevel, enableDecompression) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + + def createPollingStream( + jssc: JavaStreamingContext, + hosts: JList[String], + ports: JList[Int], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + assert(hosts.length == ports.length) + val addresses = hosts.zip(ports).map { + case (host, port) => new InetSocketAddress(host, port) + } + val dstream = FlumeUtils.createPollingStream( + jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + +} + +private object FlumeUtilsPythonHelper { + + private def stringMapToByteArray(map: JMap[CharSequence, CharSequence]): Array[Byte] = { + val byteStream = new ByteArrayOutputStream() + val output = new DataOutputStream(byteStream) + try { + output.writeInt(map.size) + map.foreach { kv => + PythonRDD.writeUTF(kv._1.toString, output) + PythonRDD.writeUTF(kv._2.toString, output) + } + byteStream.toByteArray + } + finally { + output.close() + } + } + + private def toByteArrayPairDStream(dstream: JavaReceiverInputDStream[SparkFlumeEvent]): + JavaPairDStream[Array[Byte], Array[Byte]] = { + dstream.mapToPair(new PairFunction[SparkFlumeEvent, Array[Byte], Array[Byte]] { + override def call(sparkEvent: SparkFlumeEvent): (Array[Byte], Array[Byte]) = { + val event = sparkEvent.event + val byteBuffer = event.getBody + val body = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(body) + (stringMapToByteArray(event.getHeaders), body) + } + }) + } +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala new file mode 100644 index 0000000000000..91d63d49dbec3 --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.util.concurrent._ +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer + +import com.google.common.base.Charsets.UTF_8 +import org.apache.flume.event.EventBuilder +import org.apache.flume.Context +import org.apache.flume.channel.MemoryChannel +import org.apache.flume.conf.Configurables + +import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink} + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class PollingFlumeTestUtils { + + private val batchCount = 5 + val eventsPerBatch = 100 + private val totalEventsPerChannel = batchCount * eventsPerBatch + private val channelCapacity = 5000 + + def getTotalEvents: Int = totalEventsPerChannel * channels.size + + private val channels = new ArrayBuffer[MemoryChannel] + private val sinks = new ArrayBuffer[SparkSink] + + /** + * Start a sink and return the port of this sink + */ + def startSingleSink(): Int = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + channels += (channel) + sinks += sink + + sink.getPort() + } + + /** + * Start 2 sinks and return the ports + */ + def startMultipleSinks(): JList[Int] = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val channel2 = new MemoryChannel() + Configurables.configure(channel2, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + val sink2 = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink2, context) + sink2.setChannel(channel2) + sink2.start() + + sinks += sink + sinks += sink2 + channels += channel + channels += channel2 + + sinks.map(_.getPort()) + } + + /** + * Send data and wait until all data has been received + */ + def sendDatAndEnsureAllDataHasBeenReceived(): Unit = { + val executor = Executors.newCachedThreadPool() + val executorCompletion = new ExecutorCompletionService[Void](executor) + + val latch = new CountDownLatch(batchCount * channels.size) + sinks.foreach(_.countdownWhenBatchReceived(latch)) + + channels.foreach(channel => { + executorCompletion.submit(new TxnSubmitter(channel)) + }) + + for (i <- 0 until channels.size) { + executorCompletion.take() + } + + latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. + } + + /** + * A Python-friendly method to assert the output + */ + def assertOutput( + outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = { + require(outputHeaders.size == outputBodies.size) + val eventSize = outputHeaders.size + if (eventSize != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"Expected ${totalEventsPerChannel * channels.size} events, but was $eventSize") + } + var counter = 0 + for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { + val eventBodyToVerify = s"${channels(k).getName}-$i" + val eventHeaderToVerify: JMap[String, String] = Map[String, String](s"test-$i" -> "header") + var found = false + var j = 0 + while (j < eventSize && !found) { + if (eventBodyToVerify == outputBodies.get(j) && + eventHeaderToVerify == outputHeaders.get(j)) { + found = true + counter += 1 + } + j += 1 + } + } + if (counter != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"111 Expected ${totalEventsPerChannel * channels.size} events, but was $counter") + } + } + + def assertChannelsAreEmpty(): Unit = { + channels.foreach(assertChannelIsEmpty) + } + + private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") + queueRemaining.setAccessible(true) + val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") + if (m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] != 5000) { + throw new AssertionError(s"Channel ${channel.getName} is not empty") + } + } + + def close(): Unit = { + sinks.foreach(_.stop()) + sinks.clear() + channels.foreach(_.stop()) + channels.clear() + } + + private class TxnSubmitter(channel: MemoryChannel) extends Callable[Void] { + override def call(): Void = { + var t = 0 + for (i <- 0 until batchCount) { + val tx = channel.getTransaction + tx.begin() + for (j <- 0 until eventsPerBatch) { + channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), + Map[String, String](s"test-$t" -> "header"))) + t += 1 + } + tx.commit() + tx.close() + Thread.sleep(500) // Allow some time for the events to reach + } + null + } + } + +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 43c1b865b64a1..d5f9a0aa38f9f 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -18,50 +18,39 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.concurrent.duration._ +import scala.language.postfixOps -import org.apache.flume.Context -import org.apache.flume.channel.MemoryChannel -import org.apache.flume.conf.Configurables -import org.apache.flume.event.EventBuilder +import com.google.common.base.Charsets.UTF_8 +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfter, FunSuite} - -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} -import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.{ManualClock, Utils} -class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { - val batchCount = 5 - val eventsPerBatch = 100 - val totalEventsPerChannel = batchCount * eventsPerBatch - val channelCapacity = 5000 val maxAttempts = 5 val batchDuration = Seconds(1) val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - def beforeFunction() { - logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - } - - before(beforeFunction()) + val utils = new PollingFlumeTestUtils - ignore("flume polling test") { + test("flume polling test") { testMultipleTimes(testFlumePolling) } - ignore("flume polling test multiple hosts") { + test("flume polling test multiple hosts") { testMultipleTimes(testFlumePollingMultipleHost) } @@ -86,69 +75,33 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging } private def testFlumePolling(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", sink.getPort())), - StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) - outputStream.register() - ssc.start() + try { + val port = utils.startSingleSink() - writeAndVerify(Seq(channel), ssc, outputBuffer) - assertChannelIsEmpty(channel) - sink.stop() - channel.stop() + writeAndVerify(Seq(port)) + utils.assertChannelsAreEmpty() + } finally { + utils.close() + } } private def testFlumePollingMultipleHost(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val channel2 = new MemoryChannel() - Configurables.configure(channel2, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - val sink2 = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink2, context) - sink2.setChannel(channel2) - sink2.start() + try { + val ports = utils.startMultipleSinks() + writeAndVerify(ports) + utils.assertChannelsAreEmpty() + } finally { + utils.close() + } + } + def writeAndVerify(sinkPorts: Seq[Int]): Unit = { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) - val addresses = Seq(sink.getPort(), sink2.getPort()).map(new InetSocketAddress("localhost", _)) + val addresses = sinkPorts.map(port => new InetSocketAddress("localhost", port)) val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - eventsPerBatch, 5) + utils.eventsPerBatch, 5) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) @@ -156,87 +109,21 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging ssc.start() try { - writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) - assertChannelIsEmpty(channel) - assertChannelIsEmpty(channel2) - } finally { - sink.stop() - sink2.stop() - channel.stop() - channel2.stop() - } - } - - def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext, - outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]]) { - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val executor = Executors.newCachedThreadPool() - val executorCompletion = new ExecutorCompletionService[Void](executor) - channels.map(channel => { - executorCompletion.submit(new TxnSubmitter(channel, clock)) - }) - for (i <- 0 until channels.size) { - executorCompletion.take() - } - val startTime = System.currentTimeMillis() - while (outputBuffer.size < batchCount * channels.size && - System.currentTimeMillis() - startTime < 15000) { - logInfo("output.size = " + outputBuffer.size) - Thread.sleep(100) - } - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < 15000, "Operation timed out after " + timeTaken + " ms") - logInfo("Stopping context") - ssc.stop() - - val flattenedBuffer = outputBuffer.flatten - assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) - var counter = 0 - for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { - val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + - String.valueOf(i)).getBytes("utf-8"), - Map[String, String]("test-" + i.toString -> "header")) - var found = false - var j = 0 - while (j < flattenedBuffer.size && !found) { - val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") - if (new String(eventToVerify.getBody, "utf-8") == strToCompare && - eventToVerify.getHeaders.get("test-" + i.toString) - .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { - found = true - counter += 1 - } - j += 1 + utils.sendDatAndEnsureAllDataHasBeenReceived() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds) + + // The eventually is required to ensure that all data in the batch has been processed. + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val flattenOutputBuffer = outputBuffer.flatten + val headers = flattenOutputBuffer.map(_.event.getHeaders.map { + case kv => (kv._1.toString, kv._2.toString) + }).map(mapAsJavaMap) + val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) + utils.assertOutput(headers, bodies) } - } - assert(counter === totalEventsPerChannel * channels.size) - } - - def assertChannelIsEmpty(channel: MemoryChannel): Unit = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") - queueRemaining.setAccessible(true) - val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") - assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000) - } - - private class TxnSubmitter(channel: MemoryChannel, clock: ManualClock) extends Callable[Void] { - override def call(): Void = { - var t = 0 - for (i <- 0 until batchCount) { - val tx = channel.getTransaction - tx.begin() - for (j <- 0 until eventsPerBatch) { - channel.put(EventBuilder.withBody((channel.getName + " - " + String.valueOf(t)).getBytes( - "utf-8"), - Map[String, String]("test-" + t.toString -> "header"))) - t += 1 - } - tx.commit() - tx.close() - Thread.sleep(500) // Allow some time for the events to reach - clock.advance(batchDuration.milliseconds) - } - null + } finally { + ssc.stop() } } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 39e6754c81dbf..5bc4cdf65306c 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,46 +17,26 @@ package org.apache.spark.streaming.flume -import java.net.{InetSocketAddress, ServerSocket} -import java.nio.ByteBuffer - import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.base.Charsets -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.apache.commons.lang3.RandomUtils -import org.apache.flume.source.avro -import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory import org.jboss.netty.handler.codec.compression._ -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} -import org.apache.spark.util.Utils -class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") - var ssc: StreamingContext = null - var transceiver: NettyTransceiver = null - - after { - if (ssc != null) { - ssc.stop() - } - if (transceiver != null) { - transceiver.close() - } - } test("flume input stream") { testFlumeStream(testCompression = false) @@ -69,19 +49,29 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L /** Run test on flume stream */ private def testFlumeStream(testCompression: Boolean): Unit = { val input = (1 to 100).map { _.toString } - val testPort = findFreePort() - val outputBuffer = startContext(testPort, testCompression) - writeAndVerify(input, testPort, outputBuffer, testCompression) - } + val utils = new FlumeTestUtils + try { + val outputBuffer = startContext(utils.getTestPort(), testCompression) - /** Find a free port */ - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, conf)._2 + eventually(timeout(10 seconds), interval(100 milliseconds)) { + utils.writeInput(input, testCompression) + } + + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val outputEvents = outputBuffer.flatten.map { _.event } + outputEvents.foreach { + event => + event.getHeaders.get("test") should be("header") + } + val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) + output should be (input) + } + } finally { + if (ssc != null) { + ssc.stop() + } + utils.close() + } } /** Setup and start the streaming context */ @@ -98,58 +88,6 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L outputBuffer } - /** Send data to the flume receiver and verify whether the data was received */ - private def writeAndVerify( - input: Seq[String], - testPort: Int, - outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]], - enableCompression: Boolean - ) { - val testAddress = new InetSocketAddress("localhost", testPort) - - val inputEvents = input.map { item => - val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes(Charsets.UTF_8))) - event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) - event - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - // if last attempted transceiver had succeeded, close it - if (transceiver != null) { - transceiver.close() - transceiver = null - } - - // Create transceiver - transceiver = { - if (enableCompression) { - new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) - } else { - new NettyTransceiver(testAddress) - } - } - - // Create Avro client with the transceiver - val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) - client should not be null - - // Send data - val status = client.appendBatch(inputEvents.toList) - status should be (avro.Status.OK) - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val outputEvents = outputBuffer.flatten.map { _.event } - outputEvents.foreach { - event => - event.getHeaders.get("test") should be("header") - } - val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) - output should be (input) - } - } - /** Class to create socket channel with compression */ private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 0b79f47647f6b..977514fa5a1ec 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -58,6 +58,7 @@ maven-shade-plugin false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-kafka-assembly-${project.version}.jar *:* diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 243ce6eaca658..ded863bd985e8 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.kafka kafka_${scala.binary.version} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 6715aede7928a..48a1933d92f85 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka import scala.annotation.tailrec import scala.collection.mutable -import scala.reflect.{classTag, ClassTag} +import scala.reflect.ClassTag import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata @@ -29,7 +29,7 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset -import org.apache.spark.streaming.scheduler.InputInfo +import org.apache.spark.streaming.scheduler.StreamInputInfo /** * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where @@ -65,6 +65,9 @@ class DirectKafkaInputDStream[ val maxRetries = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRetries", 1) + // Keep this consistent with how other streams are named (e.g. "Flume polling stream [2]") + private[streaming] override def name: String = s"Kafka direct stream [$id]" + protected[streaming] override val checkpointData = new DirectKafkaInputDStreamCheckpointData @@ -116,9 +119,23 @@ class DirectKafkaInputDStream[ val rdd = KafkaRDD[K, V, U, T, R]( context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) - // Report the record number of this batch interval to InputInfoTracker. - val numRecords = rdd.offsetRanges.map(r => r.untilOffset - r.fromOffset).sum - val inputInfo = InputInfo(id, numRecords) + // Report the record number and metadata of this batch interval to InputInfoTracker. + val offsetRanges = currentOffsets.map { case (tp, fo) => + val uo = untilOffsets(tp) + OffsetRange(tp.topic, tp.partition, fo, uo.offset) + } + val description = offsetRanges.filter { offsetRange => + // Don't display empty ranges. + offsetRange.fromOffset != offsetRange.untilOffset + }.map { offsetRange => + s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" + + s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}" + }.mkString("\n") + // Copy offsetRanges to immutable.List to prevent from being modified by the user + val metadata = Map( + "offsets" -> offsetRanges.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> description) + val inputInfo = StreamInputInfo(id, rdd.count, metadata) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) @@ -150,10 +167,7 @@ class DirectKafkaInputDStream[ override def restore() { // this is assuming that the topics don't change during execution, which is true currently val topics = fromOffsets.keySet - val leaders = kc.findLeaders(topics).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) + val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics)) batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 6cf254a7b69cb..8465432c5850f 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -113,7 +113,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { r.flatMap { tm: TopicMetadata => tm.partitionsMetadata.map { pm: PartitionMetadata => TopicAndPartition(tm.topic, pm.partitionId) - } + } } } } @@ -360,6 +360,14 @@ private[spark] object KafkaCluster { type Err = ArrayBuffer[Throwable] + /** If the result is right, return it, otherwise throw SparkException */ + def checkErrors[T](result: Either[Err, T]): T = { + result.fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + } + private[spark] case class LeaderOffset(host: String, port: Int, offset: Long) @@ -402,7 +410,7 @@ object KafkaCluster { } Seq("zookeeper.connect", "group.id").foreach { s => - if (!props.contains(s)) { + if (!props.containsKey(s)) { props.setProperty(s, "") } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index cca0fac0234e1..04b2dc10d39ea 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -135,7 +135,7 @@ class KafkaReceiver[ store((msgAndMetadata.key, msgAndMetadata.message)) } } catch { - case e: Throwable => logError("Error handling message; exiting", e) + case e: Throwable => reportError("Error handling message; exiting", e) } } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index a1b4a12e5d6a0..c5cd2154772ac 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -17,9 +17,11 @@ package org.apache.spark.streaming.kafka +import scala.collection.mutable.ArrayBuffer import scala.reflect.{classTag, ClassTag} import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.partial.{PartialResult, BoundedDouble} import org.apache.spark.rdd.RDD import org.apache.spark.util.NextIterator @@ -60,6 +62,48 @@ class KafkaRDD[ }.toArray } + override def count(): Long = offsetRanges.map(_.count).sum + + override def countApprox( + timeout: Long, + confidence: Double = 0.95 + ): PartialResult[BoundedDouble] = { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } + + override def isEmpty(): Boolean = count == 0L + + override def take(num: Int): Array[R] = { + val nonEmptyPartitions = this.partitions + .map(_.asInstanceOf[KafkaRDDPartition]) + .filter(_.count > 0) + + if (num < 1 || nonEmptyPartitions.size < 1) { + return new Array[R](0) + } + + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.count) + result + (part.index -> taken.toInt) + } else { + result + } + } + + val buf = new ArrayBuffer[R] + val res = context.runJob( + this, + (tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray, + parts.keys.toArray, + allowLocal = true) + res.foreach(buf ++= _) + buf.toArray + } + override def getPreferredLocations(thePart: Partition): Seq[String] = { val part = thePart.asInstanceOf[KafkaRDDPartition] // TODO is additional hostname resolution necessary here diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala index a842a6f17766f..a660d2a00c35d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala @@ -35,4 +35,7 @@ class KafkaRDDPartition( val untilOffset: Long, val host: String, val port: Int -) extends Partition +) extends Partition { + /** Number of messages this partition refers to */ + def count(): Long = untilOffset - fromOffset +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index 6dc4e9517d5a4..b608b75952721 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -195,6 +195,8 @@ private class KafkaTestUtils extends Logging { val props = new Properties() props.put("metadata.broker.list", brokerAddress) props.put("serializer.class", classOf[StringEncoder].getName) + // wait for all in-sync replicas to ack sends + props.put("request.required.acks", "-1") props } @@ -229,21 +231,6 @@ private class KafkaTestUtils extends Logging { tryAgain(1) } - /** Wait until the leader offset for the given topic/partition equals the specified offset */ - def waitUntilLeaderOffset( - topic: String, - partition: Int, - offset: Long): Unit = { - eventually(Time(10000), Time(100)) { - val kc = new KafkaCluster(Map("metadata.broker.list" -> brokerAddress)) - val tp = TopicAndPartition(topic, partition) - val llo = kc.getLatestLeaderOffsets(Set(tp)).right.get.apply(tp).offset - assert( - llo == offset, - s"$topic $partition $offset not reached after timeout") - } - } - private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { case Some(partitionState) => diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index d7cf500577c2a..f3b01bd60b178 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -158,15 +158,31 @@ object KafkaUtils { /** get leaders for the given offset ranges, or throw an exception */ private def leadersForRanges( - kafkaParams: Map[String, String], + kc: KafkaCluster, offsetRanges: Array[OffsetRange]): Map[TopicAndPartition, (String, Int)] = { - val kc = new KafkaCluster(kafkaParams) val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet - val leaders = kc.findLeaders(topics).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) - leaders + val leaders = kc.findLeaders(topics) + KafkaCluster.checkErrors(leaders) + } + + /** Make sure offsets are available in kafka, or throw an exception */ + private def checkOffsets( + kc: KafkaCluster, + offsetRanges: Array[OffsetRange]): Unit = { + val topics = offsetRanges.map(_.topicAndPartition).toSet + val result = for { + low <- kc.getEarliestLeaderOffsets(topics).right + high <- kc.getLatestLeaderOffsets(topics).right + } yield { + offsetRanges.filterNot { o => + low(o.topicAndPartition).offset <= o.fromOffset && + o.untilOffset <= high(o.topicAndPartition).offset + } + } + val badRanges = KafkaCluster.checkErrors(result) + if (!badRanges.isEmpty) { + throw new SparkException("Offsets not available on leader: " + badRanges.mkString(",")) + } } /** @@ -189,9 +205,11 @@ object KafkaUtils { sc: SparkContext, kafkaParams: Map[String, String], offsetRanges: Array[OffsetRange] - ): RDD[(K, V)] = { + ): RDD[(K, V)] = sc.withScope { val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) - val leaders = leadersForRanges(kafkaParams, offsetRanges) + val kc = new KafkaCluster(kafkaParams) + val leaders = leadersForRanges(kc, offsetRanges) + checkOffsets(kc, offsetRanges) new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler) } @@ -224,16 +242,19 @@ object KafkaUtils { offsetRanges: Array[OffsetRange], leaders: Map[TopicAndPartition, Broker], messageHandler: MessageAndMetadata[K, V] => R - ): RDD[R] = { + ): RDD[R] = sc.withScope { + val kc = new KafkaCluster(kafkaParams) val leaderMap = if (leaders.isEmpty) { - leadersForRanges(kafkaParams, offsetRanges) + leadersForRanges(kc, offsetRanges) } else { // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker leaders.map { case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port)) }.toMap } - new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler) + val cleanedHandler = sc.clean(messageHandler) + checkOffsets(kc, offsetRanges) + new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, cleanedHandler) } /** @@ -256,7 +277,7 @@ object KafkaUtils { valueDecoderClass: Class[VD], kafkaParams: JMap[String, String], offsetRanges: Array[OffsetRange] - ): JavaPairRDD[K, V] = { + ): JavaPairRDD[K, V] = jsc.sc.withScope { implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) @@ -294,7 +315,7 @@ object KafkaUtils { offsetRanges: Array[OffsetRange], leaders: JMap[TopicAndPartition, Broker], messageHandler: JFunction[MessageAndMetadata[K, V], R] - ): JavaRDD[R] = { + ): JavaRDD[R] = jsc.sc.withScope { implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) @@ -314,7 +335,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -348,8 +369,9 @@ object KafkaUtils { fromOffsets: Map[TopicAndPartition, Long], messageHandler: MessageAndMetadata[K, V] => R ): InputDStream[R] = { + val cleanedHandler = ssc.sc.clean(messageHandler) new DirectKafkaInputDStream[K, V, KD, VD, R]( - ssc, kafkaParams, fromOffsets, messageHandler) + ssc, kafkaParams, fromOffsets, cleanedHandler) } /** @@ -361,7 +383,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -397,7 +419,7 @@ object KafkaUtils { val kc = new KafkaCluster(kafkaParams) val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) - (for { + val result = for { topicPartitions <- kc.getPartitions(topics).right leaderOffsets <- (if (reset == Some("smallest")) { kc.getEarliestLeaderOffsets(topicPartitions) @@ -410,10 +432,8 @@ object KafkaUtils { } new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( ssc, kafkaParams, fromOffsets, messageHandler) - }).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) + } + KafkaCluster.checkErrors(result) } /** @@ -425,7 +445,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -469,11 +489,12 @@ object KafkaUtils { implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) + val cleanedHandler = jssc.sparkContext.clean(messageHandler.call _) createDirectStream[K, V, KD, VD, R]( jssc.ssc, Map(kafkaParams.toSeq: _*), Map(fromOffsets.mapValues { _.longValue() }.toSeq: _*), - messageHandler.call _ + cleanedHandler ) } @@ -486,7 +507,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -649,4 +670,17 @@ private class KafkaUtilsPythonHelper { TopicAndPartition(topic, partition) def createBroker(host: String, port: JInt): Broker = Broker(host, port) + + def offsetRangesOfKafkaRDD(rdd: RDD[_]): JList[OffsetRange] = { + val parentRDDs = rdd.getNarrowAncestors + val kafkaRDDs = parentRDDs.filter(rdd => rdd.isInstanceOf[KafkaRDD[_, _, _, _, _]]) + + require( + kafkaRDDs.length == 1, + "Cannot get offset ranges, as there may be multiple Kafka RDDs or no Kafka RDD associated" + + "with this RDD, please call this method only on a Kafka RDD.") + + val kafkaRDD = kafkaRDDs.head.asInstanceOf[KafkaRDD[_, _, _, _, _]] + kafkaRDD.offsetRanges.toSeq + } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 9c3dfeb8f5928..f326e7f1f6f8d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -55,6 +55,12 @@ final class OffsetRange private( val untilOffset: Long) extends Serializable { import OffsetRange.OffsetRangeTuple + /** Kafka TopicAndPartition object, for convenience */ + def topicAndPartition(): TopicAndPartition = TopicAndPartition(topic, partition) + + /** Number of messages this OffsetRange refers to */ + def count(): Long = untilOffset - fromOffset + override def equals(obj: Any): Boolean = obj match { case that: OffsetRange => this.topic == that.topic && @@ -69,7 +75,7 @@ final class OffsetRange private( } override def toString(): String = { - s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset]" + s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset])" } /** this is to avoid ClassNotFoundException during checkpoint restore */ diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index ea87e960379f1..75f0dfc22b9dc 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -267,7 +267,7 @@ class ReliableKafkaReceiver[ } } catch { case e: Exception => - logError("Error handling message", e) + reportError("Error handling message", e) } } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index 4c1d6a03eb2b8..02cd24a35906f 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -18,9 +18,8 @@ package org.apache.spark.streaming.kafka; import java.io.Serializable; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Arrays; +import java.util.*; +import java.util.concurrent.atomic.AtomicReference; import scala.Tuple2; @@ -34,6 +33,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; @@ -67,8 +67,10 @@ public void tearDown() { @Test public void testKafkaStream() throws InterruptedException { - String topic1 = "topic1"; - String topic2 = "topic2"; + final String topic1 = "topic1"; + final String topic2 = "topic2"; + // hold a reference to the current offset ranges, so it can be used downstream + final AtomicReference offsetRanges = new AtomicReference(); String[] topic1data = createTopicAndSendData(topic1); String[] topic2data = createTopicAndSendData(topic2); @@ -89,6 +91,17 @@ public void testKafkaStream() throws InterruptedException { StringDecoder.class, kafkaParams, topicToSet(topic1) + ).transformToPair( + // Make sure you can get offset ranges from the rdd + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); + Assert.assertEquals(offsets[0].topic(), topic1); + return rdd; + } + } ).map( new Function, String>() { @Override @@ -116,12 +129,17 @@ public String call(MessageAndMetadata msgAndMd) throws Exception ); JavaDStream unifiedStream = stream1.union(stream2); - final HashSet result = new HashSet(); + final Set result = Collections.synchronizedSet(new HashSet()); unifiedStream.foreachRDD( new Function, Void>() { @Override public Void call(JavaRDD rdd) throws Exception { result.addAll(rdd.collect()); + for (OffsetRange o : offsetRanges.get()) { + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() + ); + } return null; } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index 5cf379635354f..a9dc6e50613ca 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -72,9 +72,6 @@ public void testKafkaRDD() throws InterruptedException { HashMap kafkaParams = new HashMap(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); - kafkaTestUtils.waitUntilLeaderOffset(topic1, 0, topic1data.length); - kafkaTestUtils.waitUntilLeaderOffset(topic2, 0, topic2data.length); - OffsetRange[] offsetRanges = { OffsetRange.create(topic1, 0, 0, 1), OffsetRange.create(topic2, 0, 0, 1) diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 540f4ceabab47..e4c659215b767 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -18,9 +18,7 @@ package org.apache.spark.streaming.kafka; import java.io.Serializable; -import java.util.HashMap; -import java.util.List; -import java.util.Random; +import java.util.*; import scala.Tuple2; @@ -94,7 +92,7 @@ public void testKafkaStream() throws InterruptedException { topics, StorageLevel.MEMORY_ONLY_SER()); - final HashMap result = new HashMap(); + final Map result = Collections.synchronizedMap(new HashMap()); JavaDStream words = stream.map( new Function, String>() { diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index b6d314dfc7783..5b3c79444aa68 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -28,10 +28,10 @@ import scala.language.postfixOps import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream @@ -39,7 +39,7 @@ import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.Utils class DirectKafkaStreamSuite - extends FunSuite + extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll with Eventually @@ -99,15 +99,24 @@ class DirectKafkaStreamSuite ssc, kafkaParams, topics) } - val allReceived = new ArrayBuffer[(String, String)] + val allReceived = + new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)] - stream.foreachRDD { rdd => - // Get the offset ranges in the RDD - val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + // hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + + stream.transform { rdd => + // Get the offset ranges in the RDD + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd + }.foreachRDD { rdd => + for (o <- offsetRanges) { + logInfo(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } val collected = rdd.mapPartitionsWithIndex { (i, iter) => // For each partition, get size of the range in the partition, // and the number of items in the partition - val off = offsets(i) + val off = offsetRanges(i) val all = iter.toSeq val partSize = all.size val rangeSize = off.untilOffset - off.fromOffset @@ -162,7 +171,7 @@ class DirectKafkaStreamSuite "Start offset not from latest" ) - val collectedData = new mutable.ArrayBuffer[String]() + val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() } ssc.start() val newData = Map("b" -> 10) @@ -208,7 +217,7 @@ class DirectKafkaStreamSuite "Start offset not from latest" ) - val collectedData = new mutable.ArrayBuffer[String]() + val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] stream.foreachRDD { rdd => collectedData ++= rdd.collect() } ssc.start() val newData = Map("b" -> 10) @@ -324,7 +333,8 @@ class DirectKafkaStreamSuite ssc, kafkaParams, Set(topic)) } - val allReceived = new ArrayBuffer[(String, String)] + val allReceived = + new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)] stream.foreachRDD { rdd => allReceived ++= rdd.collect() } ssc.start() @@ -350,8 +360,8 @@ class DirectKafkaStreamSuite } object DirectKafkaStreamSuite { - val collectedData = new mutable.ArrayBuffer[String]() - var total = -1L + val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] + @volatile var total = -1L class InputInfoCollector extends StreamingListener { val numRecordsSubmitted = new AtomicLong(0L) diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala index 7fb841b79cb65..d66830cbacdee 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.streaming.kafka import scala.util.Random import kafka.common.TopicAndPartition -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll { +import org.apache.spark.SparkFunSuite + +class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll { private val topic = "kcsuitetopic" + Random.nextInt(10000) private val topicAndPartition = TopicAndPartition(topic, 0) private var kc: KafkaCluster = null diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index 39c3fb448ff57..f52a738afd65b 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -22,11 +22,11 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.apache.spark._ -class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { +class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private var kafkaTestUtils: KafkaTestUtils = _ @@ -55,21 +55,39 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { test("basic usage") { val topic = s"topicbasic-${Random.nextInt}" kafkaTestUtils.createTopic(topic) - val messages = Set("the", "quick", "brown", "fox") - kafkaTestUtils.sendMessages(topic, messages.toArray) + val messages = Array("the", "quick", "brown", "fox") + kafkaTestUtils.sendMessages(topic, messages) val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, "group.id" -> s"test-consumer-${Random.nextInt}") - kafkaTestUtils.waitUntilLeaderOffset(topic, 0, messages.size) - val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) - val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( sc, kafkaParams, offsetRanges) val received = rdd.map(_._2).collect.toSet - assert(received === messages) + assert(received === messages.toSet) + + // size-related method optimizations return sane results + assert(rdd.count === messages.size) + assert(rdd.countApprox(0).getFinalValue.mean === messages.size) + assert(!rdd.isEmpty) + assert(rdd.take(1).size === 1) + assert(rdd.take(1).head._2 === messages.head) + assert(rdd.take(messages.size + 10).size === messages.size) + + val emptyRdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0))) + + assert(emptyRdd.isEmpty) + + // invalid offset ranges throw exceptions + val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1)) + intercept[SparkException] { + KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, badRanges) + } } test("iterator boundary conditions") { @@ -86,7 +104,6 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { // this is the "lots of messages" case kafkaTestUtils.sendMessages(topic, sent) val sentCount = sent.values.sum - kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount) // rdd defined from leaders after sending messages, should get the number sent val rdd = getRdd(kc, Set(topic)) @@ -113,7 +130,6 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { val sentOnlyOne = Map("d" -> 1) kafkaTestUtils.sendMessages(topic, sentOnlyOne) - kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount + 1) assert(rdd2.isDefined) assert(rdd2.get.count === 0, "got messages when there shouldn't be any") diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 24699dfc33adb..797b07f80d8ee 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -23,14 +23,14 @@ import scala.language.postfixOps import scala.util.Random import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} -class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { +class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll { private var ssc: StreamingContext = _ private var kafkaTestUtils: KafkaTestUtils = _ @@ -65,7 +65,7 @@ class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) - val result = new mutable.HashMap[String, Long]() + val result = new mutable.HashMap[String, Long]() with mutable.SynchronizedMap[String, Long] stream.map(_._2).countByValue().foreachRDD { r => val ret = r.collect() ret.toMap.foreach { kv => @@ -77,10 +77,7 @@ class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { ssc.start() eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { - assert(sent.size === result.size) - sent.keys.foreach { k => - assert(sent(k) === result(k).toInt) - } + assert(sent === result) } } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 38548dd73b82c..80e2df62de3fe 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -26,15 +26,15 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.utils.{ZKGroupTopicDirs, ZkUtils} -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils -class ReliableKafkaStreamSuite extends FunSuite +class ReliableKafkaStreamSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with Eventually { private val sparkConf = new SparkConf() diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 98f95a9a64fa0..0e41e5781784b 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.eclipse.paho org.eclipse.paho.client.mqttv3 diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 3c0ef94cb0fab..7c2f18cb35bda 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -17,25 +17,12 @@ package org.apache.spark.streaming.mqtt -import java.io.IOException -import java.util.concurrent.Executors -import java.util.Properties - -import scala.collection.JavaConversions._ -import scala.collection.Map -import scala.collection.mutable.HashMap -import scala.reflect.ClassTag - import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken import org.eclipse.paho.client.mqttv3.MqttCallback import org.eclipse.paho.client.mqttv3.MqttClient -import org.eclipse.paho.client.mqttv3.MqttClientPersistence -import org.eclipse.paho.client.mqttv3.MqttException import org.eclipse.paho.client.mqttv3.MqttMessage -import org.eclipse.paho.client.mqttv3.MqttTopic import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence -import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ @@ -57,6 +44,8 @@ class MQTTInputDStream( storageLevel: StorageLevel ) extends ReceiverInputDStream[String](ssc_) { + private[streaming] override def name: String = s"MQTT stream [$id]" + def getReceiver(): Receiver[String] = { new MQTTReceiver(brokerUrl, topic, storageLevel) } @@ -86,7 +75,7 @@ class MQTTReceiver( // Handles Mqtt message override def messageArrived(topic: String, message: MqttMessage) { - store(new String(message.getPayload(),"utf-8")) + store(new String(message.getPayload(), "utf-8")) } override def deliveryComplete(token: IMqttDeliveryToken) { diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index a19a72c58a705..c4bf5aa7869bb 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -29,7 +29,7 @@ import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually import org.apache.spark.streaming.{Milliseconds, StreamingContext} @@ -37,10 +37,10 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.scheduler.StreamingListener import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils -class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { +class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 8b6a8959ac4cf..178ae8de13b57 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.twitter4j twitter4j-stream diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index 9ee57d7581d85..d9acb568879fe 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.streaming.twitter -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import twitter4j.Status import twitter4j.auth.{NullAuthorization, Authorization} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { val batchDuration = Seconds(1) diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index a50d378b34335..37bfd10d43663 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + ${akka.group} akka-zeromq_${scala.binary.version} diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala index a7566e733d891..35d2e62c68480 100644 --- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala +++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.streaming.zeromq import akka.actor.SupervisorStrategy import akka.util.ByteString import akka.zeromq.Subscribe -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -class ZeroMQStreamSuite extends FunSuite { +class ZeroMQStreamSuite extends SparkFunSuite { val batchDuration = Seconds(1) diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 4351a8a12fe21..3636a9037d43f 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -39,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} @@ -49,6 +56,7 @@ spark-streaming_${scala.binary.version} ${project.version} test-jar + test junit diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 25847a1b33d9c..c242e7a57b9ab 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -40,6 +40,13 @@ spark-streaming_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} @@ -59,7 +66,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index b0bff27a61c19..06e0ff28afd95 100644 --- a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -20,6 +20,7 @@ import java.util.List; import java.util.regex.Pattern; +import com.amazonaws.regions.RegionUtils; import org.apache.log4j.Logger; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -40,140 +41,146 @@ import com.google.common.collect.Lists; /** - * Java-friendly Kinesis Spark Streaming WordCount example + * Consumes messages from a Amazon Kinesis streams and does wordcount. * - * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details - * on the Kinesis Spark Streaming integration. + * This example spins up 1 Kinesis Receiver per shard for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given stream. * - * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard - * for the given stream. - * It then starts pulling from the last checkpointed sequence number of the given - * and . + * Usage: JavaKinesisWordCountASL [app-name] [stream-name] [endpoint-url] [region-name] + * [app-name] is the name of the consumer app, used to track the read data in DynamoDB + * [stream-name] name of the Kinesis stream (ie. mySparkStream) + * [endpoint-url] endpoint of the Kinesis service + * (e.g. https://kinesis.us-east-1.amazonaws.com) * - * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region - * - * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials - * in the following order of precedence: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY - * Java System Properties - aws.accessKeyId and aws.secretKey - * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs - * Instance profile credentials - delivered through the Amazon EC2 metadata service - * - * Usage: JavaKinesisWordCountASL - * is the name of the Kinesis stream (ie. mySparkStream) - * is the endpoint of the Kinesis service - * (ie. https://kinesis.us-east-1.amazonaws.com) * * Example: - * $ export AWS_ACCESS_KEY_ID= + * # export AWS keys if necessary + * $ export AWS_ACCESS_KEY_ID=[your-access-key] * $ export AWS_SECRET_KEY= - * $ $SPARK_HOME/bin/run-example \ - * org.apache.spark.examples.streaming.JavaKinesisWordCountASL mySparkStream \ - * https://kinesis.us-east-1.amazonaws.com * - * Note that number of workers/threads should be 1 more than the number of receivers. - * This leaves one thread available for actually processing the data. + * # run the example + * $ SPARK_HOME/bin/run-example streaming.JavaKinesisWordCountASL myAppName mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com + * + * There is a companion helper class called KinesisWordProducerASL which puts dummy data + * onto the Kinesis stream. * - * There is a companion helper class called KinesisWordCountProducerASL which puts dummy data - * onto the Kinesis stream. - * Usage instructions for KinesisWordCountProducerASL are provided in the class definition. + * This code uses the DefaultAWSCredentialsProviderChain to find credentials + * in the following order: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * For more information, see + * http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + * + * See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + * the Kinesis Spark Streaming integration. */ public final class JavaKinesisWordCountASL { // needs to be public for access from run-example - private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); - private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); - - /* Make the constructor private to enforce singleton */ - private JavaKinesisWordCountASL() { + private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); + private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); + + public static void main(String[] args) { + // Check that all required args were passed in. + if (args.length != 3) { + System.err.println( + "Usage: JavaKinesisWordCountASL \n\n" + + " is the name of the app, used to track the read data in DynamoDB\n" + + " is the name of the Kinesis stream\n" + + " is the endpoint of the Kinesis service\n" + + " (e.g. https://kinesis.us-east-1.amazonaws.com)\n" + + "Generate data for the Kinesis stream using the example KinesisWordProducerASL.\n" + + "See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more\n" + + "details.\n" + ); + System.exit(1); } - public static void main(String[] args) { - /* Check that all required args were passed in. */ - if (args.length < 2) { - System.err.println( - "Usage: JavaKinesisWordCountASL \n" + - " is the name of the Kinesis stream\n" + - " is the endpoint of the Kinesis service\n" + - " (e.g. https://kinesis.us-east-1.amazonaws.com)\n"); - System.exit(1); - } - - StreamingExamples.setStreamingLogLevels(); - - /* Populate the appropriate variables from the given args */ - String streamName = args[0]; - String endpointUrl = args[1]; - /* Set the batch interval to a fixed 2000 millis (2 seconds) */ - Duration batchInterval = new Duration(2000); - - /* Create a Kinesis client in order to determine the number of shards for the given stream */ - AmazonKinesisClient kinesisClient = new AmazonKinesisClient( - new DefaultAWSCredentialsProviderChain()); - kinesisClient.setEndpoint(endpointUrl); - - /* Determine the number of shards from the stream */ - int numShards = kinesisClient.describeStream(streamName) - .getStreamDescription().getShards().size(); - - /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard */ - int numStreams = numShards; - - /* Setup the Spark config. */ - SparkConf sparkConfig = new SparkConf().setAppName("KinesisWordCount"); - - /* Kinesis checkpoint interval. Same as batchInterval for this example. */ - Duration checkpointInterval = batchInterval; + // Set default log4j logging level to WARN to hide Spark logs + StreamingExamples.setStreamingLogLevels(); + + // Populate the appropriate variables from the given args + String kinesisAppName = args[0]; + String streamName = args[1]; + String endpointUrl = args[2]; + + // Create a Kinesis client in order to determine the number of shards for the given stream + AmazonKinesisClient kinesisClient = + new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()); + kinesisClient.setEndpoint(endpointUrl); + int numShards = + kinesisClient.describeStream(streamName).getStreamDescription().getShards().size(); + + + // In this example, we're going to create 1 Kinesis Receiver/input DStream for each shard. + // This is not a necessity; if there are less receivers/DStreams than the number of shards, + // then the shards will be automatically distributed among the receivers and each receiver + // will receive data from multiple shards. + int numStreams = numShards; + + // Spark Streaming batch interval + Duration batchInterval = new Duration(2000); + + // Kinesis checkpoint interval. Same as batchInterval for this example. + Duration kinesisCheckpointInterval = batchInterval; + + // Get the region name from the endpoint URL to save Kinesis Client Library metadata in + // DynamoDB of the same region as the Kinesis stream + String regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName(); + + // Setup the Spark config and StreamingContext + SparkConf sparkConfig = new SparkConf().setAppName("JavaKinesisWordCountASL"); + JavaStreamingContext jssc = new JavaStreamingContext(sparkConfig, batchInterval); + + // Create the Kinesis DStreams + List> streamsList = new ArrayList>(numStreams); + for (int i = 0; i < numStreams; i++) { + streamsList.add( + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + InitialPositionInStream.LATEST, kinesisCheckpointInterval, StorageLevel.MEMORY_AND_DISK_2()) + ); + } - /* Setup the StreamingContext */ - JavaStreamingContext jssc = new JavaStreamingContext(sparkConfig, batchInterval); + // Union all the streams if there is more than 1 stream + JavaDStream unionStreams; + if (streamsList.size() > 1) { + unionStreams = jssc.union(streamsList.get(0), streamsList.subList(1, streamsList.size())); + } else { + // Otherwise, just use the 1 stream + unionStreams = streamsList.get(0); + } - /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ - List> streamsList = new ArrayList>(numStreams); - for (int i = 0; i < numStreams; i++) { - streamsList.add( - KinesisUtils.createStream(jssc, streamName, endpointUrl, checkpointInterval, - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()) - ); + // Convert each line of Array[Byte] to String, and split into words + JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { + @Override + public Iterable call(byte[] line) { + return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); + } + }); + + // Map each word to a (word, 1) tuple so we can reduce by key to count the words + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } } - - /* Union all the streams if there is more than 1 stream */ - JavaDStream unionStreams; - if (streamsList.size() > 1) { - unionStreams = jssc.union(streamsList.get(0), streamsList.subList(1, streamsList.size())); - } else { - /* Otherwise, just use the 1 stream */ - unionStreams = streamsList.get(0); + ).reduceByKey( + new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } } + ); - /* - * Split each line of the union'd DStreams into multiple words using flatMap to produce the collection. - * Convert lines of byte[] to multiple Strings by first converting to String, then splitting on WORD_SEPARATOR. - */ - JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { - @Override - public Iterable call(byte[] line) { - return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); - } - }); - - /* Map each word to a (word, 1) tuple, then reduce/aggregate by word. */ - JavaPairDStream wordCounts = words.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(String s) { - return new Tuple2(s, 1); - } - }).reduceByKey(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }); - - /* Print the first 10 wordCounts */ - wordCounts.print(); - - /* Start the streaming context and await termination */ - jssc.start(); - jssc.awaitTermination(); - } + // Print the first 10 wordCounts + wordCounts.print(); + + // Start the streaming context and await termination + jssc.start(); + jssc.awaitTermination(); + } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index 32da0858d1a1d..de749626ec09c 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -15,226 +15,253 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.nio.ByteBuffer + import scala.util.Random -import org.apache.spark.Logging -import org.apache.spark.SparkConf -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.Milliseconds -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions -import org.apache.spark.streaming.kinesis.KinesisUtils -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain + +import com.amazonaws.auth.{DefaultAWSCredentialsProviderChain, BasicAWSCredentials} +import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.PutRecordRequest -import org.apache.log4j.Logger -import org.apache.log4j.Level +import org.apache.log4j.{Level, Logger} + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} +import org.apache.spark.streaming.dstream.DStream.toPairDStreamFunctions +import org.apache.spark.streaming.kinesis.KinesisUtils + /** - * Kinesis Spark Streaming WordCount example. + * Consumes messages from a Amazon Kinesis streams and does wordcount. * - * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details on - * the Kinesis Spark Streaming integration. + * This example spins up 1 Kinesis Receiver per shard for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given stream. * - * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard - * for the given stream. - * It then starts pulling from the last checkpointed sequence number of the given - * and . + * Usage: KinesisWordCountASL + * is the name of the consumer app, used to track the read data in DynamoDB + * name of the Kinesis stream (ie. mySparkStream) + * endpoint of the Kinesis service + * (e.g. https://kinesis.us-east-1.amazonaws.com) * - * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region - * - * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials - * in the following order of precedence: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY - * Java System Properties - aws.accessKeyId and aws.secretKey - * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs - * Instance profile credentials - delivered through the Amazon EC2 metadata service - * - * Usage: KinesisWordCountASL - * is the name of the Kinesis stream (ie. mySparkStream) - * is the endpoint of the Kinesis service - * (ie. https://kinesis.us-east-1.amazonaws.com) * * Example: - * $ export AWS_ACCESS_KEY_ID= - * $ export AWS_SECRET_KEY= - * $ $SPARK_HOME/bin/run-example \ - * org.apache.spark.examples.streaming.KinesisWordCountASL mySparkStream \ - * https://kinesis.us-east-1.amazonaws.com + * # export AWS keys if necessary + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * + * # run the example + * $ SPARK_HOME/bin/run-example streaming.KinesisWordCountASL myAppName mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com * - * - * Note that number of workers/threads should be 1 more than the number of receivers. - * This leaves one thread available for actually processing the data. + * There is a companion helper class called KinesisWordProducerASL which puts dummy data + * onto the Kinesis stream. * - * There is a companion helper class below called KinesisWordCountProducerASL which puts - * dummy data onto the Kinesis stream. - * Usage instructions for KinesisWordCountProducerASL are provided in that class definition. + * This code uses the DefaultAWSCredentialsProviderChain to find credentials + * in the following order: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * For more information, see + * http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + * + * See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + * the Kinesis Spark Streaming integration. */ -private object KinesisWordCountASL extends Logging { +object KinesisWordCountASL extends Logging { def main(args: Array[String]) { - /* Check that all required args were passed in. */ - if (args.length < 2) { + // Check that all required args were passed in. + if (args.length != 3) { System.err.println( """ - |Usage: KinesisWordCount + |Usage: KinesisWordCountASL + | + | is the name of the consumer app, used to track the read data in DynamoDB | is the name of the Kinesis stream | is the endpoint of the Kinesis service | (e.g. https://kinesis.us-east-1.amazonaws.com) + | + |Generate input data for Kinesis stream using the example KinesisWordProducerASL. + |See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more + |details. """.stripMargin) System.exit(1) } StreamingExamples.setStreamingLogLevels() - /* Populate the appropriate variables from the given args */ - val Array(streamName, endpointUrl) = args + // Populate the appropriate variables from the given args + val Array(appName, streamName, endpointUrl) = args - /* Determine the number of shards from the stream */ - val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) + + // Determine the number of shards from the stream using the low-level Kinesis Client + // from the AWS Java SDK. + val credentials = new DefaultAWSCredentialsProviderChain().getCredentials() + require(credentials != null, + "No AWS credentials found. Please specify credentials using one of the methods specified " + + "in http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html") + val kinesisClient = new AmazonKinesisClient(credentials) kinesisClient.setEndpoint(endpointUrl) - val numShards = kinesisClient.describeStream(streamName).getStreamDescription().getShards() - .size() + val numShards = kinesisClient.describeStream(streamName).getStreamDescription().getShards().size + - /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard. */ + // In this example, we're going to create 1 Kinesis Receiver/input DStream for each shard. + // This is not a necessity; if there are less receivers/DStreams than the number of shards, + // then the shards will be automatically distributed among the receivers and each receiver + // will receive data from multiple shards. val numStreams = numShards - /* Setup the and SparkConfig and StreamingContext */ - /* Spark Streaming batch interval */ + // Spark Streaming batch interval val batchInterval = Milliseconds(2000) - val sparkConfig = new SparkConf().setAppName("KinesisWordCount") - val ssc = new StreamingContext(sparkConfig, batchInterval) - /* Kinesis checkpoint interval. Same as batchInterval for this example. */ + // Kinesis checkpoint interval is the interval at which the DynamoDB is updated with information + // on sequence number of records that have been received. Same as batchInterval for this + // example. val kinesisCheckpointInterval = batchInterval - /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ + // Get the region name from the endpoint URL to save Kinesis Client Library metadata in + // DynamoDB of the same region as the Kinesis stream + val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + + // Setup the SparkConfig and StreamingContext + val sparkConfig = new SparkConf().setAppName("KinesisWordCountASL") + val ssc = new StreamingContext(sparkConfig, batchInterval) + + // Create the Kinesis DStreams val kinesisStreams = (0 until numStreams).map { i => - KinesisUtils.createStream(ssc, streamName, endpointUrl, kinesisCheckpointInterval, - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + KinesisUtils.createStream(ssc, appName, streamName, endpointUrl, regionName, + InitialPositionInStream.LATEST, kinesisCheckpointInterval, StorageLevel.MEMORY_AND_DISK_2) } - /* Union all the streams */ + // Union all the streams val unionStreams = ssc.union(kinesisStreams) - /* Convert each line of Array[Byte] to String, split into words, and count them */ - val words = unionStreams.flatMap(byteArray => new String(byteArray) - .split(" ")) + // Convert each line of Array[Byte] to String, and split into words + val words = unionStreams.flatMap(byteArray => new String(byteArray).split(" ")) - /* Map each word to a (word, 1) tuple so we can reduce/aggregate by key. */ + // Map each word to a (word, 1) tuple so we can reduce by key to count the words val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _) - /* Print the first 10 wordCounts */ + // Print the first 10 wordCounts wordCounts.print() - /* Start the streaming context and await termination */ + // Start the streaming context and await termination ssc.start() ssc.awaitTermination() } } /** - * Usage: KinesisWordCountProducerASL - * + * Usage: KinesisWordProducerASL \ + * + * * is the name of the Kinesis stream (ie. mySparkStream) - * is the endpoint of the Kinesis service + * is the endpoint of the Kinesis service * (ie. https://kinesis.us-east-1.amazonaws.com) * is the rate of records per second to put onto the stream * is the rate of records per second to put onto the stream * * Example: - * $ export AWS_ACCESS_KEY_ID= - * $ export AWS_SECRET_KEY= - * $ $SPARK_HOME/bin/run-example \ - * org.apache.spark.examples.streaming.KinesisWordCountProducerASL mySparkStream \ - * https://kinesis.us-east-1.amazonaws.com 10 5 + * $ SPARK_HOME/bin/run-example streaming.KinesisWordProducerASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com us-east-1 10 5 */ -private object KinesisWordCountProducerASL { +object KinesisWordProducerASL { def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: KinesisWordCountProducerASL " + - " ") + if (args.length != 4) { + System.err.println( + """ + |Usage: KinesisWordProducerASL + + | + | is the name of the Kinesis stream + | is the endpoint of the Kinesis service + | (e.g. https://kinesis.us-east-1.amazonaws.com) + | is the rate of records per second to put onto the stream + | is the rate of records per second to put onto the stream + | + """.stripMargin) + System.exit(1) } + // Set default log4j logging level to WARN to hide Spark logs StreamingExamples.setStreamingLogLevels() - /* Populate the appropriate variables from the given args */ + // Populate the appropriate variables from the given args val Array(stream, endpoint, recordsPerSecond, wordsPerRecord) = args - /* Generate the records and return the totals */ - val totals = generate(stream, endpoint, recordsPerSecond.toInt, wordsPerRecord.toInt) + // Generate the records and return the totals + val totals = generate(stream, endpoint, recordsPerSecond.toInt, + wordsPerRecord.toInt) - /* Print the array of (index, total) tuples */ - println("Totals") - totals.foreach(total => println(total.toString())) + // Print the array of (word, total) tuples + println("Totals for the words sent") + totals.foreach(println(_)) } def generate(stream: String, endpoint: String, recordsPerSecond: Int, - wordsPerRecord: Int): Seq[(Int, Int)] = { + wordsPerRecord: Int): Seq[(String, Int)] = { - val MaxRandomInts = 10 + val randomWords = List("spark", "you", "are", "my", "father") + val totals = scala.collection.mutable.Map[String, Int]() - /* Create the Kinesis client */ + // Create the low-level Kinesis Client from the AWS Java SDK. val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) kinesisClient.setEndpoint(endpoint) println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" + - s" $recordsPerSecond records per second and $wordsPerRecord words per record"); - - val totals = new Array[Int](MaxRandomInts) - /* Put String records onto the stream per the given recordPerSec and wordsPerRecord */ - for (i <- 1 to 5) { - - /* Generate recordsPerSec records to put onto the stream */ - val records = (1 to recordsPerSecond.toInt).map { recordNum => - /* - * Randomly generate each wordsPerRec words between 0 (inclusive) - * and MAX_RANDOM_INTS (exclusive) - */ + s" $recordsPerSecond records per second and $wordsPerRecord words per record") + + // Iterate and put records onto the stream per the given recordPerSec and wordsPerRecord + for (i <- 1 to 10) { + // Generate recordsPerSec records to put onto the stream + val records = (1 to recordsPerSecond.toInt).foreach { recordNum => + // Randomly generate wordsPerRecord number of words val data = (1 to wordsPerRecord.toInt).map(x => { - /* Generate the random int */ - val randomInt = Random.nextInt(MaxRandomInts) + // Get a random index to a word + val randomWordIdx = Random.nextInt(randomWords.size) + val randomWord = randomWords(randomWordIdx) - /* Keep track of the totals */ - totals(randomInt) += 1 + // Increment total count to compare to server counts later + totals(randomWord) = totals.getOrElse(randomWord, 0) + 1 - randomInt.toString() + randomWord }).mkString(" ") - /* Create a partitionKey based on recordNum */ + // Create a partitionKey based on recordNum val partitionKey = s"partitionKey-$recordNum" - /* Create a PutRecordRequest with an Array[Byte] version of the data */ + // Create a PutRecordRequest with an Array[Byte] version of the data val putRecordRequest = new PutRecordRequest().withStreamName(stream) .withPartitionKey(partitionKey) - .withData(ByteBuffer.wrap(data.getBytes())); + .withData(ByteBuffer.wrap(data.getBytes())) - /* Put the record onto the stream and capture the PutRecordResult */ - val putRecordResult = kinesisClient.putRecord(putRecordRequest); + // Put the record onto the stream and capture the PutRecordResult + val putRecordResult = kinesisClient.putRecord(putRecordRequest) } - /* Sleep for a second */ + // Sleep for a second Thread.sleep(1000) println("Sent " + recordsPerSecond + " records") } - - /* Convert the totals to (index, total) tuple */ - (0 to (MaxRandomInts - 1)).zip(totals) + // Convert the totals to (index, total) tuple + totals.toSeq.sortBy(_._1) } } -/** - * Utility functions for Spark Streaming examples. +/** + * Utility functions for Spark Streaming examples. * This has been lifted from the examples/ project to remove the circular dependency. */ private[streaming] object StreamingExamples extends Logging { - - /** Set reasonable logging levels for streaming if the user has not configured log4j. */ + // Set reasonable logging levels for streaming if the user has not configured log4j. def setStreamingLogLevels() { val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements if (!log4jInitialized) { @@ -246,3 +273,4 @@ private[streaming] object StreamingExamples extends Logging { } } } +// scalastyle:on println diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala index 588e86a1887ec..83a4537559512 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -23,20 +23,20 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock} /** * This is a helper class for managing checkpoint clocks. * - * @param checkpointInterval + * @param checkpointInterval * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) */ private[kinesis] class KinesisCheckpointState( - checkpointInterval: Duration, + checkpointInterval: Duration, currentClock: Clock = new SystemClock()) extends Logging { - + /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ val checkpointClock = new ManualClock() checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) /** - * Check if it's time to checkpoint based on the current time and the derived time + * Check if it's time to checkpoint based on the current time and the derived time * for the next checkpoint * * @return true if it's time to checkpoint @@ -48,7 +48,7 @@ private[kinesis] class KinesisCheckpointState( /** * Advance the checkpoint clock by the checkpoint interval. */ - def advanceCheckpoint() = { + def advanceCheckpoint(): Unit = { checkpointClock.advance(checkpointInterval.milliseconds) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index a7fe4476cacb8..1a8a4cecc1141 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -16,39 +16,45 @@ */ package org.apache.spark.streaming.kinesis -import java.net.InetAddress import java.util.UUID +import scala.util.control.NonFatal + +import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, BasicAWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory} +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} + import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -import com.amazonaws.auth.AWSCredentialsProvider -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker + +private[kinesis] +case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) + extends AWSCredentials { + override def getAWSAccessKeyId: String = accessKeyId + override def getAWSSecretKey: String = secretKey +} /** * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: * https://github.com/awslabs/amazon-kinesis-client - * This is a custom receiver used with StreamingContext.receiverStream(Receiver) - * as described here: - * http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * Instances of this class will get shipped to the Spark Streaming Workers - * to run within a Spark Executor. + * This is a custom receiver used with StreamingContext.receiverStream(Receiver) as described here: + * http://spark.apache.org/docs/latest/streaming-custom-receivers.html + * Instances of this class will get shipped to the Spark Streaming Workers to run within a + * Spark Executor. * * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams * by the Kinesis Client Library. If you change the App name or Stream name, - * the KCL will throw errors. This usually requires deleting the backing + * the KCL will throw errors. This usually requires deleting the backing * DynamoDB table with the same name this Kinesis application. * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Region name used by the Kinesis Client Library for + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. @@ -59,92 +65,121 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). * @param storageLevel Storage level to use for storing the received objects - * - * @return ReceiverInputDStream[Array[Byte]] + * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies + * the credentials */ private[kinesis] class KinesisReceiver( appName: String, streamName: String, endpointUrl: String, - checkpointInterval: Duration, + regionName: String, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel) - extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => - - /* - * The following vars are built in the onStart() method which executes in the Spark Worker after - * this code is serialized and shipped remotely. - */ + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsCredentialsOption: Option[SerializableAWSCredentials] + ) extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => /* - * workerId should be based on the ip address of the actual Spark Worker where this code runs - * (not the Driver's ip address.) + * ================================================================================= + * The following vars are initialize in the onStart() method which executes in the + * Spark worker after this Receiver is serialized and shipped to the worker. + * ================================================================================= */ - var workerId: String = null - /* - * This impl uses the DefaultAWSCredentialsProviderChain and searches for credentials - * in the following order of precedence: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY - * Java System Properties - aws.accessKeyId and aws.secretKey - * Credential profiles file at the default location (~/.aws/credentials) shared by all - * AWS SDKs and the AWS CLI - * Instance profile credentials delivered through the Amazon EC2 metadata service + /** + * workerId is used by the KCL should be based on the ip address of the actual Spark Worker + * where this code runs (not the driver's IP address.) */ - var credentialsProvider: AWSCredentialsProvider = null - - /* KCL config instance. */ - var kinesisClientLibConfiguration: KinesisClientLibConfiguration = null + private var workerId: String = null - /* - * RecordProcessorFactory creates impls of IRecordProcessor. - * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the - * IRecordProcessor.processRecords() method. - * We're using our custom KinesisRecordProcessor in this case. + /** + * Worker is the core client abstraction from the Kinesis Client Library (KCL). + * A worker can process more than one shards from the given stream. + * Each shard is assigned its own IRecordProcessor and the worker run multiple such + * processors. */ - var recordProcessorFactory: IRecordProcessorFactory = null + private var worker: Worker = null - /* - * Create a Kinesis Worker. - * This is the core client abstraction from the Kinesis Client Library (KCL). - * We pass the RecordProcessorFactory from above as well as the KCL config instance. - * A Kinesis Worker can process 1..* shards from the given stream - each with its - * own RecordProcessor. - */ - var worker: Worker = null + /** Thread running the worker */ + private var workerThread: Thread = null /** - * This is called when the KinesisReceiver starts and must be non-blocking. - * The KCL creates and manages the receiving/processing thread pool through the Worker.run() - * method. + * This is called when the KinesisReceiver starts and must be non-blocking. + * The KCL creates and manages the receiving/processing thread pool through Worker.run(). */ override def onStart() { workerId = Utils.localHostName() + ":" + UUID.randomUUID() - credentialsProvider = new DefaultAWSCredentialsProviderChain() - kinesisClientLibConfiguration = new KinesisClientLibConfiguration(appName, streamName, - credentialsProvider, workerId).withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPositionInStream).withTaskBackoffTimeMillis(500) - recordProcessorFactory = new IRecordProcessorFactory { + + // KCL config instance + val awsCredProvider = resolveAWSCredentialsProvider() + val kinesisClientLibConfiguration = + new KinesisClientLibConfiguration(appName, streamName, awsCredProvider, workerId) + .withKinesisEndpoint(endpointUrl) + .withInitialPositionInStream(initialPositionInStream) + .withTaskBackoffTimeMillis(500) + .withRegionName(regionName) + + /* + * RecordProcessorFactory creates impls of IRecordProcessor. + * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the + * IRecordProcessor.processRecords() method. + * We're using our custom KinesisRecordProcessor in this case. + */ + val recordProcessorFactory = new IRecordProcessorFactory { override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, workerId, new KinesisCheckpointState(checkpointInterval)) } + worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) - worker.run() + workerThread = new Thread() { + override def run(): Unit = { + try { + worker.run() + } catch { + case NonFatal(e) => + restart("Error running the KCL worker in Receiver", e) + } + } + } + workerThread.setName(s"Kinesis Receiver ${streamId}") + workerThread.setDaemon(true) + workerThread.start() logInfo(s"Started receiver with workerId $workerId") } /** - * This is called when the KinesisReceiver stops. - * The KCL worker.shutdown() method stops the receiving/processing threads. - * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. + * This is called when the KinesisReceiver stops. + * The KCL worker.shutdown() method stops the receiving/processing threads. + * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. */ override def onStop() { - worker.shutdown() - logInfo(s"Shut down receiver with workerId $workerId") + if (workerThread != null) { + if (worker != null) { + worker.shutdown() + worker = null + } + workerThread.join() + workerThread = null + logInfo(s"Stopped receiver for workerId $workerId") + } workerId = null - credentialsProvider = null - kinesisClientLibConfiguration = null - recordProcessorFactory = null - worker = null + } + + /** + * If AWS credential is provided, return a AWSCredentialProvider returning that credential. + * Otherwise, return the DefaultAWSCredentialsProviderChain. + */ + private def resolveAWSCredentialsProvider(): AWSCredentialsProvider = { + awsCredentialsOption match { + case Some(awsCredentials) => + logInfo("Using provided AWS credentials") + new AWSCredentialsProvider { + override def getCredentials: AWSCredentials = awsCredentials + override def refresh(): Unit = { } + } + case None => + logInfo("Using DefaultAWSCredentialsProviderChain") + new DefaultAWSCredentialsProviderChain() + } } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index af8cd875b4541..fe9e3a0c793e2 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -35,7 +35,10 @@ import com.amazonaws.services.kinesis.model.Record /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. - * The Kinesis Worker creates an instance of this KinesisRecordProcessor upon startup. + * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each + * shard in the Kinesis stream upon startup. This is normally done in separate threads, + * but the KCLs within the KinesisReceivers will balance themselves out if you create + * multiple Receivers. * * @param receiver Kinesis receiver * @param workerId for logging purposes @@ -47,8 +50,8 @@ private[kinesis] class KinesisRecordProcessor( workerId: String, checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { - /* shardId to be populated during initialize() */ - var shardId: String = _ + // shardId to be populated during initialize() + private var shardId: String = _ /** * The Kinesis Client Library calls this method during IRecordProcessor initialization. @@ -56,8 +59,8 @@ private[kinesis] class KinesisRecordProcessor( * @param shardId assigned by the KCL to this particular RecordProcessor. */ override def initialize(shardId: String) { - logInfo(s"Initialize: Initializing workerId $workerId with shardId $shardId") this.shardId = shardId + logInfo(s"Initialized workerId $workerId with shardId $shardId") } /** @@ -66,29 +69,34 @@ private[kinesis] class KinesisRecordProcessor( * and Spark Streaming's Receiver.store(). * * @param batch list of records from the Kinesis stream shard - * @param checkpointer used to update Kinesis when this batch has been processed/stored + * @param checkpointer used to update Kinesis when this batch has been processed/stored * in the DStream */ override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { if (!receiver.isStopped()) { try { /* - * Note: If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming - * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the - * internally-configured Spark serializer (kryo, etc). - * This is not desirable, so we instead store a raw Array[Byte] and decouple - * ourselves from Spark's internal serialization strategy. - */ + * Notes: + * 1) If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming + * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the + * internally-configured Spark serializer (kryo, etc). + * 2) This is not desirable, so we instead store a raw Array[Byte] and decouple + * ourselves from Spark's internal serialization strategy. + * 3) For performance, the BlockGenerator is asynchronously queuing elements within its + * memory before creating blocks. This prevents the small block scenario, but requires + * that you register callbacks to know when a block has been generated and stored + * (WAL is sufficient for storage) before can checkpoint back to the source. + */ batch.foreach(record => receiver.store(record.getData().array())) - + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") /* - * Checkpoint the sequence number of the last record successfully processed/stored + * Checkpoint the sequence number of the last record successfully processed/stored * in the batch. * In this implementation, we're checkpointing after the given checkpointIntervalMillis. - * Note that this logic requires that processRecords() be called AND that it's time to - * checkpoint. I point this out because there is no background thread running the + * Note that this logic requires that processRecords() be called AND that it's time to + * checkpoint. I point this out because there is no background thread running the * checkpointer. Checkpointing is tested and trigger only when a new batch comes in. * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below). * However, if the worker dies unexpectedly, a checkpoint may not happen. @@ -116,22 +124,22 @@ private[kinesis] class KinesisRecordProcessor( logError(s"Exception: WorkerId $workerId encountered and exception while storing " + " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) - /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor.*/ + /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ throw e } } } else { /* RecordProcessor has been stopped. */ - logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + + logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + s" and shardId $shardId. No more records will be processed.") } } /** * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: - * 1) the stream is resharding by splitting or merging adjacent shards + * 1) the stream is resharding by splitting or merging adjacent shards * (ShutdownReason.TERMINATE) - * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason + * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason * (ShutdownReason.ZOMBIE) * * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE @@ -145,7 +153,7 @@ private[kinesis] class KinesisRecordProcessor( * Checkpoint to indicate that all records from the shard have been drained and processed. * It's now OK to read from the new shards that resulted from a resharding event. */ - case ShutdownReason.TERMINATE => + case ShutdownReason.TERMINATE => KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) /* @@ -190,7 +198,7 @@ private[kinesis] object KinesisRecordProcessor extends Logging { logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) } - /* Throw: Shutdown has been requested by the Kinesis Client Library.*/ + /* Throw: Shutdown has been requested by the Kinesis Client Library. */ case _: ShutdownException => { logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) throw e diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 96f4399accd3a..e5acab50181e1 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -16,29 +16,78 @@ */ package org.apache.spark.streaming.kinesis -import org.apache.spark.annotation.Experimental +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.Duration -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream -import org.apache.spark.streaming.api.java.JavaStreamingContext +import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream - -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.apache.spark.streaming.{Duration, StreamingContext} -/** - * Helper class to create Amazon Kinesis Input Stream - * :: Experimental :: - */ -@Experimental object KinesisUtils { /** - * Create an InputDStream that pulls messages from a Kinesis stream. - * :: Experimental :: - * @param ssc StreamingContext object + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel + ): ReceiverInputDStream[Array[Byte]] = { + // Setting scope to override receiver stream's scope of "receiver stream" + ssc.withNamedScope("kinesis stream") { + ssc.receiverStream( + new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, checkpointInterval, storageLevel, None)) + } + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. @@ -48,28 +97,84 @@ object KinesisUtils { * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). - * @param storageLevel Storage level to use for storing the received objects + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): ReceiverInputDStream[Array[Byte]] = { + ssc.receiverStream( + new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, checkpointInterval, storageLevel, + Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * @return ReceiverInputDStream[Array[Byte]] + * Note: + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in + * [[org.apache.spark.SparkConf]]. + * + * @param ssc Java StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Endpoint url of Kinesis service + * (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ - @Experimental + @deprecated("use other forms of createStream", "1.4.0") def createStream( ssc: StreamingContext, streamName: String, endpointUrl: String, checkpointInterval: Duration, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { - ssc.receiverStream(new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, - checkpointInterval, initialPositionInStream, storageLevel)) + storageLevel: StorageLevel + ): ReceiverInputDStream[Array[Byte]] = { + ssc.receiverStream( + new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, getRegionByEndpoint(endpointUrl), + initialPositionInStream, checkpointInterval, storageLevel, None)) } /** - * Create a Java-friendly InputDStream that pulls messages from a Kinesis stream. - * :: Experimental :: + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. @@ -79,19 +184,116 @@ object KinesisUtils { * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). - * @param storageLevel Storage level to use for storing the received objects + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * @return JavaReceiverInputDStream[Array[Byte]] + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ - @Experimental def createStream( - jssc: JavaStreamingContext, - streamName: String, - endpointUrl: String, + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, awsAccessKeyId, awsSecretKey) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in + * [[org.apache.spark.SparkConf]]. + * + * @param jssc Java StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Endpoint url of Kinesis service + * (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + @deprecated("use other forms of createStream", "1.4.0") + def createStream( + jssc: JavaStreamingContext, + streamName: String, + endpointUrl: String, checkpointInterval: Duration, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel): JavaReceiverInputDStream[Array[Byte]] = { - jssc.receiverStream(new KinesisReceiver(jssc.ssc.sc.appName, streamName, - endpointUrl, checkpointInterval, initialPositionInStream, storageLevel)) + storageLevel: StorageLevel + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream( + jssc.ssc, streamName, endpointUrl, checkpointInterval, initialPositionInStream, storageLevel) + } + + private def getRegionByEndpoint(endpointUrl: String): String = { + RegionUtils.getRegionByEndpoint(endpointUrl).getName() + } + + private def validateRegion(regionName: String): String = { + Option(RegionUtils.getRegion(regionName)).map { _.getName }.getOrElse { + throw new IllegalArgumentException(s"Region name '$regionName' is not valid") + } } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 255fe65819608..2103dca6b766f 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -20,26 +20,18 @@ import java.nio.ByteBuffer import scala.collection.JavaConversions.seqAsJavaList -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.Milliseconds -import org.apache.spark.streaming.Seconds -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.TestSuiteBase -import org.apache.spark.util.{ManualClock, Clock} - -import org.mockito.Mockito._ -import org.scalatest.BeforeAndAfter -import org.scalatest.Matchers -import org.scalatest.mock.MockitoSugar - -import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfter, Matchers} +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, Seconds, StreamingContext, TestSuiteBase} +import org.apache.spark.util.{Clock, ManualClock, Utils} /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor @@ -65,7 +57,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft var checkpointStateMock: KinesisCheckpointState = _ var currentClockMock: Clock = _ - override def beforeFunction() = { + override def beforeFunction(): Unit = { receiverMock = mock[KinesisReceiver] checkpointerMock = mock[IRecordProcessorCheckpointer] checkpointClockMock = mock[ManualClock] @@ -81,15 +73,28 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft checkpointStateMock, currentClockMock) } - test("kinesis utils api") { + test("KinesisUtils API") { val ssc = new StreamingContext(master, framework, batchDuration) // Tests the API, does not actually test data receiving - val kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", + val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", "https://kinesis.us-west-2.amazonaws.com", Seconds(2), - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2); + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + ssc.stop() } + test("check serializability of SerializableAWSCredentials") { + Utils.deserialize[SerializableAWSCredentials]( + Utils.serialize(new SerializableAWSCredentials("x", "y"))) + } + test("process records including store and checkpoint") { when(receiverMock.isStopped()).thenReturn(false) when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index e14bbae4a9b6e..478d0019a25f0 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index d38a3aa8256b7..853dea9a7795e 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + com.google.guava guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala index 058c8c8aa1b24..ce1054ed92ba1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala @@ -26,8 +26,8 @@ class EdgeDirection private (private val name: String) extends Serializable { * out becomes in and both and either remain the same. */ def reverse: EdgeDirection = this match { - case EdgeDirection.In => EdgeDirection.Out - case EdgeDirection.Out => EdgeDirection.In + case EdgeDirection.In => EdgeDirection.Out + case EdgeDirection.Out => EdgeDirection.In case EdgeDirection.Either => EdgeDirection.Either case EdgeDirection.Both => EdgeDirection.Both } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index cc70b396a8dd4..4611a3ace219b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -41,14 +41,16 @@ abstract class EdgeRDD[ED]( @transient sc: SparkContext, @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { + // scalastyle:off structural.type private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } + // scalastyle:on structural.type override protected def getPartitions: Array[Partition] = partitionsRDD.partitions override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context) if (p.hasNext) { - p.next._2.iterator.map(_.copy()) + p.next()._2.iterator.map(_.copy()) } else { Iterator.empty } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala index c8790cac3d8a0..65f82429d2029 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala @@ -37,7 +37,7 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { /** * Set the edge properties of this triplet. */ - protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD,ED] = { + protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD, ED] = { srcId = other.srcId dstId = other.dstId attr = other.attr diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 36dc7b0f86c89..db73a8abc5733 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -316,7 +316,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * satisfy the predicates */ def subgraph( - epred: EdgeTriplet[VD,ED] => Boolean = (x => true), + epred: EdgeTriplet[VD, ED] => Boolean = (x => true), vpred: (VertexId, VD) => Boolean = ((v, d) => true)) : Graph[VD, ED] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 7edd627b20918..9451ff1e5c0e2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -124,18 +124,18 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = { val nbrs = edgeDirection match { case EdgeDirection.Either => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => { ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))) ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))) }, (a, b) => a ++ b, TripletFields.All) case EdgeDirection.In => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))), (a, b) => a ++ b, TripletFields.Src) case EdgeDirection.Out => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))), (a, b) => a ++ b, TripletFields.Dst) case EdgeDirection.Both => @@ -253,7 +253,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def filter[VD2: ClassTag, ED2: ClassTag]( preprocess: Graph[VD, ED] => Graph[VD2, ED2], epred: (EdgeTriplet[VD2, ED2]) => Boolean = (x: EdgeTriplet[VD2, ED2]) => true, - vpred: (VertexId, VD2) => Boolean = (v:VertexId, d:VD2) => true): Graph[VD, ED] = { + vpred: (VertexId, VD2) => Boolean = (v: VertexId, d: VD2) => true): Graph[VD, ED] = { graph.mask(preprocess(graph).subgraph(epred, vpred)) } @@ -356,7 +356,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali maxIterations: Int = Int.MaxValue, activeDirection: EdgeDirection = EdgeDirection.Either)( vprog: (VertexId, VD, A) => VD, - sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId,A)], + sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], mergeMsg: (A, A) => A) : Graph[VD, ED] = { Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala index 7372dfbd9fe98..70a7592da8ae3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala @@ -32,7 +32,7 @@ trait PartitionStrategy extends Serializable { object PartitionStrategy { /** * Assigns edges to partitions using a 2D partitioning of the sparse edge adjacency matrix, - * guaranteeing a `2 * sqrt(numParts) - 1` bound on vertex replication. + * guaranteeing a `2 * sqrt(numParts)` bound on vertex replication. * * Suppose we have a graph with 12 vertices that we want to partition * over 9 machines. We can use the following sparse matrix representation: @@ -61,26 +61,36 @@ object PartitionStrategy { * that edges adjacent to `v11` can only be in the first column of blocks `(P0, P3, * P6)` or the last * row of blocks `(P6, P7, P8)`. As a consequence we can guarantee that `v11` will need to be - * replicated to at most `2 * sqrt(numParts) - 1` machines. + * replicated to at most `2 * sqrt(numParts)` machines. * * Notice that `P0` has many edges and as a consequence this partitioning would lead to poor work * balance. To improve balance we first multiply each vertex id by a large prime to shuffle the * vertex locations. * - * One of the limitations of this approach is that the number of machines must either be a - * perfect square. We partially address this limitation by computing the machine assignment to - * the next - * largest perfect square and then mapping back down to the actual number of machines. - * Unfortunately, this can also lead to work imbalance and so it is suggested that a perfect - * square is used. + * When the number of partitions requested is not a perfect square we use a slightly different + * method where the last column can have a different number of rows than the others while still + * maintaining the same size per block. */ case object EdgePartition2D extends PartitionStrategy { override def getPartition(src: VertexId, dst: VertexId, numParts: PartitionID): PartitionID = { val ceilSqrtNumParts: PartitionID = math.ceil(math.sqrt(numParts)).toInt val mixingPrime: VertexId = 1125899906842597L - val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt - val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt - (col * ceilSqrtNumParts + row) % numParts + if (numParts == ceilSqrtNumParts * ceilSqrtNumParts) { + // Use old method for perfect squared to ensure we get same results + val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt + val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt + (col * ceilSqrtNumParts + row) % numParts + + } else { + // Otherwise use new method + val cols = ceilSqrtNumParts + val rows = (numParts + cols - 1) / cols + val lastColRows = numParts - rows * (cols - 1) + val col = (math.abs(src * mixingPrime) % numParts / rows).toInt + val row = (math.abs(dst * mixingPrime) % (if (col < cols - 1) rows else lastColRows)).toInt + col * rows + row + + } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 01b013ff716fc..cfcf7244eaed5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -147,10 +147,10 @@ object Pregel extends Logging { logInfo("Pregel finished iteration " + i) // Unpersist the RDDs hidden by newly-materialized RDDs - oldMessages.unpersist(blocking=false) - newVerts.unpersist(blocking=false) - prevG.unpersistVertices(blocking=false) - prevG.edges.unpersist(blocking=false) + oldMessages.unpersist(blocking = false) + newVerts.unpersist(blocking = false) + prevG.unpersistVertices(blocking = false) + prevG.edges.unpersist(blocking = false) // count the iteration i += 1 } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index c561570809253..ab021a252eb8a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -156,8 +156,8 @@ class EdgePartition[ val size = data.size var i = 0 while (i < size) { - edge.srcId = srcIds(i) - edge.dstId = dstIds(i) + edge.srcId = srcIds(i) + edge.dstId = dstIds(i) edge.attr = data(i) newData(i) = f(edge) i += 1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index bc974b2f04e70..8c0a461e99fa4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -116,7 +116,7 @@ object PageRank extends Logging { val personalized = srcId isDefined val src: VertexId = srcId.getOrElse(-1L) - def delta(u: VertexId, v: VertexId):Double = { if (u == v) 1.0 else 0.0 } + def delta(u: VertexId, v: VertexId): Double = { if (u == v) 1.0 else 0.0 } var iteration = 0 var prevRankGraph: Graph[Double, Double] = null @@ -133,13 +133,13 @@ object PageRank extends Logging { // edge partitions. prevRankGraph = rankGraph val rPrb = if (personalized) { - (src: VertexId ,id: VertexId) => resetProb * delta(src,id) + (src: VertexId , id: VertexId) => resetProb * delta(src, id) } else { (src: VertexId, id: VertexId) => resetProb } rankGraph = rankGraph.joinVertices(rankUpdates) { - (id, oldRank, msgSum) => rPrb(src,id) + (1.0 - resetProb) * msgSum + (id, oldRank, msgSum) => rPrb(src, id) + (1.0 - resetProb) * msgSum }.cache() rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices @@ -243,7 +243,7 @@ object PageRank extends Logging { // Execute a dynamic version of Pregel. val vp = if (personalized) { - (id: VertexId, attr: (Double, Double),msgSum: Double) => + (id: VertexId, attr: (Double, Double), msgSum: Double) => personalizedVertexProgram(id, attr, msgSum) } else { (id: VertexId, attr: (Double, Double), msgSum: Double) => diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 3b0e1628d86b5..9cb24ed080e1c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -210,7 +210,7 @@ object SVDPlusPlus { /** * Forces materialization of a Graph by count()ing its RDDs. */ - private def materialize(g: Graph[_,_]): Unit = { + private def materialize(g: Graph[_, _]): Unit = { g.vertices.count() g.edges.count() } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala index daf162085e3e4..a5d598053f9ca 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala @@ -38,7 +38,7 @@ import org.apache.spark.graphx._ */ object TriangleCount { - def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD,ED]): Graph[Int, ED] = { + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = { // Remove redundant edges val g = graph.groupEdges((a, b) => a).cache() @@ -49,7 +49,7 @@ object TriangleCount { var i = 0 while (i < nbrs.size) { // prevent self cycle - if(nbrs(i) != vid) { + if (nbrs(i) != vid) { set.add(nbrs(i)) } i += 1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index be6b9047d932d..74a7de18d4161 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -66,7 +66,6 @@ private[graphx] object BytecodeUtils { val finder = new MethodInvocationFinder(c.getName, m) getClassReader(c).accept(finder, 0) for (classMethod <- finder.methodsInvoked) { - // println(classMethod) if (classMethod._1 == targetClass && classMethod._2 == targetMethod) { return true } else if (!seen.contains(classMethod)) { @@ -122,7 +121,7 @@ private[graphx] object BytecodeUtils { override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { if (!skipClass(owner)) { - methodsInvoked.add((Class.forName(owner.replace("/", ".")), name)) + methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name)) } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 2d6a825b61726..989e226305265 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -33,7 +33,7 @@ import org.apache.spark.graphx.Edge import org.apache.spark.graphx.impl.GraphImpl /** A collection of graph generating functions. */ -object GraphGenerators { +object GraphGenerators extends Logging { val RMATa = 0.45 val RMATb = 0.15 @@ -142,7 +142,7 @@ object GraphGenerators { var edges: Set[Edge[Int]] = Set() while (edges.size < numEdges) { if (edges.size % 100 == 0) { - println(edges.size + " edges") + logDebug(edges.size + " edges") } edges += addEdge(numVertices) } @@ -243,14 +243,15 @@ object GraphGenerators { * @return A graph containing vertices with the row and column ids * as their attributes and edge values as 1.0. */ - def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int,Int), Double] = { + def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int, Int), Double] = { // Convert row column address into vertex ids (row major order) def sub2ind(r: Int, c: Int): VertexId = r * cols + c - val vertices: RDD[(VertexId, (Int,Int))] = - sc.parallelize(0 until rows).flatMap( r => (0 until cols).map( c => (sub2ind(r,c), (r,c)) ) ) + val vertices: RDD[(VertexId, (Int, Int))] = sc.parallelize(0 until rows).flatMap { r => + (0 until cols).map( c => (sub2ind(r, c), (r, c)) ) + } val edges: RDD[Edge[Double]] = - vertices.flatMap{ case (vid, (r,c)) => + vertices.flatMap{ case (vid, (r, c)) => (if (r + 1 < rows) { Seq( (sub2ind(r, c), sub2ind(r + 1, c))) } else { Seq.empty }) ++ (if (c + 1 < cols) { Seq( (sub2ind(r, c), sub2ind(r, c + 1))) } else { Seq.empty }) }.map{ case (src, dst) => Edge(src, dst, 1.0) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala index eb1dbe52c2fda..f1ecc9e2219d1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel -class EdgeRDDSuite extends FunSuite with LocalSparkContext { +class EdgeRDDSuite extends SparkFunSuite with LocalSparkContext { test("cache, getStorageLevel") { // test to see if getStorageLevel returns correct value after caching diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala index 5a2c73b414279..094a63472eaab 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala @@ -17,21 +17,21 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class EdgeSuite extends FunSuite { +class EdgeSuite extends SparkFunSuite { test ("compare") { // decending order val testEdges: Array[Edge[Int]] = Array( - Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), - Edge(0x2345L, 0x1234L, 1), - Edge(0x1234L, 0x5678L, 1), - Edge(0x1234L, 0x2345L, 1), + Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), + Edge(0x2345L, 0x1234L, 1), + Edge(0x1234L, 0x5678L, 1), + Edge(0x1234L, 0x2345L, 1), Edge(-0x7FEDCBA987654321L, 0x7FEDCBA987654321L, 1) ) // to ascending order val sortedEdges = testEdges.sorted(Edge.lexicographicOrdering[Int]) - + for (i <- 0 until testEdges.length) { assert(sortedEdges(i) == testEdges(testEdges.length - i - 1)) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index 9bc8007ce49cd..57a8b95dd12e9 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.rdd._ -import org.scalatest.FunSuite -class GraphOpsSuite extends FunSuite with LocalSparkContext { +class GraphOpsSuite extends SparkFunSuite with LocalSparkContext { test("joinVertices") { withSpark { sc => @@ -59,7 +58,7 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { test ("filter") { withSpark { sc => val n = 5 - val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x))) + val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) val graph: Graph[Int, Int] = Graph(vertices, edges).cache() val filteredGraph = graph.filter( @@ -67,11 +66,11 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { val degrees: VertexRDD[Int] = graph.outDegrees graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)} }, - vpred = (vid: VertexId, deg:Int) => deg > 0 + vpred = (vid: VertexId, deg: Int) => deg > 0 ).cache() val v = filteredGraph.vertices.collect().toSet - assert(v === Set((0,0))) + assert(v === Set((0, 0))) // the map is necessary because of object-reuse in the edge iterator val e = filteredGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index a570e4ed75fc3..1f5e27d5508b8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.PartitionStrategy._ import org.apache.spark.rdd._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class GraphSuite extends FunSuite with LocalSparkContext { +class GraphSuite extends SparkFunSuite with LocalSparkContext { def starGraph(sc: SparkContext, n: Int): Graph[String, Int] = { Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: VertexId, x: VertexId)), 3), "v") @@ -248,7 +246,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { test("mask") { withSpark { sc => val n = 5 - val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x))) + val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) val graph: Graph[Int, Int] = Graph(vertices, edges).cache() @@ -260,11 +258,11 @@ class GraphSuite extends FunSuite with LocalSparkContext { val projectedGraph = graph.mask(subgraph) val v = projectedGraph.vertices.collect().toSet - assert(v === Set((0,0), (1,1), (2,2), (4,4), (5,5))) + assert(v === Set((0, 0), (1, 1), (2, 2), (4, 4), (5, 5))) // the map is necessary because of object-reuse in the edge iterator val e = projectedGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet - assert(e === Set(Edge(0,1,1), Edge(0,2,2), Edge(0,5,5))) + assert(e === Set(Edge(0, 1, 1), Edge(0, 2, 2), Edge(0, 5, 5))) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala index 490b94429ea1f..8afa2d403b53f 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.rdd._ -class PregelSuite extends FunSuite with LocalSparkContext { +class PregelSuite extends SparkFunSuite with LocalSparkContext { test("1 iteration") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index d0a7198d691d7..f1aa685a79c98 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.{HashPartitioner, SparkContext} +import org.apache.spark.{HashPartitioner, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -class VertexRDDSuite extends FunSuite with LocalSparkContext { +class VertexRDDSuite extends SparkFunSuite with LocalSparkContext { private def vertices(sc: SparkContext, n: Int) = { VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5)) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index 515f3a9cd02eb..7435647c6d9ee 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class EdgePartitionSuite extends FunSuite { +class EdgePartitionSuite extends SparkFunSuite { def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = { val builder = new EdgePartitionBuilder[A, Int] diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index fe8304c1cdc32..1203f8959f506 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.graphx.impl -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class VertexPartitionSuite extends FunSuite { +class VertexPartitionSuite extends SparkFunSuite { test("isDefined, filter") { val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala index 4cc30a96408f8..c965a6eb8df13 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class ConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Grid Connected Components") { withSpark { sc => @@ -52,13 +50,16 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1)) val chain2 = (10 until 20).map(x => (x, x + 1)) - val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0) val ccGraph = twoChains.connectedComponents() val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { - if(id < 10) { assert(cc === 0) } - else { assert(cc === 10) } + if (id < 10) { + assert(cc === 0) + } else { + assert(cc === 10) + } } val ccMap = vertices.toMap for (id <- 0 until 20) { @@ -75,7 +76,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1)) val chain2 = (10 until 20).map(x => (x, x + 1)) - val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse val ccGraph = twoChains.connectedComponents() val vertices = ccGraph.vertices.collect() @@ -106,9 +107,9 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { (4L, ("peter", "student")))) // Create an RDD for edges val relationships: RDD[Edge[String]] = - sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"), + sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"), Edge(2L, 5L, "colleague"), Edge(5L, 7L, "pi"), - Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague"))) + Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague"))) // Edges are: // 2 ---> 5 ---> 3 // | \ diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala index 61fd0c4605568..808877f0590f8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class LabelPropagationSuite extends FunSuite with LocalSparkContext { +class LabelPropagationSuite extends SparkFunSuite with LocalSparkContext { test("Label Propagation") { withSpark { sc => // Construct a graph with two cliques connected by a single edge diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 3f3c9dfd7b3dd..45f1e3011035e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators @@ -31,14 +30,14 @@ object GridPageRank { def sub2ind(r: Int, c: Int): Int = r * nCols + c // Make the grid graph for (r <- 0 until nRows; c <- 0 until nCols) { - val ind = sub2ind(r,c) + val ind = sub2ind(r, c) if (r + 1 < nRows) { outDegree(ind) += 1 - inNbrs(sub2ind(r + 1,c)) += ind + inNbrs(sub2ind(r + 1, c)) += ind } if (c + 1 < nCols) { outDegree(ind) += 1 - inNbrs(sub2ind(r,c + 1)) += ind + inNbrs(sub2ind(r, c + 1)) += ind } } // compute the pagerank @@ -57,7 +56,7 @@ object GridPageRank { } -class PageRankSuite extends FunSuite with LocalSparkContext { +class PageRankSuite extends SparkFunSuite with LocalSparkContext { def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = { a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) } @@ -99,8 +98,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val resetProb = 0.15 val errorTol = 1.0e-5 - val staticRanks1 = starGraph.staticPersonalizedPageRank(0,numIter = 1, resetProb).vertices - val staticRanks2 = starGraph.staticPersonalizedPageRank(0,numIter = 2, resetProb) + val staticRanks1 = starGraph.staticPersonalizedPageRank(0, numIter = 1, resetProb).vertices + val staticRanks2 = starGraph.staticPersonalizedPageRank(0, numIter = 2, resetProb) .vertices.cache() // Static PageRank should only take 2 iterations to converge @@ -117,7 +116,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } assert(staticErrors.sum === 0) - val dynamicRanks = starGraph.personalizedPageRank(0,0, resetProb).vertices.cache() + val dynamicRanks = starGraph.personalizedPageRank(0, 0, resetProb).vertices.cache() assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) } } // end of test Star PageRank @@ -162,7 +161,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { test("Chain PersonalizedPageRank") { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1) ) - val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() val resetProb = 0.15 val tol = 0.0001 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index 7bd6b7f3c4ab2..2991438f5e57e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { +class SVDPlusPlusSuite extends SparkFunSuite with LocalSparkContext { test("Test SVD++ with mean square error on training set") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala index f2c38e79c452c..d7eaa70ce6407 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ShortestPathsSuite extends FunSuite with LocalSparkContext { +class ShortestPathsSuite extends SparkFunSuite with LocalSparkContext { test("Shortest Path Computations") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala index 1f658c371ffcf..d6b03208180db 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class StronglyConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Island Strongly Connected Components") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index 293c7f3ba4c21..c47552cf3a3bd 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.PartitionStrategy.RandomVertexCut -class TriangleCountSuite extends FunSuite with LocalSparkContext { +class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle") { withSpark { sc => @@ -58,7 +57,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { val triangles = Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ Array(0L -> -1L, -1L -> -2L, -2L -> 0L) - val revTriangles = triangles.map { case (a,b) => (b,a) } + val revTriangles = triangles.map { case (a, b) => (b, a) } val rawEdges = sc.parallelize(triangles ++ revTriangles, 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() val triangleCount = graph.triangleCount() diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index f3b3738db0dad..61e44dcab578c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class BytecodeUtilsSuite extends FunSuite { +// scalastyle:off println +class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass @@ -102,6 +103,7 @@ class BytecodeUtilsSuite extends FunSuite { private val c = {e: TestClass => println(e.baz)} } +// scalastyle:on println object BytecodeUtilsSuite { class TestClass(val foo: Int, val bar: Long) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala index 8d9c8ddccbb3c..32e0c841c6997 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.LocalSparkContext -class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { +class GraphGeneratorsSuite extends SparkFunSuite with LocalSparkContext { test("GraphGenerators.generateRandomEdges") { val src = 5 diff --git a/launcher/pom.xml b/launcher/pom.xml index ebfa7685eaa18..2fd768d8119c4 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,14 +22,14 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml org.apache.spark spark-launcher_2.10 jar - Spark Launcher Project + Spark Project Launcher http://spark.apache.org/ launcher @@ -49,7 +49,7 @@ org.mockito - mockito-all + mockito-core test @@ -68,12 +68,6 @@ org.apache.hadoop hadoop-client test - - - org.codehaus.jackson - jackson-mapper-asl - - diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index b8f02b961113d..5e793a5c48775 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -121,7 +121,10 @@ List buildJavaCommand(String extraClassPath) throws IOException { * set it. */ void addPermGenSizeOpt(List cmd) { - // Don't set MaxPermSize for Java 8 and later. + // Don't set MaxPermSize for IBM Java, or Oracle Java 8 and later. + if (getJavaVendor() == JavaVendor.IBM) { + return; + } String[] version = System.getProperty("java.version").split("\\."); if (Integer.parseInt(version[0]) > 1 || Integer.parseInt(version[1]) > 7) { return; @@ -133,7 +136,7 @@ void addPermGenSizeOpt(List cmd) { } } - cmd.add("-XX:MaxPermSize=128m"); + cmd.add("-XX:MaxPermSize=256m"); } void addOptionString(List cmd, String options) { @@ -293,6 +296,9 @@ Properties loadPropertiesFile() throws IOException { try { fd = new FileInputStream(propsFile); props.load(new InputStreamReader(fd, "UTF-8")); + for (Map.Entry e : props.entrySet()) { + e.setValue(e.getValue().toString().trim()); + } } finally { if (fd != null) { try { diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 261402856ac5e..a16c0d2b5ca0b 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -27,11 +27,16 @@ */ class CommandBuilderUtils { - static final String DEFAULT_MEM = "512m"; + static final String DEFAULT_MEM = "1g"; static final String DEFAULT_PROPERTIES_FILE = "spark-defaults.conf"; static final String ENV_SPARK_HOME = "SPARK_HOME"; static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; + /** The set of known JVM vendors. */ + static enum JavaVendor { + Oracle, IBM, OpenJDK, Unknown + }; + /** Returns whether the given string is null or empty. */ static boolean isEmpty(String s) { return s == null || s.isEmpty(); @@ -108,6 +113,21 @@ static boolean isWindows() { return os.startsWith("Windows"); } + /** Returns an enum value indicating whose JVM is being used. */ + static JavaVendor getJavaVendor() { + String vendorString = System.getProperty("java.vendor"); + if (vendorString.contains("Oracle")) { + return JavaVendor.Oracle; + } + if (vendorString.contains("IBM")) { + return JavaVendor.IBM; + } + if (vendorString.contains("OpenJDK")) { + return JavaVendor.OpenJDK; + } + return JavaVendor.Unknown; + } + /** * Updates the user environment, appending the given pathList to the existing value of the given * environment variable (or setting it if it hasn't yet been set). diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 929b29a49ed70..62492f9baf3bb 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -53,21 +53,33 @@ public static void main(String[] argsArray) throws Exception { List args = new ArrayList(Arrays.asList(argsArray)); String className = args.remove(0); - boolean printLaunchCommand; - boolean printUsage; + boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); AbstractCommandBuilder builder; - try { - if (className.equals("org.apache.spark.deploy.SparkSubmit")) { + if (className.equals("org.apache.spark.deploy.SparkSubmit")) { + try { builder = new SparkSubmitCommandBuilder(args); - } else { - builder = new SparkClassCommandBuilder(className, args); + } catch (IllegalArgumentException e) { + printLaunchCommand = false; + System.err.println("Error: " + e.getMessage()); + System.err.println(); + + MainClassOptionParser parser = new MainClassOptionParser(); + try { + parser.parse(args); + } catch (Exception ignored) { + // Ignore parsing exceptions. + } + + List help = new ArrayList(); + if (parser.className != null) { + help.add(parser.CLASS); + help.add(parser.className); + } + help.add(parser.USAGE_ERROR); + builder = new SparkSubmitCommandBuilder(help); } - printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); - printUsage = false; - } catch (IllegalArgumentException e) { - builder = new UsageCommandBuilder(e.getMessage()); - printLaunchCommand = false; - printUsage = true; + } else { + builder = new SparkClassCommandBuilder(className, args); } Map env = new HashMap(); @@ -78,13 +90,7 @@ public static void main(String[] argsArray) throws Exception { } if (isWindows()) { - // When printing the usage message, we can't use "cmd /v" since that prevents the env - // variable from being seen in the caller script. So do not call prepareWindowsCommand(). - if (printUsage) { - System.out.println(join(" ", cmd)); - } else { - System.out.println(prepareWindowsCommand(cmd, env)); - } + System.out.println(prepareWindowsCommand(cmd, env)); } else { // In bash, use NULL as the arg separator since it cannot be used in an argument. List bashCmd = prepareBashCommand(cmd, env); @@ -135,33 +141,30 @@ private static List prepareBashCommand(List cmd, Map buildCommand(Map env) { - if (isWindows()) { - return Arrays.asList("set", "SPARK_LAUNCHER_USAGE_ERROR=" + message); - } else { - return Arrays.asList("usage", message, "1"); - } + protected boolean handleUnknown(String opt) { + return false; + } + + @Override + protected void handleExtraArgs(List extra) { + } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index d80abf2a8676e..de85720febf23 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -93,6 +93,9 @@ public List buildCommand(Map env) throws IOException { toolsDir.getAbsolutePath(), className); javaOptsKeys.add("SPARK_JAVA_OPTS"); + } else { + javaOptsKeys.add("SPARK_JAVA_OPTS"); + memKey = "SPARK_DRIVER_MEMORY"; } List cmd = buildJavaCommand(extraClassPath); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index d4cfeacb6ef18..c0f89c9230692 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -25,11 +25,12 @@ import static org.apache.spark.launcher.CommandBuilderUtils.*; -/** +/** * Launcher for Spark applications. - *

    + *

    * Use this class to start Spark applications programmatically. The class uses a builder pattern * to allow clients to configure the Spark application and launch it as a child process. + *

    */ public class SparkLauncher { diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 7d387d406edae..87c43aa9980e1 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -77,6 +77,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } private final List sparkArgs; + private final boolean printHelp; /** * Controls whether mixing spark-submit arguments with app arguments is allowed. This is needed @@ -87,10 +88,11 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkSubmitCommandBuilder() { this.sparkArgs = new ArrayList(); + this.printHelp = false; } SparkSubmitCommandBuilder(List args) { - this(); + this.sparkArgs = new ArrayList(); List submitArgs = args; if (args.size() > 0 && args.get(0).equals(PYSPARK_SHELL)) { this.allowsMixedArguments = true; @@ -104,14 +106,16 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { this.allowsMixedArguments = false; } - new OptionParser().parse(submitArgs); + OptionParser parser = new OptionParser(); + parser.parse(submitArgs); + this.printHelp = parser.helpRequested; } @Override public List buildCommand(Map env) throws IOException { - if (PYSPARK_SHELL_RESOURCE.equals(appResource)) { + if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printHelp) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL_RESOURCE.equals(appResource)) { + } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printHelp) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -204,7 +208,7 @@ private List buildSparkSubmitCommand(Map env) throws IOE // - properties file. // - SPARK_DRIVER_MEMORY env variable // - SPARK_MEM env variable - // - default value (512m) + // - default value (1g) // Take Thrift Server as daemon String tsMemory = isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null; @@ -311,6 +315,8 @@ private boolean isThriftServer(String mainClass) { private class OptionParser extends SparkSubmitOptionParser { + boolean helpRequested = false; + @Override protected boolean handle(String opt, String value) { if (opt.equals(MASTER)) { @@ -341,6 +347,9 @@ protected boolean handle(String opt, String value) { allowsMixedArguments = true; appResource = specialClasses.get(value); } + } else if (opt.equals(HELP) || opt.equals(USAGE_ERROR)) { + helpRequested = true; + sparkArgs.add(opt); } else { sparkArgs.add(opt); if (value != null) { @@ -360,6 +369,7 @@ protected boolean handleUnknown(String opt) { appArgs.add(opt); return true; } else { + checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); sparkArgs.add(opt); return false; } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index 229000087688f..b88bba883ac65 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -61,6 +61,7 @@ class SparkSubmitOptionParser { // Options that do not take arguments. protected final String HELP = "--help"; protected final String SUPERVISE = "--supervise"; + protected final String USAGE_ERROR = "--usage-error"; protected final String VERBOSE = "--verbose"; protected final String VERSION = "--version"; @@ -120,6 +121,7 @@ class SparkSubmitOptionParser { final String[][] switches = { { HELP, "-h" }, { SUPERVISE }, + { USAGE_ERROR }, { VERBOSE, "-v" }, { VERSION }, }; diff --git a/launcher/src/main/java/org/apache/spark/launcher/package-info.java b/launcher/src/main/java/org/apache/spark/launcher/package-info.java index 7ed756f4b8591..7c97dba511b28 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/package-info.java +++ b/launcher/src/main/java/org/apache/spark/launcher/package-info.java @@ -17,13 +17,17 @@ /** * Library for launching Spark applications. - *

    + * + *

    * This library allows applications to launch Spark programmatically. There's only one entry * point to the library - the {@link org.apache.spark.launcher.SparkLauncher} class. - *

    + *

    + * + *

    * To launch a Spark application, just instantiate a {@link org.apache.spark.launcher.SparkLauncher} * and configure the application to run. For example: - * + *

    + * *
      * {@code
      *   import org.apache.spark.launcher.SparkLauncher;
    diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
    index 97043a76cc612..7329ac9f7fb8c 100644
    --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
    +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
    @@ -194,7 +194,7 @@ private void testCmdBuilder(boolean isDriver) throws Exception {
             if (isDriver) {
               assertEquals("-XX:MaxPermSize=256m", arg);
             } else {
    -          assertEquals("-XX:MaxPermSize=128m", arg);
    +          assertEquals("-XX:MaxPermSize=256m", arg);
             }
           }
         }
    diff --git a/make-distribution.sh b/make-distribution.sh
    index 1bfa9acb1fe6e..cac7032bb2e87 100755
    --- a/make-distribution.sh
    +++ b/make-distribution.sh
    @@ -58,7 +58,7 @@ while (( "$#" )); do
         --hadoop)
           echo "Error: '--hadoop' is no longer supported:"
           echo "Error: use Maven profiles and options -Dhadoop.version and -Dyarn.version instead."
    -      echo "Error: Related profiles include hadoop-2.2, hadoop-2.3 and hadoop-2.4."
    +      echo "Error: Related profiles include hadoop-1, hadoop-2.2, hadoop-2.3 and hadoop-2.4."
           exit_with_usage
           ;;
         --with-yarn)
    @@ -141,22 +141,6 @@ SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hi
         # because we use "set -o pipefail"
         echo -n)
     
    -JAVA_CMD="$JAVA_HOME"/bin/java
    -JAVA_VERSION=$("$JAVA_CMD" -version 2>&1)
    -if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then
    -  echo "***NOTE***: JAVA_HOME is not set to a JDK 6 installation. The resulting"
    -  echo "            distribution may not work well with PySpark and will not run"
    -  echo "            with Java 6 (See SPARK-1703 and SPARK-1911)."
    -  echo "            This test can be disabled by adding --skip-java-test."
    -  echo "Output from 'java -version' was:"
    -  echo "$JAVA_VERSION"
    -  read -p "Would you like to continue anyways? [y,n]: " -r
    -  if [[ ! "$REPLY" =~ ^[Yy]$ ]]; then
    -    echo "Okay, exiting."
    -    exit 1
    -  fi
    -fi
    -
     if [ "$NAME" == "none" ]; then
       NAME=$SPARK_HADOOP_VERSION
     fi
    @@ -231,6 +215,12 @@ cp -r "$SPARK_HOME/bin" "$DISTDIR"
     cp -r "$SPARK_HOME/python" "$DISTDIR"
     cp -r "$SPARK_HOME/sbin" "$DISTDIR"
     cp -r "$SPARK_HOME/ec2" "$DISTDIR"
    +# Copy SparkR if it exists
    +if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then
    +  mkdir -p "$DISTDIR"/R/lib
    +  cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib
    +  cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib
    +fi
     
     # Download and copy in tachyon, if requested
     if [ "$SPARK_TACHYON" == "true" ]; then
    diff --git a/mllib/pom.xml b/mllib/pom.xml
    index 0c07ca1a62fd3..a5db14407b4fc 100644
    --- a/mllib/pom.xml
    +++ b/mllib/pom.xml
    @@ -21,7 +21,7 @@
       
         org.apache.spark
         spark-parent_2.10
    -    1.4.0-SNAPSHOT
    +    1.5.0-SNAPSHOT
         ../pom.xml
       
     
    @@ -40,6 +40,13 @@
           spark-core_${scala.binary.version}
           ${project.version}
         
    +    
    +      org.apache.spark
    +      spark-core_${scala.binary.version}
    +      ${project.version}
    +      test-jar
    +      test
    +    
         
           org.apache.spark
           spark-streaming_${scala.binary.version}
    @@ -99,7 +106,7 @@
         
         
           org.mockito
    -      mockito-all
    +      mockito-core
           test
         
         
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
    index 7f3f3262a644f..57e416591de69 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
    @@ -19,16 +19,16 @@ package org.apache.spark.ml
     
     import scala.annotation.varargs
     
    -import org.apache.spark.annotation.AlphaComponent
    -import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
    +import org.apache.spark.annotation.DeveloperApi
    +import org.apache.spark.ml.param.{ParamMap, ParamPair}
     import org.apache.spark.sql.DataFrame
     
     /**
    - * :: AlphaComponent ::
    + * :: DeveloperApi ::
      * Abstract class for estimators that fit models to data.
      */
    -@AlphaComponent
    -abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
    +@DeveloperApi
    +abstract class Estimator[M <: Model[M]] extends PipelineStage {
     
       /**
        * Fits a single model to the input data with optional parameters.
    @@ -78,7 +78,5 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
         paramMaps.map(fit(dataset, _))
       }
     
    -  override def copy(extra: ParamMap): Estimator[M] = {
    -    super.copy(extra).asInstanceOf[Estimator[M]]
    -  }
    +  override def copy(extra: ParamMap): Estimator[M]
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
    index 9974efe7b1d25..252acc156583f 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
    @@ -17,25 +17,33 @@
     
     package org.apache.spark.ml
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.DeveloperApi
     import org.apache.spark.ml.param.ParamMap
     
     /**
    - * :: AlphaComponent ::
    + * :: DeveloperApi ::
      * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]].
      *
      * @tparam M model type
      */
    -@AlphaComponent
    +@DeveloperApi
     abstract class Model[M <: Model[M]] extends Transformer {
       /**
        * The parent estimator that produced this model.
        * Note: For ensembles' component Models, this value can be null.
        */
    -  val parent: Estimator[M]
    +  @transient var parent: Estimator[M] = _
     
    -  override def copy(extra: ParamMap): M = {
    -    // The default implementation of Params.copy doesn't work for models.
    -    throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")
    +  /**
    +   * Sets the parent of this model (Java API).
    +   */
    +  def setParent(parent: Estimator[M]): M = {
    +    this.parent = parent
    +    this.asInstanceOf[M]
       }
    +
    +  /** Indicates whether this [[Model]] has a corresponding parent. */
    +  def hasParent: Boolean = parent != null
    +
    +  override def copy(extra: ParamMap): M
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
    index 33d430f5671ee..aef2c019d2871 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
    @@ -17,19 +17,23 @@
     
     package org.apache.spark.ml
     
    +import java.{util => ju}
    +
    +import scala.collection.JavaConverters._
     import scala.collection.mutable.ListBuffer
     
     import org.apache.spark.Logging
    -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
    +import org.apache.spark.annotation.{DeveloperApi, Experimental}
     import org.apache.spark.ml.param.{Param, ParamMap, Params}
    +import org.apache.spark.ml.util.Identifiable
     import org.apache.spark.sql.DataFrame
     import org.apache.spark.sql.types.StructType
     
     /**
    - * :: AlphaComponent ::
    + * :: DeveloperApi ::
      * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]].
      */
    -@AlphaComponent
    +@DeveloperApi
     abstract class PipelineStage extends Params with Logging {
     
       /**
    @@ -62,13 +66,11 @@ abstract class PipelineStage extends Params with Logging {
         outputSchema
       }
     
    -  override def copy(extra: ParamMap): PipelineStage = {
    -    super.copy(extra).asInstanceOf[PipelineStage]
    -  }
    +  override def copy(extra: ParamMap): PipelineStage
     }
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each
      * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline#fit]] is called, the
      * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator#fit]] method will
    @@ -79,8 +81,10 @@ abstract class PipelineStage extends Params with Logging {
      * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as
      * an identity transformer.
      */
    -@AlphaComponent
    -class Pipeline extends Estimator[PipelineModel] {
    +@Experimental
    +class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
    +
    +  def this() = this(Identifiable.randomUID("pipeline"))
     
       /**
        * param for pipeline stages
    @@ -91,15 +95,14 @@ class Pipeline extends Estimator[PipelineModel] {
       /** @group setParam */
       def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
     
    +  // Below, we clone stages so that modifications to the list of stages will not change
    +  // the Param value in the Pipeline.
       /** @group getParam */
       def getStages: Array[PipelineStage] = $(stages).clone()
     
    -  override def validateParams(paramMap: ParamMap): Unit = {
    -    val map = extractParamMap(paramMap)
    -    getStages.foreach {
    -      case pStage: Params => pStage.validateParams(map)
    -      case _ =>
    -    }
    +  override def validateParams(): Unit = {
    +    super.validateParams()
    +    $(stages).foreach(_.validateParams())
       }
     
       /**
    @@ -148,7 +151,7 @@ class Pipeline extends Estimator[PipelineModel] {
           }
         }
     
    -    new PipelineModel(this, transformers.toArray)
    +    new PipelineModel(uid, transformers.toArray).setParent(this)
       }
     
       override def copy(extra: ParamMap): Pipeline = {
    @@ -166,15 +169,20 @@ class Pipeline extends Estimator[PipelineModel] {
     }
     
     /**
    - * :: AlphaComponent ::
    - * Represents a compiled pipeline.
    + * :: Experimental ::
    + * Represents a fitted pipeline.
      */
    -@AlphaComponent
    +@Experimental
     class PipelineModel private[ml] (
    -    override val parent: Pipeline,
    +    override val uid: String,
         val stages: Array[Transformer])
       extends Model[PipelineModel] with Logging {
     
    +  /** A Java/Python-friendly auxiliary constructor. */
    +  private[ml] def this(uid: String, stages: ju.List[Transformer]) = {
    +    this(uid, stages.asScala.toArray)
    +  }
    +
       override def validateParams(): Unit = {
         super.validateParams()
         stages.foreach(_.validateParams())
    @@ -190,6 +198,6 @@ class PipelineModel private[ml] (
       }
     
       override def copy(extra: ParamMap): PipelineModel = {
    -    new PipelineModel(parent, stages)
    +    new PipelineModel(uid, stages.map(_.copy(extra)))
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
    index f6a5f27425d1f..333b42711ec52 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
    @@ -58,7 +58,6 @@ private[ml] trait PredictorParams extends Params
     
     /**
      * :: DeveloperApi ::
    - *
      * Abstraction for prediction problems (regression and classification).
      *
      * @tparam FeaturesType  Type of features.
    @@ -88,12 +87,10 @@ abstract class Predictor[
         // This handles a few items such as schema validation.
         // Developers only need to implement train().
         transformSchema(dataset.schema, logging = true)
    -    copyValues(train(dataset))
    +    copyValues(train(dataset).setParent(this))
       }
     
    -  override def copy(extra: ParamMap): Learner = {
    -    super.copy(extra).asInstanceOf[Learner]
    -  }
    +  override def copy(extra: ParamMap): Learner
     
       /**
        * Train a model using the given dataset and parameters.
    @@ -113,7 +110,6 @@ abstract class Predictor[
        *
        * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
        */
    -  @DeveloperApi
       private[ml] def featuresDataType: DataType = new VectorUDT
     
       override def transformSchema(schema: StructType): StructType = {
    @@ -126,15 +122,12 @@ abstract class Predictor[
        */
       protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
         dataset.select($(labelCol), $(featuresCol))
    -      .map { case Row(label: Double, features: Vector) =>
    -      LabeledPoint(label, features)
    -    }
    +      .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) }
       }
     }
     
     /**
      * :: DeveloperApi ::
    - *
      * Abstraction for a model for prediction tasks (regression and classification).
      *
      * @tparam FeaturesType  Type of features.
    @@ -176,7 +169,10 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
       override def transform(dataset: DataFrame): DataFrame = {
         transformSchema(dataset.schema, logging = true)
         if ($(predictionCol).nonEmpty) {
    -      dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol))))
    +      val predictUDF = udf { (features: Any) =>
    +        predict(features.asInstanceOf[FeaturesType])
    +      }
    +      dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
         } else {
           this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
             " since no output columns were set.")
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
    index d96b54e511e9c..3c7bcf7590e6d 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
    @@ -20,7 +20,7 @@ package org.apache.spark.ml
     import scala.annotation.varargs
     
     import org.apache.spark.Logging
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.DeveloperApi
     import org.apache.spark.ml.param._
     import org.apache.spark.ml.param.shared._
     import org.apache.spark.sql.DataFrame
    @@ -28,11 +28,11 @@ import org.apache.spark.sql.functions._
     import org.apache.spark.sql.types._
     
     /**
    - * :: AlphaComponent ::
    + * :: DeveloperApi ::
      * Abstract class for transformers that transform one dataset into another.
      */
    -@AlphaComponent
    -abstract class Transformer extends PipelineStage with Params {
    +@DeveloperApi
    +abstract class Transformer extends PipelineStage {
     
       /**
        * Transforms the dataset with optional parameters
    @@ -67,16 +67,16 @@ abstract class Transformer extends PipelineStage with Params {
        */
       def transform(dataset: DataFrame): DataFrame
     
    -  override def copy(extra: ParamMap): Transformer = {
    -    super.copy(extra).asInstanceOf[Transformer]
    -  }
    +  override def copy(extra: ParamMap): Transformer
     }
     
     /**
    + * :: DeveloperApi ::
      * Abstract class for transformers that take one input column, apply transformation, and output the
      * result as a new column.
      */
    -private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
    +@DeveloperApi
    +abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
       extends Transformer with HasInputCol with HasOutputCol with Logging {
     
       /** @group setParam */
    @@ -118,4 +118,6 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
         dataset.withColumn($(outputCol),
           callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
       }
    +
    +  override def copy(extra: ParamMap): T = defaultCopy(extra)
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
    index f5f37aa77929c..457c15830fd38 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
    @@ -19,10 +19,12 @@ package org.apache.spark.ml.attribute
     
     import scala.collection.mutable.ArrayBuffer
     
    +import org.apache.spark.annotation.DeveloperApi
     import org.apache.spark.mllib.linalg.VectorUDT
     import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField}
     
     /**
    + * :: DeveloperApi ::
      * Attributes that describe a vector ML column.
      *
      * @param name name of the attribute group (the ML column name)
    @@ -31,6 +33,7 @@ import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField}
      * @param attrs optional array of attributes. Attribute will be copied with their corresponding
      *              indices in the array.
      */
    +@DeveloperApi
     class AttributeGroup private (
         val name: String,
         val numAttributes: Option[Int],
    @@ -182,7 +185,11 @@ class AttributeGroup private (
       }
     }
     
    -/** Factory methods to create attribute groups. */
    +/**
    + * :: DeveloperApi ::
    + * Factory methods to create attribute groups.
    + */
    +@DeveloperApi
     object AttributeGroup {
     
       import AttributeKeys._
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
    index a83febd7de2cc..5c7089b491677 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
    @@ -17,12 +17,17 @@
     
     package org.apache.spark.ml.attribute
     
    +import org.apache.spark.annotation.DeveloperApi
    +
     /**
    + * :: DeveloperApi ::
      * An enum-like type for attribute types: [[AttributeType$#Numeric]], [[AttributeType$#Nominal]],
      * and [[AttributeType$#Binary]].
      */
    +@DeveloperApi
     sealed abstract class AttributeType(val name: String)
     
    +@DeveloperApi
     object AttributeType {
     
       /** Numeric type. */
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
    index e8f7f152784a1..e479f169021d8 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
    @@ -19,11 +19,14 @@ package org.apache.spark.ml.attribute
     
     import scala.annotation.varargs
     
    -import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField}
    +import org.apache.spark.annotation.DeveloperApi
    +import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField}
     
     /**
    + * :: DeveloperApi ::
      * Abstract class for ML attributes.
      */
    +@DeveloperApi
     sealed abstract class Attribute extends Serializable {
     
       name.foreach { n =>
    @@ -124,7 +127,7 @@ private[attribute] trait AttributeFactory {
        * Creates an [[Attribute]] from a [[StructField]] instance.
        */
       def fromStructField(field: StructField): Attribute = {
    -    require(field.dataType == DoubleType)
    +    require(field.dataType.isInstanceOf[NumericType])
         val metadata = field.metadata
         val mlAttr = AttributeKeys.ML_ATTR
         if (metadata.contains(mlAttr)) {
    @@ -135,6 +138,10 @@ private[attribute] trait AttributeFactory {
       }
     }
     
    +/**
    + * :: DeveloperApi ::
    + */
    +@DeveloperApi
     object Attribute extends AttributeFactory {
     
       private[attribute] override def fromMetadata(metadata: Metadata): Attribute = {
    @@ -163,6 +170,7 @@ object Attribute extends AttributeFactory {
     
     
     /**
    + * :: DeveloperApi ::
      * A numeric attribute with optional summary statistics.
      * @param name optional name
      * @param index optional index
    @@ -171,6 +179,7 @@ object Attribute extends AttributeFactory {
      * @param std optional standard deviation
      * @param sparsity optional sparsity (ratio of zeros)
      */
    +@DeveloperApi
     class NumericAttribute private[ml] (
         override val name: Option[String] = None,
         override val index: Option[Int] = None,
    @@ -278,8 +287,10 @@ class NumericAttribute private[ml] (
     }
     
     /**
    + * :: DeveloperApi ::
      * Factory methods for numeric attributes.
      */
    +@DeveloperApi
     object NumericAttribute extends AttributeFactory {
     
       /** The default numeric attribute. */
    @@ -298,6 +309,7 @@ object NumericAttribute extends AttributeFactory {
     }
     
     /**
    + * :: DeveloperApi ::
      * A nominal attribute.
      * @param name optional name
      * @param index optional index
    @@ -306,6 +318,7 @@ object NumericAttribute extends AttributeFactory {
      *                  defined.
      * @param values optional values. At most one of `numValues` and `values` can be defined.
      */
    +@DeveloperApi
     class NominalAttribute private[ml] (
         override val name: Option[String] = None,
         override val index: Option[Int] = None,
    @@ -430,7 +443,11 @@ class NominalAttribute private[ml] (
       }
     }
     
    -/** Factory methods for nominal attributes. */
    +/**
    + * :: DeveloperApi ::
    + * Factory methods for nominal attributes.
    + */
    +@DeveloperApi
     object NominalAttribute extends AttributeFactory {
     
       /** The default nominal attribute. */
    @@ -450,11 +467,13 @@ object NominalAttribute extends AttributeFactory {
     }
     
     /**
    + * :: DeveloperApi ::
      * A binary attribute.
      * @param name optional name
      * @param index optional index
      * @param values optionla values. If set, its size must be 2.
      */
    +@DeveloperApi
     class BinaryAttribute private[ml] (
         override val name: Option[String] = None,
         override val index: Option[Int] = None,
    @@ -526,7 +545,11 @@ class BinaryAttribute private[ml] (
       }
     }
     
    -/** Factory methods for binary attributes. */
    +/**
    + * :: DeveloperApi ::
    + * Factory methods for binary attributes.
    + */
    +@DeveloperApi
     object BinaryAttribute extends AttributeFactory {
     
       /** The default binary attribute. */
    @@ -543,8 +566,10 @@ object BinaryAttribute extends AttributeFactory {
     }
     
     /**
    + * :: DeveloperApi ::
      * An unresolved attribute.
      */
    +@DeveloperApi
     object UnresolvedAttribute extends Attribute {
     
       override def attrType: AttributeType = AttributeType.Unresolved
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
    index 263d580fe2dd3..85c097bc64a4f 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
    @@ -18,6 +18,7 @@
     package org.apache.spark.ml.classification
     
     import org.apache.spark.annotation.DeveloperApi
    +import org.apache.spark.ml.param.ParamMap
     import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
     import org.apache.spark.ml.param.shared.HasRawPredictionCol
     import org.apache.spark.ml.util.SchemaUtils
    @@ -101,15 +102,20 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
         var outputData = dataset
         var numColsOutput = 0
         if (getRawPredictionCol != "") {
    -      outputData = outputData.withColumn(getRawPredictionCol,
    -        callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol)))
    +      val predictRawUDF = udf { (features: Any) =>
    +        predictRaw(features.asInstanceOf[FeaturesType])
    +      }
    +      outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
           numColsOutput += 1
         }
         if (getPredictionCol != "") {
           val predUDF = if (getRawPredictionCol != "") {
    -        callUDF(raw2prediction _, DoubleType, col(getRawPredictionCol))
    +        udf(raw2prediction _).apply(col(getRawPredictionCol))
           } else {
    -        callUDF(predict _, DoubleType, col(getFeaturesCol))
    +        val predictUDF = udf { (features: Any) =>
    +          predict(features.asInstanceOf[FeaturesType])
    +        }
    +        predictUDF(col(getFeaturesCol))
           }
           outputData = outputData.withColumn(getPredictionCol, predUDF)
           numColsOutput += 1
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
    index dcebea1d4b015..2dc1824964a42 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
    @@ -17,11 +17,11 @@
     
     package org.apache.spark.ml.classification
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.ParamMap
    -import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node}
    -import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
    +import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
    @@ -31,18 +31,19 @@ import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
      * for classification.
      * It supports both binary and multiclass labels, as well as both continuous and categorical
      * features.
      */
    -@AlphaComponent
    -final class DecisionTreeClassifier
    +@Experimental
    +final class DecisionTreeClassifier(override val uid: String)
       extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
       with DecisionTreeParams with TreeClassifierParams {
     
    +  def this() = this(Identifiable.randomUID("dtc"))
    +
       // Override parameter setters from parent trait for Java API compatibility.
     
       override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
    @@ -85,23 +86,25 @@ final class DecisionTreeClassifier
         super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
           subsamplingRate = 1.0)
       }
    +
    +  override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra)
     }
     
    +@Experimental
     object DecisionTreeClassifier {
       /** Accessor for supported impurities: entropy, gini */
       final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
     }
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification.
      * It supports both binary and multiclass labels, as well as both continuous and categorical
      * features.
      */
    -@AlphaComponent
    +@Experimental
     final class DecisionTreeClassificationModel private[ml] (
    -    override val parent: DecisionTreeClassifier,
    +    override val uid: String,
         override val rootNode: Node)
       extends PredictionModel[Vector, DecisionTreeClassificationModel]
       with DecisionTreeModel with Serializable {
    @@ -114,7 +117,7 @@ final class DecisionTreeClassificationModel private[ml] (
       }
     
       override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
    -    copyValues(new DecisionTreeClassificationModel(parent, rootNode), extra)
    +    copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra)
       }
     
       override def toString: String = {
    @@ -138,6 +141,7 @@ private[ml] object DecisionTreeClassificationModel {
           s"Cannot convert non-classification DecisionTreeModel (old API) to" +
             s" DecisionTreeClassificationModel (new API).  Algo is: ${oldModel.algo}")
         val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
    -    new DecisionTreeClassificationModel(parent, rootNode)
    +    val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
    +    new DecisionTreeClassificationModel(uid, rootNode)
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
    index ae51b05a0c42d..554e3b8e052b2 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
    @@ -20,12 +20,12 @@ package org.apache.spark.ml.classification
     import com.github.fommil.netlib.BLAS.{getInstance => blas}
     
     import org.apache.spark.Logging
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.{Param, ParamMap}
     import org.apache.spark.ml.regression.DecisionTreeRegressionModel
    -import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
    -import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams, TreeEnsembleModel}
    +import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
    @@ -36,18 +36,19 @@ import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
      * learning algorithm for classification.
      * It supports binary labels, as well as both continuous and categorical features.
      * Note: Multiclass labels are not currently supported.
      */
    -@AlphaComponent
    -final class GBTClassifier
    +@Experimental
    +final class GBTClassifier(override val uid: String)
       extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
       with GBTParams with TreeClassifierParams with Logging {
     
    +  def this() = this(Identifiable.randomUID("gbtc"))
    +
       // Override parameter setters from parent trait for Java API compatibility.
     
       // Parameters from TreeClassifierParams:
    @@ -140,8 +141,11 @@ final class GBTClassifier
         val oldModel = oldGBT.run(oldDataset)
         GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
       }
    +
    +  override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
     }
     
    +@Experimental
     object GBTClassifier {
       // The losses below should be lowercase.
       /** Accessor for supported loss settings: logistic */
    @@ -149,8 +153,7 @@ object GBTClassifier {
     }
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
      * model for classification.
      * It supports binary labels, as well as both continuous and categorical features.
    @@ -158,9 +161,9 @@ object GBTClassifier {
      * @param _trees  Decision trees in the ensemble.
      * @param _treeWeights  Weights for the decision trees in the ensemble.
      */
    -@AlphaComponent
    +@Experimental
     final class GBTClassificationModel(
    -    override val parent: GBTClassifier,
    +    override val uid: String,
         private val _trees: Array[DecisionTreeRegressionModel],
         private val _treeWeights: Array[Double])
       extends PredictionModel[Vector, GBTClassificationModel]
    @@ -184,7 +187,7 @@ final class GBTClassificationModel(
       }
     
       override def copy(extra: ParamMap): GBTClassificationModel = {
    -    copyValues(new GBTClassificationModel(parent, _trees, _treeWeights), extra)
    +    copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra)
       }
     
       override def toString: String = {
    @@ -207,9 +210,10 @@ private[ml] object GBTClassificationModel {
         require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
           s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
         val newTrees = oldModel.trees.map { tree =>
    -      // parent, fittingParamMap for each tree is null since there are no good ways to set these.
    +      // parent for each tree is null since there is no good way to set this.
           DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
         }
    -    new GBTClassificationModel(parent, newTrees, oldModel.treeWeights)
    +    val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
    +    new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights)
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
    index 93ba91167bfad..8fc9199fb4602 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
    @@ -19,13 +19,14 @@ package org.apache.spark.ml.classification
     
     import scala.collection.mutable
     
    -import breeze.linalg.{norm => brzNorm, DenseVector => BDV}
    -import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
    -import breeze.optimize.{CachedDiffFunction, DiffFunction}
    +import breeze.linalg.{DenseVector => BDV}
    +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.{Logging, SparkException}
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.param._
     import org.apache.spark.ml.param.shared._
    +import org.apache.spark.ml.util.Identifiable
     import org.apache.spark.mllib.linalg._
     import org.apache.spark.mllib.linalg.BLAS._
     import org.apache.spark.mllib.regression.LabeledPoint
    @@ -34,26 +35,26 @@ import org.apache.spark.mllib.util.MLUtils
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
     import org.apache.spark.storage.StorageLevel
    -import org.apache.spark.{SparkException, Logging}
     
     /**
      * Params for logistic regression.
      */
     private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
       with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
    -  with HasThreshold
    +  with HasThreshold with HasStandardization
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * Logistic regression.
      * Currently, this class only supports binary classification.
      */
    -@AlphaComponent
    -class LogisticRegression
    +@Experimental
    +class LogisticRegression(override val uid: String)
       extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
       with LogisticRegressionParams with Logging {
     
    +  def this() = this(Identifiable.randomUID("logreg"))
    +
       /**
        * Set the regularization parameter.
        * Default is 0.0.
    @@ -73,7 +74,7 @@ class LogisticRegression
       setDefault(elasticNetParam -> 0.0)
     
       /**
    -   * Set the maximal number of iterations.
    +   * Set the maximum number of iterations.
        * Default is 100.
        * @group setParam
        */
    @@ -89,10 +90,26 @@ class LogisticRegression
       def setTol(value: Double): this.type = set(tol, value)
       setDefault(tol -> 1E-6)
     
    -  /** @group setParam */
    +  /**
    +   * Whether to fit an intercept term.
    +   * Default is true.
    +   * @group setParam
    +   * */
       def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
       setDefault(fitIntercept -> true)
     
    +  /**
    +   * Whether to standardize the training features before fitting the model.
    +   * The coefficients of models will be always returned on the original scale,
    +   * so it will be transparent for users. Note that when no regularization,
    +   * with or without standardization, the models should be always converged to
    +   * the same solution.
    +   * Default is true.
    +   * @group setParam
    +   * */
    +  def setStandardization(value: Boolean): this.type = set(standardization, value)
    +  setDefault(standardization -> true)
    +
       /** @group setParam */
       def setThreshold(value: Double): this.type = set(threshold, value)
       setDefault(threshold -> 0.5)
    @@ -111,7 +128,7 @@ class LogisticRegression
               case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer),
               (label: Double, features: Vector)) =>
                 (summarizer.add(features), labelSummarizer.add(label))
    -      },
    +        },
             combOp = (c1, c2) => (c1, c2) match {
               case ((summarizer1: MultivariateOnlineSummarizer,
               classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer,
    @@ -144,15 +161,28 @@ class LogisticRegression
         val regParamL1 = $(elasticNetParam) * $(regParam)
         val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
     
    -    val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
    +    val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), $(standardization),
           featuresStd, featuresMean, regParamL2)
     
         val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
           new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
         } else {
    -      // Remove the L1 penalization on the intercept
           def regParamL1Fun = (index: Int) => {
    -        if (index == numFeatures) 0.0 else regParamL1
    +        // Remove the L1 penalization on the intercept
    +        if (index == numFeatures) {
    +          0.0
    +        } else {
    +          if ($(standardization)) {
    +            regParamL1
    +          } else {
    +            // If `standardization` is false, we still standardize the data
    +            // to improve the rate of convergence; as a result, we have to
    +            // perform this reverse standardization by penalizing each component
    +            // differently to get effectively the same objective function when
    +            // the training dataset is not standardized.
    +            if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0
    +          }
    +        }
           }
           new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
         }
    @@ -161,18 +191,18 @@ class LogisticRegression
           Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
     
         if ($(fitIntercept)) {
    -      /**
    -       * For binary logistic regression, when we initialize the weights as zeros,
    -       * it will converge faster if we initialize the intercept such that
    -       * it follows the distribution of the labels.
    -       *
    -       * {{{
    -       * P(0) = 1 / (1 + \exp(b)), and
    -       * P(1) = \exp(b) / (1 + \exp(b))
    -       * }}}, hence
    -       * {{{
    -       * b = \log{P(1) / P(0)} = \log{count_1 / count_0}
    -       * }}}
    +      /*
    +         For binary logistic regression, when we initialize the weights as zeros,
    +         it will converge faster if we initialize the intercept such that
    +         it follows the distribution of the labels.
    +
    +         {{{
    +         P(0) = 1 / (1 + \exp(b)), and
    +         P(1) = \exp(b) / (1 + \exp(b))
    +         }}}, hence
    +         {{{
    +         b = \log{P(1) / P(0)} = \log{count_1 / count_0}
    +         }}}
            */
           initialWeightsWithIntercept.toArray(numFeatures)
             = math.log(histogram(1).toDouble / histogram(0).toDouble)
    @@ -181,50 +211,60 @@ class LogisticRegression
         val states = optimizer.iterations(new CachedDiffFunction(costFun),
           initialWeightsWithIntercept.toBreeze.toDenseVector)
     
    -    var state = states.next()
    -    val lossHistory = mutable.ArrayBuilder.make[Double]
    +    val (weights, intercept, objectiveHistory) = {
    +      /*
    +         Note that in Logistic Regression, the objective history (loss + regularization)
    +         is log-likelihood which is invariance under feature standardization. As a result,
    +         the objective history from optimizer is the same as the one in the original space.
    +       */
    +      val arrayBuilder = mutable.ArrayBuilder.make[Double]
    +      var state: optimizer.State = null
    +      while (states.hasNext) {
    +        state = states.next()
    +        arrayBuilder += state.adjustedValue
    +      }
     
    -    while (states.hasNext) {
    -      lossHistory += state.value
    -      state = states.next()
    -    }
    -    lossHistory += state.value
    +      if (state == null) {
    +        val msg = s"${optimizer.getClass.getName} failed."
    +        logError(msg)
    +        throw new SparkException(msg)
    +      }
     
    -    // The weights are trained in the scaled space; we're converting them back to
    -    // the original space.
    -    val weightsWithIntercept = {
    +      /*
    +         The weights are trained in the scaled space; we're converting them back to
    +         the original space.
    +         Note that the intercept in scaled space and original space is the same;
    +         as a result, no scaling is needed.
    +       */
           val rawWeights = state.x.toArray.clone()
           var i = 0
    -      // Note that the intercept in scaled space and original space is the same;
    -      // as a result, no scaling is needed.
           while (i < numFeatures) {
             rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
             i += 1
           }
    -      Vectors.dense(rawWeights)
    +
    +      if ($(fitIntercept)) {
    +        (Vectors.dense(rawWeights.dropRight(1)).compressed, rawWeights.last, arrayBuilder.result())
    +      } else {
    +        (Vectors.dense(rawWeights).compressed, 0.0, arrayBuilder.result())
    +      }
         }
     
         if (handlePersistence) instances.unpersist()
     
    -    val (weights, intercept) = if ($(fitIntercept)) {
    -      (Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)),
    -        weightsWithIntercept(weightsWithIntercept.size - 1))
    -    } else {
    -      (weightsWithIntercept, 0.0)
    -    }
    -
    -    new LogisticRegressionModel(this, weights.compressed, intercept)
    +    copyValues(new LogisticRegressionModel(uid, weights, intercept))
       }
    +
    +  override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
     }
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * Model produced by [[LogisticRegression]].
      */
    -@AlphaComponent
    +@Experimental
     class LogisticRegressionModel private[ml] (
    -    override val parent: LogisticRegression,
    +    override val uid: String,
         val weights: Vector,
         val intercept: Double)
       extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
    @@ -258,7 +298,8 @@ class LogisticRegressionModel private[ml] (
         rawPrediction match {
           case dv: DenseVector =>
             var i = 0
    -        while (i < dv.size) {
    +        val size = dv.size
    +        while (i < size) {
               dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i)))
               i += 1
             }
    @@ -275,7 +316,7 @@ class LogisticRegressionModel private[ml] (
       }
     
       override def copy(extra: ParamMap): LogisticRegressionModel = {
    -    copyValues(new LogisticRegressionModel(parent, weights, intercept), extra)
    +    copyValues(new LogisticRegressionModel(uid, weights, intercept), extra)
       }
     
       override protected def raw2prediction(rawPrediction: Vector): Double = {
    @@ -357,7 +398,8 @@ private[classification] class MultiClassSummarizer extends Serializable {
       def histogram: Array[Long] = {
         val result = Array.ofDim[Long](numClasses)
         var i = 0
    -    while (i < result.length) {
    +    val len = result.length
    +    while (i < len) {
           result(i) = distinctMap.getOrElse(i, 0L)
           i += 1
         }
    @@ -415,16 +457,12 @@ private class LogisticAggregator(
         require(dim == data.size, s"Dimensions mismatch when adding new sample." +
           s" Expecting $dim but got ${data.size}.")
     
    -    val dataSize = data.size
    -
         val localWeightsArray = weightsArray
         val localGradientSumArray = gradientSumArray
     
         numClasses match {
           case 2 =>
    -        /**
    -         * For Binary Logistic Regression.
    -         */
    +        // For Binary Logistic Regression.
             val margin = - {
               var sum = 0.0
               data.foreachActive { (index, value) =>
    @@ -480,7 +518,8 @@ private class LogisticAggregator(
           var i = 0
           val localThisGradientSumArray = this.gradientSumArray
           val localOtherGradientSumArray = other.gradientSumArray
    -      while (i < localThisGradientSumArray.length) {
    +      val len = localThisGradientSumArray.length
    +      while (i < len) {
             localThisGradientSumArray(i) += localOtherGradientSumArray(i)
             i += 1
           }
    @@ -509,11 +548,13 @@ private class LogisticCostFun(
         data: RDD[(Double, Vector)],
         numClasses: Int,
         fitIntercept: Boolean,
    +    standardization: Boolean,
         featuresStd: Array[Double],
         featuresMean: Array[Double],
         regParamL2: Double) extends DiffFunction[BDV[Double]] {
     
       override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
    +    val numFeatures = featuresStd.length
         val w = Vectors.fromBreeze(weights)
     
         val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept,
    @@ -525,27 +566,43 @@ private class LogisticCostFun(
               case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
             })
     
    -    // regVal is the sum of weight squares for L2 regularization
    -    val norm = if (regParamL2 == 0.0) {
    -      0.0
    -    } else if (fitIntercept) {
    -      brzNorm(Vectors.dense(weights.toArray.slice(0, weights.size -1)).toBreeze, 2.0)
    -    } else {
    -      brzNorm(weights, 2.0)
    -    }
    -    val regVal = 0.5 * regParamL2 * norm * norm
    +    val totalGradientArray = logisticAggregator.gradient.toArray
     
    -    val loss = logisticAggregator.loss + regVal
    -    val gradient = logisticAggregator.gradient
    -
    -    if (fitIntercept) {
    -      val wArray = w.toArray.clone()
    -      wArray(wArray.length - 1) = 0.0
    -      axpy(regParamL2, Vectors.dense(wArray), gradient)
    +    // regVal is the sum of weight squares excluding intercept for L2 regularization.
    +    val regVal = if (regParamL2 == 0.0) {
    +      0.0
         } else {
    -      axpy(regParamL2, w, gradient)
    +      var sum = 0.0
    +      w.foreachActive { (index, value) =>
    +        // If `fitIntercept` is true, the last term which is intercept doesn't
    +        // contribute to the regularization.
    +        if (index != numFeatures) {
    +          // The following code will compute the loss of the regularization; also
    +          // the gradient of the regularization, and add back to totalGradientArray.
    +          sum += {
    +            if (standardization) {
    +              totalGradientArray(index) += regParamL2 * value
    +              value * value
    +            } else {
    +              if (featuresStd(index) != 0.0) {
    +                // If `standardization` is false, we still standardize the data
    +                // to improve the rate of convergence; as a result, we have to
    +                // perform this reverse standardization by penalizing each component
    +                // differently to get effectively the same objective function when
    +                // the training dataset is not standardized.
    +                val temp = value / (featuresStd(index) * featuresStd(index))
    +                totalGradientArray(index) += regParamL2 * temp
    +                value * temp
    +              } else {
    +                0.0
    +              }
    +            }
    +          }
    +        }
    +      }
    +      0.5 * regParamL2 * sum
         }
     
    -    (loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
    +    (logisticAggregator.loss + regVal, new BDV(totalGradientArray))
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
    index afb8d75d57384..ea757c5e40c76 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
    @@ -21,11 +21,11 @@ import java.util.UUID
     
     import scala.language.existentials
     
    -import org.apache.spark.annotation.{AlphaComponent, Experimental}
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml._
     import org.apache.spark.ml.attribute._
    -import org.apache.spark.ml.param.Param
    -import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.ml.param.{Param, ParamMap}
    +import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.sql.{DataFrame, Row}
     import org.apache.spark.sql.functions._
    @@ -37,27 +37,26 @@ import org.apache.spark.storage.StorageLevel
      */
     private[ml] trait OneVsRestParams extends PredictorParams {
     
    +  // scalastyle:off structural.type
       type ClassifierType = Classifier[F, E, M] forSome {
         type F
         type M <: ClassificationModel[F, M]
    -    type E <:  Classifier[F, E, M]
    +    type E <: Classifier[F, E, M]
       }
    +  // scalastyle:on structural.type
     
       /**
        * param for the base binary classifier that we reduce multiclass classification into.
        * @group param
        */
    -  val classifier: Param[ClassifierType]  =
    -    new Param(this, "classifier", "base binary classifier ")
    +  val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier")
     
       /** @group getParam */
       def getClassifier: ClassifierType = $(classifier)
    -
     }
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * Model produced by [[OneVsRest]].
      * This stores the models resulting from training k binary classifiers: one for each class.
      * Each example is scored against all k models, and the model with the highest score
    @@ -69,11 +68,11 @@ private[ml] trait OneVsRestParams extends PredictorParams {
      *               The i-th model is produced by testing the i-th class (taking label 1) vs the rest
      *               (taking label 0).
      */
    -@AlphaComponent
    -class OneVsRestModel private[ml] (
    -      override val parent: OneVsRest,
    -      labelMetadata: Metadata,
    -      val models: Array[_ <: ClassificationModel[_,_]])
    +@Experimental
    +final class OneVsRestModel private[ml] (
    +    override val uid: String,
    +    labelMetadata: Metadata,
    +    val models: Array[_ <: ClassificationModel[_, _]])
       extends Model[OneVsRestModel] with OneVsRestParams {
     
       override def transformSchema(schema: StructType): StructType = {
    @@ -89,9 +88,9 @@ class OneVsRestModel private[ml] (
     
         // add an accumulator column to store predictions of all the models
         val accColName = "mbc$acc" + UUID.randomUUID().toString
    -    val init: () => Map[Int, Double] = () => {Map()}
    +    val initUDF = udf { () => Map[Int, Double]() }
         val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false)
    -    val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
    +    val newDataset = dataset.withColumn(accColName, initUDF())
     
         // persist if underlying dataset is not persistent.
         val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
    @@ -107,17 +106,16 @@ class OneVsRestModel private[ml] (
     
             // add temporary column to store intermediate scores and update
             val tmpColName = "mbc$tmp" + UUID.randomUUID().toString
    -        val update: (Map[Int, Double], Vector) => Map[Int, Double]  =
    -          (predictions: Map[Int, Double], prediction: Vector) => {
    -            predictions + ((index, prediction(1)))
    -          }
    -        val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
    -        val transformedDataset = model.transform(df).select(columns:_*)
    -        val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
    +        val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
    +          predictions + ((index, prediction(1)))
    +        }
    +        val transformedDataset = model.transform(df).select(columns : _*)
    +        val updatedDataset = transformedDataset
    +          .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol)))
             val newColumns = origCols ++ List(col(tmpColName))
     
             // switch out the intermediate column with the accumulator column
    -        updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName)
    +        updatedDataset.select(newColumns : _*).withColumnRenamed(tmpColName, accColName)
         }
     
         if (handlePersistence) {
    @@ -125,13 +123,20 @@ class OneVsRestModel private[ml] (
         }
     
         // output the index of the classifier with highest confidence as prediction
    -    val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => {
    +    val labelUDF = udf { (predictions: Map[Int, Double]) =>
           predictions.maxBy(_._2)._1.toDouble
         }
     
         // output label and label metadata as prediction
    -    val labelUdf = callUDF(label, DoubleType, col(accColName))
    -    aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
    +    aggregatedDataset
    +      .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata))
    +      .drop(accColName)
    +  }
    +
    +  override def copy(extra: ParamMap): OneVsRestModel = {
    +    val copied = new OneVsRestModel(
    +      uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
    +    copyValues(copied, extra)
       }
     }
     
    @@ -145,11 +150,13 @@ class OneVsRestModel private[ml] (
      * is picked to label the example.
      */
     @Experimental
    -final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
    +final class OneVsRest(override val uid: String)
    +  extends Estimator[OneVsRestModel] with OneVsRestParams {
    +
    +  def this() = this(Identifiable.randomUID("oneVsRest"))
     
       /** @group setParam */
    -  def setClassifier(value: Classifier[_,_,_]): this.type = {
    -    // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed
    +  def setClassifier(value: Classifier[_, _, _]): this.type = {
         set(classifier, value.asInstanceOf[ClassifierType])
       }
     
    @@ -177,21 +184,19 @@ final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
     
         // create k columns, one for each binary classifier.
         val models = Range(0, numClasses).par.map { index =>
    -
    -      val label: Double => Double = (label: Double) => {
    +      val labelUDF = udf { (label: Double) =>
             if (label.toInt == index) 1.0 else 0.0
           }
     
           // generate new label metadata for the binary problem.
           // TODO: use when ... otherwise after SPARK-7321 is merged
    -      val labelUDF = callUDF(label, DoubleType, col($(labelCol)))
           val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
           val labelColName = "mc2b$" + index
    -      val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta)
    +      val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
           val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
           val classifier = getClassifier
           classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
    -    }.toArray[ClassificationModel[_,_]]
    +    }.toArray[ClassificationModel[_, _]]
     
         if (handlePersistence) {
           multiclassLabeled.unpersist()
    @@ -204,6 +209,15 @@ final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
             NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
           case attr: Attribute => attr
         }
    -    copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models))
    +    val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this)
    +    copyValues(model)
    +  }
    +
    +  override def copy(extra: ParamMap): OneVsRest = {
    +    val copied = defaultCopy(extra).asInstanceOf[OneVsRest]
    +    if (isDefined(classifier)) {
    +      copied.setClassifier($(classifier).copy(extra))
    +    }
    +    copied
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
    index 330ae2938f4e0..38e832372698c 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
    @@ -98,26 +98,34 @@ private[spark] abstract class ProbabilisticClassificationModel[
         var outputData = dataset
         var numColsOutput = 0
         if ($(rawPredictionCol).nonEmpty) {
    -      outputData = outputData.withColumn(getRawPredictionCol,
    -        callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol)))
    +      val predictRawUDF = udf { (features: Any) =>
    +        predictRaw(features.asInstanceOf[FeaturesType])
    +      }
    +      outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
           numColsOutput += 1
         }
         if ($(probabilityCol).nonEmpty) {
           val probUDF = if ($(rawPredictionCol).nonEmpty) {
    -        callUDF(raw2probability _, new VectorUDT, col($(rawPredictionCol)))
    +        udf(raw2probability _).apply(col($(rawPredictionCol)))
           } else {
    -        callUDF(predictProbability _, new VectorUDT, col($(featuresCol)))
    +        val probabilityUDF = udf { (features: Any) =>
    +          predictProbability(features.asInstanceOf[FeaturesType])
    +        }
    +        probabilityUDF(col($(featuresCol)))
           }
           outputData = outputData.withColumn($(probabilityCol), probUDF)
           numColsOutput += 1
         }
         if ($(predictionCol).nonEmpty) {
           val predUDF = if ($(rawPredictionCol).nonEmpty) {
    -        callUDF(raw2prediction _, DoubleType, col($(rawPredictionCol)))
    +        udf(raw2prediction _).apply(col($(rawPredictionCol)))
           } else if ($(probabilityCol).nonEmpty) {
    -        callUDF(probability2prediction _, DoubleType, col($(probabilityCol)))
    +        udf(probability2prediction _).apply(col($(probabilityCol)))
           } else {
    -        callUDF(predict _, DoubleType, col($(featuresCol)))
    +        val predictUDF = udf { (features: Any) =>
    +          predict(features.asInstanceOf[FeaturesType])
    +        }
    +        predictUDF(col($(featuresCol)))
           }
           outputData = outputData.withColumn($(predictionCol), predUDF)
           numColsOutput += 1
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
    index 9954893f14359..d3c67494a31e4 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
    @@ -19,11 +19,11 @@ package org.apache.spark.ml.classification
     
     import scala.collection.mutable
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.ParamMap
    -import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
    -import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
    +import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
    @@ -33,18 +33,19 @@ import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * [[http://en.wikipedia.org/wiki/Random_forest  Random Forest]] learning algorithm for
      * classification.
      * It supports both binary and multiclass labels, as well as both continuous and categorical
      * features.
      */
    -@AlphaComponent
    -final class RandomForestClassifier
    +@Experimental
    +final class RandomForestClassifier(override val uid: String)
       extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
       with RandomForestParams with TreeClassifierParams {
     
    +  def this() = this(Identifiable.randomUID("rfc"))
    +
       // Override parameter setters from parent trait for Java API compatibility.
     
       // Parameters from TreeClassifierParams:
    @@ -96,8 +97,11 @@ final class RandomForestClassifier
           oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
         RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
       }
    +
    +  override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
     }
     
    +@Experimental
     object RandomForestClassifier {
       /** Accessor for supported impurity settings: entropy, gini */
       final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
    @@ -108,17 +112,16 @@ object RandomForestClassifier {
     }
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * [[http://en.wikipedia.org/wiki/Random_forest  Random Forest]] model for classification.
      * It supports both binary and multiclass labels, as well as both continuous and categorical
      * features.
      * @param _trees  Decision trees in the ensemble.
      *               Warning: These have null parents.
      */
    -@AlphaComponent
    +@Experimental
     final class RandomForestClassificationModel private[ml] (
    -    override val parent: RandomForestClassifier,
    +    override val uid: String,
         private val _trees: Array[DecisionTreeClassificationModel])
       extends PredictionModel[Vector, RandomForestClassificationModel]
       with TreeEnsembleModel with Serializable {
    @@ -146,7 +149,7 @@ final class RandomForestClassificationModel private[ml] (
       }
     
       override def copy(extra: ParamMap): RandomForestClassificationModel = {
    -    copyValues(new RandomForestClassificationModel(parent, _trees), extra)
    +    copyValues(new RandomForestClassificationModel(uid, _trees), extra)
       }
     
       override def toString: String = {
    @@ -169,9 +172,10 @@ private[ml] object RandomForestClassificationModel {
         require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
           s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
         val newTrees = oldModel.trees.map { tree =>
    -      // parent, fittingParamMap for each tree is null since there are no good ways to set these.
    +      // parent for each tree is null since there is no good way to set this.
           DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
         }
    -    new RandomForestClassificationModel(parent, newTrees)
    +    val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
    +    new RandomForestClassificationModel(uid, newTrees)
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
    index e5a73c6087a11..4a82b77f0edcb 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
    @@ -17,23 +17,24 @@
     
     package org.apache.spark.ml.evaluation
     
    -import org.apache.spark.annotation.AlphaComponent
    -import org.apache.spark.ml.Evaluator
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.param._
     import org.apache.spark.ml.param.shared._
    -import org.apache.spark.ml.util.SchemaUtils
    +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
     import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
     import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
     import org.apache.spark.sql.{DataFrame, Row}
     import org.apache.spark.sql.types.DoubleType
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * Evaluator for binary classification, which expects two input columns: score and label.
      */
    -@AlphaComponent
    -class BinaryClassificationEvaluator extends Evaluator with HasRawPredictionCol with HasLabelCol {
    +@Experimental
    +class BinaryClassificationEvaluator(override val uid: String)
    +  extends Evaluator with HasRawPredictionCol with HasLabelCol {
    +
    +  def this() = this(Identifiable.randomUID("binEval"))
     
       /**
        * param for metric name in evaluation
    @@ -78,4 +79,6 @@ class BinaryClassificationEvaluator extends Evaluator with HasRawPredictionCol w
         metrics.unpersist()
         metric
       }
    +
    +  override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
    similarity index 86%
    rename from mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
    rename to mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
    index 5f2f8c94e9ff7..e56c946a063e8 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
    @@ -15,21 +15,21 @@
      * limitations under the License.
      */
     
    -package org.apache.spark.ml
    +package org.apache.spark.ml.evaluation
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.DeveloperApi
     import org.apache.spark.ml.param.{ParamMap, Params}
     import org.apache.spark.sql.DataFrame
     
     /**
    - * :: AlphaComponent ::
    + * :: DeveloperApi ::
      * Abstract class for evaluators that compute metrics from predictions.
      */
    -@AlphaComponent
    +@DeveloperApi
     abstract class Evaluator extends Params {
     
       /**
    -   * Evaluates the output.
    +   * Evaluates model output and returns a scalar metric (larger is better).
        *
        * @param dataset a dataset that contains labels/observations and predictions.
        * @param paramMap parameter map that specifies the input columns and output metrics
    @@ -46,7 +46,5 @@ abstract class Evaluator extends Params {
        */
       def evaluate(dataset: DataFrame): Double
     
    -  override def copy(extra: ParamMap): Evaluator = {
    -    super.copy(extra).asInstanceOf[Evaluator]
    -  }
    +  override def copy(extra: ParamMap): Evaluator
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
    new file mode 100644
    index 0000000000000..01c000b47514c
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
    @@ -0,0 +1,89 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.evaluation
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
    +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
    +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
    +import org.apache.spark.mllib.evaluation.RegressionMetrics
    +import org.apache.spark.sql.{DataFrame, Row}
    +import org.apache.spark.sql.types.DoubleType
    +
    +/**
    + * :: Experimental ::
    + * Evaluator for regression, which expects two input columns: prediction and label.
    + */
    +@Experimental
    +final class RegressionEvaluator(override val uid: String)
    +  extends Evaluator with HasPredictionCol with HasLabelCol {
    +
    +  def this() = this(Identifiable.randomUID("regEval"))
    +
    +  /**
    +   * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`)
    +   *
    +   * Because we will maximize evaluation value (ref: `CrossValidator`),
    +   * when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),
    +   * we take and output the negative of this metric.
    +   * @group param
    +   */
    +  val metricName: Param[String] = {
    +    val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae"))
    +    new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", allowedParams)
    +  }
    +
    +  /** @group getParam */
    +  def getMetricName: String = $(metricName)
    +
    +  /** @group setParam */
    +  def setMetricName(value: String): this.type = set(metricName, value)
    +
    +  /** @group setParam */
    +  def setPredictionCol(value: String): this.type = set(predictionCol, value)
    +
    +  /** @group setParam */
    +  def setLabelCol(value: String): this.type = set(labelCol, value)
    +
    +  setDefault(metricName -> "rmse")
    +
    +  override def evaluate(dataset: DataFrame): Double = {
    +    val schema = dataset.schema
    +    SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
    +    SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
    +
    +    val predictionAndLabels = dataset.select($(predictionCol), $(labelCol))
    +      .map { case Row(prediction: Double, label: Double) =>
    +        (prediction, label)
    +      }
    +    val metrics = new RegressionMetrics(predictionAndLabels)
    +    val metric = $(metricName) match {
    +      case "rmse" =>
    +        -metrics.rootMeanSquaredError
    +      case "mse" =>
    +        -metrics.meanSquaredError
    +      case "r2" =>
    +        metrics.r2
    +      case "mae" =>
    +        -metrics.meanAbsoluteError
    +    }
    +    metric
    +  }
    +
    +  override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
    index 6eb1db6971111..46314854d5e3a 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
    @@ -17,22 +17,25 @@
     
     package org.apache.spark.ml.feature
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.Transformer
     import org.apache.spark.ml.attribute.BinaryAttribute
     import org.apache.spark.ml.param._
     import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
    -import org.apache.spark.ml.util.SchemaUtils
    +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
     import org.apache.spark.sql._
     import org.apache.spark.sql.functions._
     import org.apache.spark.sql.types.{DoubleType, StructType}
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Binarize a column of continuous features given a threshold.
      */
    -@AlphaComponent
    -final class Binarizer extends Transformer with HasInputCol with HasOutputCol {
    +@Experimental
    +final class Binarizer(override val uid: String)
    +  extends Transformer with HasInputCol with HasOutputCol {
    +
    +  def this() = this(Identifiable.randomUID("binarizer"))
     
       /**
        * Param for threshold used to binarize continuous features.
    @@ -80,4 +83,6 @@ final class Binarizer extends Transformer with HasInputCol with HasOutputCol {
         val outputFields = inputFields :+ attr.toStructField()
         StructType(outputFields)
       }
    +
    +  override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
    index b28c88aaaecbc..67e4785bc3553 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
    @@ -20,25 +20,25 @@ package org.apache.spark.ml.feature
     import java.{util => ju}
     
     import org.apache.spark.SparkException
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.Model
     import org.apache.spark.ml.attribute.NominalAttribute
     import org.apache.spark.ml.param._
     import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
    -import org.apache.spark.ml.util.SchemaUtils
    -import org.apache.spark.ml.{Estimator, Model}
    +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
     import org.apache.spark.sql._
     import org.apache.spark.sql.functions._
     import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * `Bucketizer` maps a column of continuous features to a column of feature buckets.
      */
    -@AlphaComponent
    -final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
    +@Experimental
    +final class Bucketizer(override val uid: String)
       extends Model[Bucketizer] with HasInputCol with HasOutputCol {
     
    -  def this() = this(null)
    +  def this() = this(Identifiable.randomUID("bucketizer"))
     
       /**
        * Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets.
    @@ -48,7 +48,7 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
        * otherwise, values outside the splits specified will be treated as errors.
        * @group param
        */
    -  val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
    +  val splits: DoubleArrayParam = new DoubleArrayParam(this, "splits",
         "Split points for mapping continuous features into buckets. With n+1 splits, there are n " +
           "buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " +
           "bucket, which also includes y. The splits should be strictly increasing. " +
    @@ -89,6 +89,8 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
         SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
         SchemaUtils.appendColumn(schema, prepOutputField(schema))
       }
    +
    +  override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra)
     }
     
     private[feature] object Bucketizer {
    @@ -98,7 +100,8 @@ private[feature] object Bucketizer {
           false
         } else {
           var i = 0
    -      while (i < splits.length - 1) {
    +      val n = splits.length - 1
    +      while (i < n) {
             if (splits(i) >= splits(i + 1)) return false
             i += 1
           }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
    new file mode 100644
    index 0000000000000..6b77de89a0330
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
    @@ -0,0 +1,82 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.ml.feature
    +
    +import scala.collection.mutable
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.UnaryTransformer
    +import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam}
    +import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector}
    +import org.apache.spark.sql.types.{StringType, ArrayType, DataType}
    +
    +/**
    + * :: Experimental ::
    + * Converts a text document to a sparse vector of token counts.
    + * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
    + */
    +@Experimental
    +class CountVectorizerModel (override val uid: String, val vocabulary: Array[String])
    +  extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] {
    +
    +  def this(vocabulary: Array[String]) =
    +    this(Identifiable.randomUID("cntVec"), vocabulary)
    +
    +  /**
    +   * Corpus-specific filter to ignore scarce words in a document. For each document, terms with
    +   * frequency (count) less than the given threshold are ignored.
    +   * Default: 1
    +   * @group param
    +   */
    +  val minTermFreq: IntParam = new IntParam(this, "minTermFreq",
    +    "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " +
    +      "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1))
    +
    +  /** @group setParam */
    +  def setMinTermFreq(value: Int): this.type = set(minTermFreq, value)
    +
    +  /** @group getParam */
    +  def getMinTermFreq: Int = $(minTermFreq)
    +
    +  setDefault(minTermFreq -> 1)
    +
    +  override protected def createTransformFunc: Seq[String] => Vector = {
    +    val dict = vocabulary.zipWithIndex.toMap
    +    document =>
    +      val termCounts = mutable.HashMap.empty[Int, Double]
    +      document.foreach { term =>
    +        dict.get(term) match {
    +          case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0)
    +          case None => // ignore terms not in the vocabulary
    +        }
    +      }
    +      Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq)
    +  }
    +
    +  override protected def validateInputType(inputType: DataType): Unit = {
    +    require(inputType.sameType(ArrayType(StringType)),
    +      s"Input type must be ArrayType(StringType) but got $inputType.")
    +  }
    +
    +  override protected def outputDataType: DataType = new VectorUDT()
    +
    +  override def copy(extra: ParamMap): CountVectorizerModel = {
    +    val copied = new CountVectorizerModel(uid, vocabulary)
    +    copyValues(copied, extra)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
    new file mode 100644
    index 0000000000000..228347635c92b
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
    @@ -0,0 +1,72 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import edu.emory.mathcs.jtransforms.dct._
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.UnaryTransformer
    +import org.apache.spark.ml.param.BooleanParam
    +import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
    +import org.apache.spark.sql.types.DataType
    +
    +/**
    + * :: Experimental ::
    + * A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero
    + * padding is performed on the input vector.
    + * It returns a real vector of the same length representing the DCT. The return vector is scaled
    + * such that the transform matrix is unitary (aka scaled DCT-II).
    + *
    + * More information on [[https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia]].
    + */
    +@Experimental
    +class DCT(override val uid: String)
    +  extends UnaryTransformer[Vector, Vector, DCT] {
    +
    +  def this() = this(Identifiable.randomUID("dct"))
    +
    +  /**
    +   * Indicates whether to perform the inverse DCT (true) or forward DCT (false).
    +   * Default: false
    +   * @group param
    +   */
    +  def inverse: BooleanParam = new BooleanParam(
    +    this, "inverse", "Set transformer to perform inverse DCT")
    +
    +  /** @group setParam */
    +  def setInverse(value: Boolean): this.type = set(inverse, value)
    +
    +  /** @group getParam */
    +  def getInverse: Boolean = $(inverse)
    +
    +  setDefault(inverse -> false)
    +
    +  override protected def createTransformFunc: Vector => Vector = { vec =>
    +    val result = vec.toArray
    +    val jTransformer = new DoubleDCT_1D(result.length)
    +    if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true)
    +    Vectors.dense(result)
    +  }
    +
    +  override protected def validateInputType(inputType: DataType): Unit = {
    +    require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.")
    +  }
    +
    +  override protected def outputDataType: DataType = new VectorUDT
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
    index f8b56293e3ccc..a359cb8f37ec3 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
    @@ -17,27 +17,31 @@
     
     package org.apache.spark.ml.feature
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.UnaryTransformer
    -import org.apache.spark.ml.param.Param
    +import org.apache.spark.ml.param.{ParamMap, Param}
    +import org.apache.spark.ml.util.Identifiable
     import org.apache.spark.mllib.feature
     import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
     import org.apache.spark.sql.types.DataType
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a
      * provided "weight" vector.  In other words, it scales each column of the dataset by a scalar
      * multiplier.
      */
    -@AlphaComponent
    -class ElementwiseProduct extends UnaryTransformer[Vector, Vector, ElementwiseProduct] {
    +@Experimental
    +class ElementwiseProduct(override val uid: String)
    +  extends UnaryTransformer[Vector, Vector, ElementwiseProduct] {
    +
    +  def this() = this(Identifiable.randomUID("elemProd"))
     
       /**
         * the vector to multiply with input vectors
         * @group param
         */
    -  val scalingVec: Param[Vector] = new Param(this, "scalingVector", "vector for hadamard product")
    +  val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product")
     
       /** @group setParam */
       def setScalingVec(value: Vector): this.type = set(scalingVec, value)
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
    index c305a819a8966..319d23e46cef4 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
    @@ -17,19 +17,31 @@
     
     package org.apache.spark.ml.feature
     
    -import org.apache.spark.annotation.AlphaComponent
    -import org.apache.spark.ml.UnaryTransformer
    -import org.apache.spark.ml.param.{IntParam, ParamValidators}
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.Transformer
    +import org.apache.spark.ml.attribute.AttributeGroup
    +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
    +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
    +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
     import org.apache.spark.mllib.feature
    -import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
    -import org.apache.spark.sql.types.DataType
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions.{col, udf}
    +import org.apache.spark.sql.types.{ArrayType, StructType}
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Maps a sequence of terms to their term frequencies using the hashing trick.
      */
    -@AlphaComponent
    -class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
    +@Experimental
    +class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol {
    +
    +  def this() = this(Identifiable.randomUID("hashingTF"))
    +
    +  /** @group setParam */
    +  def setInputCol(value: String): this.type = set(inputCol, value)
    +
    +  /** @group setParam */
    +  def setOutputCol(value: String): this.type = set(outputCol, value)
     
       /**
        * Number of features.  Should be > 0.
    @@ -47,10 +59,21 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
       /** @group setParam */
       def setNumFeatures(value: Int): this.type = set(numFeatures, value)
     
    -  override protected def createTransformFunc: Iterable[_] => Vector = {
    +  override def transform(dataset: DataFrame): DataFrame = {
    +    val outputSchema = transformSchema(dataset.schema)
         val hashingTF = new feature.HashingTF($(numFeatures))
    -    hashingTF.transform
    +    val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
    +    val metadata = outputSchema($(outputCol)).metadata
    +    dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    val inputType = schema($(inputCol)).dataType
    +    require(inputType.isInstanceOf[ArrayType],
    +      s"The input column must be ArrayType, but got $inputType.")
    +    val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
    +    SchemaUtils.appendColumn(schema, attrGroup.toStructField())
       }
     
    -  override protected def outputDataType: DataType = new VectorUDT()
    +  override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
    index d901a20aed002..ecde80810580c 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
    @@ -17,11 +17,11 @@
     
     package org.apache.spark.ml.feature
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml._
     import org.apache.spark.ml.param._
     import org.apache.spark.ml.param.shared._
    -import org.apache.spark.ml.util.SchemaUtils
    +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
     import org.apache.spark.mllib.feature
     import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
     import org.apache.spark.sql._
    @@ -45,9 +45,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
       /** @group getParam */
       def getMinDocFreq: Int = $(minDocFreq)
     
    -  /** @group setParam */
    -  def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
    -
       /**
        * Validate and transform the input schema.
        */
    @@ -58,11 +55,13 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
     }
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Compute the Inverse Document Frequency (IDF) given a collection of documents.
      */
    -@AlphaComponent
    -final class IDF extends Estimator[IDFModel] with IDFBase {
    +@Experimental
    +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase {
    +
    +  def this() = this(Identifiable.randomUID("idf"))
     
       /** @group setParam */
       def setInputCol(value: String): this.type = set(inputCol, value)
    @@ -70,25 +69,30 @@ final class IDF extends Estimator[IDFModel] with IDFBase {
       /** @group setParam */
       def setOutputCol(value: String): this.type = set(outputCol, value)
     
    +  /** @group setParam */
    +  def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
    +
       override def fit(dataset: DataFrame): IDFModel = {
         transformSchema(dataset.schema, logging = true)
         val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
         val idf = new feature.IDF($(minDocFreq)).fit(input)
    -    copyValues(new IDFModel(this, idf))
    +    copyValues(new IDFModel(uid, idf).setParent(this))
       }
     
       override def transformSchema(schema: StructType): StructType = {
         validateAndTransformSchema(schema)
       }
    +
    +  override def copy(extra: ParamMap): IDF = defaultCopy(extra)
     }
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Model fitted by [[IDF]].
      */
    -@AlphaComponent
    +@Experimental
     class IDFModel private[ml] (
    -    override val parent: IDF,
    +    override val uid: String,
         idfModel: feature.IDFModel)
       extends Model[IDFModel] with IDFBase {
     
    @@ -107,4 +111,9 @@ class IDFModel private[ml] (
       override def transformSchema(schema: StructType): StructType = {
         validateAndTransformSchema(schema)
       }
    +
    +  override def copy(extra: ParamMap): IDFModel = {
    +    val copied = new IDFModel(uid, idfModel)
    +    copyValues(copied, extra)
    +  }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
    new file mode 100644
    index 0000000000000..b30adf3df48d2
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
    @@ -0,0 +1,170 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
    +import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params}
    +import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.ml.{Estimator, Model}
    +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
    +import org.apache.spark.mllib.stat.Statistics
    +import org.apache.spark.sql._
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types.{StructField, StructType}
    +
    +/**
    + * Params for [[MinMaxScaler]] and [[MinMaxScalerModel]].
    + */
    +private[feature] trait MinMaxScalerParams extends Params with HasInputCol with HasOutputCol {
    +
    +  /**
    +   * lower bound after transformation, shared by all features
    +   * Default: 0.0
    +   * @group param
    +   */
    +  val min: DoubleParam = new DoubleParam(this, "min",
    +    "lower bound of the output feature range")
    +
    +  /**
    +   * upper bound after transformation, shared by all features
    +   * Default: 1.0
    +   * @group param
    +   */
    +  val max: DoubleParam = new DoubleParam(this, "max",
    +    "upper bound of the output feature range")
    +
    +  /** Validates and transforms the input schema. */
    +  protected def validateAndTransformSchema(schema: StructType): StructType = {
    +    val inputType = schema($(inputCol)).dataType
    +    require(inputType.isInstanceOf[VectorUDT],
    +      s"Input column ${$(inputCol)} must be a vector column")
    +    require(!schema.fieldNames.contains($(outputCol)),
    +      s"Output column ${$(outputCol)} already exists.")
    +    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
    +    StructType(outputFields)
    +  }
    +
    +  override def validateParams(): Unit = {
    +    require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
    +  }
    +}
    +
    +/**
    + * :: Experimental ::
    + * Rescale each feature individually to a common range [min, max] linearly using column summary
    + * statistics, which is also known as min-max normalization or Rescaling. The rescaled value for
    + * feature E is calculated as,
    + *
    + * Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min
    + *
    + * For the case E_{max} == E_{min}, Rescaled(e_i) = 0.5 * (max + min)
    + * Note that since zero values will probably be transformed to non-zero values, output of the
    + * transformer will be DenseVector even for sparse input.
    + */
    +@Experimental
    +class MinMaxScaler(override val uid: String)
    +  extends Estimator[MinMaxScalerModel] with MinMaxScalerParams {
    +
    +  def this() = this(Identifiable.randomUID("minMaxScal"))
    +
    +  setDefault(min -> 0.0, max -> 1.0)
    +
    +  /** @group setParam */
    +  def setInputCol(value: String): this.type = set(inputCol, value)
    +
    +  /** @group setParam */
    +  def setOutputCol(value: String): this.type = set(outputCol, value)
    +
    +  /** @group setParam */
    +  def setMin(value: Double): this.type = set(min, value)
    +
    +  /** @group setParam */
    +  def setMax(value: Double): this.type = set(max, value)
    +
    +  override def fit(dataset: DataFrame): MinMaxScalerModel = {
    +    transformSchema(dataset.schema, logging = true)
    +    val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
    +    val summary = Statistics.colStats(input)
    +    copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this))
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    validateAndTransformSchema(schema)
    +  }
    +
    +  override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra)
    +}
    +
    +/**
    + * :: Experimental ::
    + * Model fitted by [[MinMaxScaler]].
    + *
    + * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529).
    + */
    +@Experimental
    +class MinMaxScalerModel private[ml] (
    +    override val uid: String,
    +    val originalMin: Vector,
    +    val originalMax: Vector)
    +  extends Model[MinMaxScalerModel] with MinMaxScalerParams {
    +
    +  /** @group setParam */
    +  def setInputCol(value: String): this.type = set(inputCol, value)
    +
    +  /** @group setParam */
    +  def setOutputCol(value: String): this.type = set(outputCol, value)
    +
    +  /** @group setParam */
    +  def setMin(value: Double): this.type = set(min, value)
    +
    +  /** @group setParam */
    +  def setMax(value: Double): this.type = set(max, value)
    +
    +
    +  override def transform(dataset: DataFrame): DataFrame = {
    +    val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray
    +    val minArray = originalMin.toArray
    +
    +    val reScale = udf { (vector: Vector) =>
    +      val scale = $(max) - $(min)
    +
    +      // 0 in sparse vector will probably be rescaled to non-zero
    +      val values = vector.toArray
    +      val size = values.size
    +      var i = 0
    +      while (i < size) {
    +        val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5
    +        values(i) = raw * scale + $(min)
    +        i += 1
    +      }
    +      Vectors.dense(values)
    +    }
    +
    +    dataset.withColumn($(outputCol), reScale(col($(inputCol))))
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    validateAndTransformSchema(schema)
    +  }
    +
    +  override def copy(extra: ParamMap): MinMaxScalerModel = {
    +    val copied = new MinMaxScalerModel(uid, originalMin, originalMax)
    +    copyValues(copied, extra)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
    new file mode 100644
    index 0000000000000..8de10eb51f923
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
    @@ -0,0 +1,69 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.UnaryTransformer
    +import org.apache.spark.ml.param._
    +import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
    +
    +/**
    + * :: Experimental ::
    + * A feature transformer that converts the input array of strings into an array of n-grams. Null
    + * values in the input array are ignored.
    + * It returns an array of n-grams where each n-gram is represented by a space-separated string of
    + * words.
    + *
    + * When the input is empty, an empty array is returned.
    + * When the input array length is less than n (number of elements per n-gram), no n-grams are
    + * returned.
    + */
    +@Experimental
    +class NGram(override val uid: String)
    +  extends UnaryTransformer[Seq[String], Seq[String], NGram] {
    +
    +  def this() = this(Identifiable.randomUID("ngram"))
    +
    +  /**
    +   * Minimum n-gram length, >= 1.
    +   * Default: 2, bigram features
    +   * @group param
    +   */
    +  val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)",
    +    ParamValidators.gtEq(1))
    +
    +  /** @group setParam */
    +  def setN(value: Int): this.type = set(n, value)
    +
    +  /** @group getParam */
    +  def getN: Int = $(n)
    +
    +  setDefault(n -> 2)
    +
    +  override protected def createTransformFunc: Seq[String] => Seq[String] = {
    +    _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
    +  }
    +
    +  override protected def validateInputType(inputType: DataType): Unit = {
    +    require(inputType.sameType(ArrayType(StringType)),
    +      s"Input type must be ArrayType(StringType) but got $inputType.")
    +  }
    +
    +  override protected def outputDataType: DataType = new ArrayType(StringType, false)
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
    index 755b46a64c7f1..8282e5ffa17f7 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
    @@ -17,19 +17,22 @@
     
     package org.apache.spark.ml.feature
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.UnaryTransformer
     import org.apache.spark.ml.param.{DoubleParam, ParamValidators}
    +import org.apache.spark.ml.util.Identifiable
     import org.apache.spark.mllib.feature
     import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
     import org.apache.spark.sql.types.DataType
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Normalize a vector to have unit norm using the given p-norm.
      */
    -@AlphaComponent
    -class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
    +@Experimental
    +class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] {
    +
    +  def this() = this(Identifiable.randomUID("normalizer"))
     
       /**
        * Normalization in L^p^ space.  Must be >= 1.
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
    index 46514ae5f0e84..3825942795645 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
    @@ -17,91 +17,154 @@
     
     package org.apache.spark.ml.feature
     
    -import org.apache.spark.SparkException
    -import org.apache.spark.annotation.AlphaComponent
    -import org.apache.spark.ml.UnaryTransformer
    -import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
    -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.Transformer
    +import org.apache.spark.ml.attribute._
     import org.apache.spark.ml.param._
     import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
    -import org.apache.spark.ml.util.SchemaUtils
    -import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
    +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
    +import org.apache.spark.mllib.linalg.Vectors
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions.{col, udf}
    +import org.apache.spark.sql.types.{DoubleType, StructType}
     
     /**
    - * A one-hot encoder that maps a column of label indices to a column of binary vectors, with
    - * at most a single one-value. By default, the binary vector has an element for each category, so
    - * with 5 categories, an input value of 2.0 would map to an output vector of
    - * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the
    - * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
    - * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
    - * linearly dependent because they sum up to one.
    + * :: Experimental ::
    + * A one-hot encoder that maps a column of category indices to a column of binary vectors, with
    + * at most a single one-value per row that indicates the input category index.
    + * For example with 5 categories, an input value of 2.0 would map to an output vector of
    + * `[0.0, 0.0, 1.0, 0.0]`.
    + * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]]
    + * because it makes the vector entries sum up to one, and hence linearly dependent.
    + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
    + * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories.
    + * The output vectors are sparse.
    + *
    + * @see [[StringIndexer]] for converting categorical values into category indices
      */
    -@AlphaComponent
    -class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
    +@Experimental
    +class OneHotEncoder(override val uid: String) extends Transformer
       with HasInputCol with HasOutputCol {
     
    +  def this() = this(Identifiable.randomUID("oneHot"))
    +
       /**
    -   * Whether to include a component in the encoded vectors for the first category, defaults to true.
    +   * Whether to drop the last category in the encoded vector (default: true)
        * @group param
        */
    -  final val includeFirst: BooleanParam =
    -    new BooleanParam(this, "includeFirst", "include first category")
    -  setDefault(includeFirst -> true)
    -
    -  private var categories: Array[String] = _
    +  final val dropLast: BooleanParam =
    +    new BooleanParam(this, "dropLast", "whether to drop the last category")
    +  setDefault(dropLast -> true)
     
       /** @group setParam */
    -  def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)
    +  def setDropLast(value: Boolean): this.type = set(dropLast, value)
     
       /** @group setParam */
    -  override def setInputCol(value: String): this.type = set(inputCol, value)
    +  def setInputCol(value: String): this.type = set(inputCol, value)
     
       /** @group setParam */
    -  override def setOutputCol(value: String): this.type = set(outputCol, value)
    +  def setOutputCol(value: String): this.type = set(outputCol, value)
     
       override def transformSchema(schema: StructType): StructType = {
    -    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
    -    val inputFields = schema.fields
    +    val is = "_is_"
    +    val inputColName = $(inputCol)
         val outputColName = $(outputCol)
    -    require(inputFields.forall(_.name != $(outputCol)),
    -      s"Output column ${$(outputCol)} already exists.")
     
    -    val inputColAttr = Attribute.fromStructField(schema($(inputCol)))
    -    categories = inputColAttr match {
    +    SchemaUtils.checkColumnType(schema, inputColName, DoubleType)
    +    val inputFields = schema.fields
    +    require(!inputFields.exists(_.name == outputColName),
    +      s"Output column $outputColName already exists.")
    +
    +    val inputAttr = Attribute.fromStructField(schema(inputColName))
    +    val outputAttrNames: Option[Array[String]] = inputAttr match {
           case nominal: NominalAttribute =>
    -        nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray)
    -      case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1"))
    +        if (nominal.values.isDefined) {
    +          nominal.values.map(_.map(v => inputColName + is + v))
    +        } else if (nominal.numValues.isDefined) {
    +          nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i))
    +        } else {
    +          None
    +        }
    +      case binary: BinaryAttribute =>
    +        if (binary.values.isDefined) {
    +          binary.values.map(_.map(v => inputColName + is + v))
    +        } else {
    +          Some(Array.tabulate(2)(i => inputColName + is + i))
    +        }
    +      case _: NumericAttribute =>
    +        throw new RuntimeException(
    +          s"The input column $inputColName cannot be numeric.")
           case _ =>
    -        throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal")
    +        None // optimistic about unknown attributes
    +    }
    +
    +    val filteredOutputAttrNames = outputAttrNames.map { names =>
    +      if ($(dropLast)) {
    +        require(names.length > 1,
    +          s"The input column $inputColName should have at least two distinct values.")
    +        names.dropRight(1)
    +      } else {
    +        names
    +      }
         }
     
    -    val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray
    -    val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues)
    -    val outputFields = inputFields :+ attr.toStructField()
    +    val outputAttrGroup = if (filteredOutputAttrNames.isDefined) {
    +      val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name =>
    +        BinaryAttribute.defaultAttr.withName(name)
    +      }
    +      new AttributeGroup($(outputCol), attrs)
    +    } else {
    +      new AttributeGroup($(outputCol))
    +    }
    +
    +    val outputFields = inputFields :+ outputAttrGroup.toStructField()
         StructType(outputFields)
       }
     
    -  protected override def createTransformFunc(): (Double) => Vector = {
    -    val first = $(includeFirst)
    -    val vecLen = if (first) categories.length else categories.length - 1
    +  override def transform(dataset: DataFrame): DataFrame = {
    +    // schema transformation
    +    val is = "_is_"
    +    val inputColName = $(inputCol)
    +    val outputColName = $(outputCol)
    +    val shouldDropLast = $(dropLast)
    +    var outputAttrGroup = AttributeGroup.fromStructField(
    +      transformSchema(dataset.schema)(outputColName))
    +    if (outputAttrGroup.size < 0) {
    +      // If the number of attributes is unknown, we check the values from the input column.
    +      val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0))
    +        .aggregate(0.0)(
    +          (m, x) => {
    +            assert(x >=0.0 && x == x.toInt,
    +              s"Values from column $inputColName must be indices, but got $x.")
    +            math.max(m, x)
    +          },
    +          (m0, m1) => {
    +            math.max(m0, m1)
    +          }
    +        ).toInt + 1
    +      val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i)
    +      val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
    +      val outputAttrs: Array[Attribute] =
    +        filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
    +      outputAttrGroup = new AttributeGroup(outputColName, outputAttrs)
    +    }
    +    val metadata = outputAttrGroup.toMetadata()
    +
    +    // data transformation
    +    val size = outputAttrGroup.size
         val oneValue = Array(1.0)
         val emptyValues = Array[Double]()
         val emptyIndices = Array[Int]()
    -    label: Double => {
    -      val values = if (first || label != 0.0) oneValue else emptyValues
    -      val indices = if (first) {
    -        Array(label.toInt)
    -      } else if (label != 0.0) {
    -        Array(label.toInt - 1)
    +    val encode = udf { label: Double =>
    +      if (label < size) {
    +        Vectors.sparse(size, Array(label.toInt), oneValue)
           } else {
    -        emptyIndices
    +        Vectors.sparse(size, emptyIndices, emptyValues)
           }
    -      Vectors.sparse(vecLen, indices, values)
         }
    +
    +    dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata))
       }
     
    -  /**
    -   * Returns the data type of the output column.
    -   */
    -  protected def outputDataType: DataType = new VectorUDT
    +  override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra)
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
    new file mode 100644
    index 0000000000000..2d3bb680cf309
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
    @@ -0,0 +1,130 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml._
    +import org.apache.spark.ml.param._
    +import org.apache.spark.ml.param.shared._
    +import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.mllib.feature
    +import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
    +import org.apache.spark.sql._
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types.{StructField, StructType}
    +
    +/**
    + * Params for [[PCA]] and [[PCAModel]].
    + */
    +private[feature] trait PCAParams extends Params with HasInputCol with HasOutputCol {
    +
    +  /**
    +   * The number of principal components.
    +   * @group param
    +   */
    +  final val k: IntParam = new IntParam(this, "k", "the number of principal components")
    +
    +  /** @group getParam */
    +  def getK: Int = $(k)
    +
    +}
    +
    +/**
    + * :: Experimental ::
    + * PCA trains a model to project vectors to a low-dimensional space using PCA.
    + */
    +@Experimental
    +class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams {
    +
    +  def this() = this(Identifiable.randomUID("pca"))
    +
    +  /** @group setParam */
    +  def setInputCol(value: String): this.type = set(inputCol, value)
    +
    +  /** @group setParam */
    +  def setOutputCol(value: String): this.type = set(outputCol, value)
    +
    +  /** @group setParam */
    +  def setK(value: Int): this.type = set(k, value)
    +
    +  /**
    +   * Computes a [[PCAModel]] that contains the principal components of the input vectors.
    +   */
    +  override def fit(dataset: DataFrame): PCAModel = {
    +    transformSchema(dataset.schema, logging = true)
    +    val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v}
    +    val pca = new feature.PCA(k = $(k))
    +    val pcaModel = pca.fit(input)
    +    copyValues(new PCAModel(uid, pcaModel).setParent(this))
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    val inputType = schema($(inputCol)).dataType
    +    require(inputType.isInstanceOf[VectorUDT],
    +      s"Input column ${$(inputCol)} must be a vector column")
    +    require(!schema.fieldNames.contains($(outputCol)),
    +      s"Output column ${$(outputCol)} already exists.")
    +    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
    +    StructType(outputFields)
    +  }
    +
    +  override def copy(extra: ParamMap): PCA = defaultCopy(extra)
    +}
    +
    +/**
    + * :: Experimental ::
    + * Model fitted by [[PCA]].
    + */
    +@Experimental
    +class PCAModel private[ml] (
    +    override val uid: String,
    +    pcaModel: feature.PCAModel)
    +  extends Model[PCAModel] with PCAParams {
    +
    +  /** @group setParam */
    +  def setInputCol(value: String): this.type = set(inputCol, value)
    +
    +  /** @group setParam */
    +  def setOutputCol(value: String): this.type = set(outputCol, value)
    +
    +  /**
    +   * Transform a vector by computed Principal Components.
    +   * NOTE: Vectors to be transformed must be the same length
    +   * as the source vectors given to [[PCA.fit()]].
    +   */
    +  override def transform(dataset: DataFrame): DataFrame = {
    +    transformSchema(dataset.schema, logging = true)
    +    val pcaOp = udf { pcaModel.transform _ }
    +    dataset.withColumn($(outputCol), pcaOp(col($(inputCol))))
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    val inputType = schema($(inputCol)).dataType
    +    require(inputType.isInstanceOf[VectorUDT],
    +      s"Input column ${$(inputCol)} must be a vector column")
    +    require(!schema.fieldNames.contains($(outputCol)),
    +      s"Output column ${$(outputCol)} already exists.")
    +    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
    +    StructType(outputFields)
    +  }
    +
    +  override def copy(extra: ParamMap): PCAModel = {
    +    val copied = new PCAModel(uid, pcaModel)
    +    copyValues(copied, extra)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
    index 9e6177ca27e4a..d85e468562d4a 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
    @@ -19,22 +19,26 @@ package org.apache.spark.ml.feature
     
     import scala.collection.mutable
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.UnaryTransformer
    -import org.apache.spark.ml.param.{IntParam, ParamValidators}
    +import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators}
    +import org.apache.spark.ml.util.Identifiable
     import org.apache.spark.mllib.linalg._
     import org.apache.spark.sql.types.DataType
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion,
      * which is available at [[http://en.wikipedia.org/wiki/Polynomial_expansion]], "In mathematics, an
      * expansion of a product of sums expresses it as a sum of products by using the fact that
      * multiplication distributes over addition". Take a 2-variable feature vector as an example:
      * `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, y, x * y, y * y)`.
      */
    -@AlphaComponent
    -class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] {
    +@Experimental
    +class PolynomialExpansion(override val uid: String)
    +  extends UnaryTransformer[Vector, Vector, PolynomialExpansion] {
    +
    +  def this() = this(Identifiable.randomUID("poly"))
     
       /**
        * The polynomial degree to expand, which should be >= 1.  A value of 1 means no expansion.
    @@ -57,6 +61,8 @@ class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExp
       }
     
       override protected def outputDataType: DataType = new VectorUDT()
    +
    +  override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra)
     }
     
     /**
    @@ -71,7 +77,7 @@ class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExp
      * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the
      * current index and increment it properly for sparse input.
      */
    -object PolynomialExpansion {
    +private[feature] object PolynomialExpansion {
     
       private def choose(n: Int, k: Int): Int = {
         Range(n, n - k, -1).product / Range(k, 1, -1).product
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
    index 7cad59ff3fa37..72b545e5db3e4 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
    @@ -17,10 +17,11 @@
     
     package org.apache.spark.ml.feature
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml._
     import org.apache.spark.ml.param._
     import org.apache.spark.ml.param.shared._
    +import org.apache.spark.ml.util.Identifiable
     import org.apache.spark.mllib.feature
     import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
     import org.apache.spark.sql._
    @@ -34,13 +35,13 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
     
       /**
        * Centers the data with mean before scaling.
    -   * It will build a dense output, so this does not work on sparse input 
    +   * It will build a dense output, so this does not work on sparse input
        * and will raise an exception.
        * Default: false
        * @group param
        */
       val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
    -  
    +
       /**
        * Scales the data to unit standard deviation.
        * Default: true
    @@ -50,12 +51,15 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
     }
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Standardizes features by removing the mean and scaling to unit variance using column summary
      * statistics on the samples in the training set.
      */
    -@AlphaComponent
    -class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
    +@Experimental
    +class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel]
    +  with StandardScalerParams {
    +
    +  def this() = this(Identifiable.randomUID("stdScal"))
     
       setDefault(withMean -> false, withStd -> true)
     
    @@ -64,19 +68,19 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
     
       /** @group setParam */
       def setOutputCol(value: String): this.type = set(outputCol, value)
    -  
    +
       /** @group setParam */
       def setWithMean(value: Boolean): this.type = set(withMean, value)
    -  
    +
       /** @group setParam */
       def setWithStd(value: Boolean): this.type = set(withStd, value)
    -  
    +
       override def fit(dataset: DataFrame): StandardScalerModel = {
         transformSchema(dataset.schema, logging = true)
         val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
         val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
         val scalerModel = scaler.fit(input)
    -    copyValues(new StandardScalerModel(this, scalerModel))
    +    copyValues(new StandardScalerModel(uid, scalerModel).setParent(this))
       }
     
       override def transformSchema(schema: StructType): StructType = {
    @@ -88,18 +92,26 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
         val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
         StructType(outputFields)
       }
    +
    +  override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
     }
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Model fitted by [[StandardScaler]].
      */
    -@AlphaComponent
    +@Experimental
     class StandardScalerModel private[ml] (
    -    override val parent: StandardScaler,
    +    override val uid: String,
         scaler: feature.StandardScalerModel)
       extends Model[StandardScalerModel] with StandardScalerParams {
     
    +  /** Standard deviation of the StandardScalerModel */
    +  val std: Vector = scaler.std
    +
    +  /** Mean of the StandardScalerModel */
    +  val mean: Vector = scaler.mean
    +
       /** @group setParam */
       def setInputCol(value: String): this.type = set(inputCol, value)
     
    @@ -121,4 +133,9 @@ class StandardScalerModel private[ml] (
         val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
         StructType(outputFields)
       }
    +
    +  override def copy(extra: ParamMap): StandardScalerModel = {
    +    val copied = new StandardScalerModel(uid, scaler)
    +    copyValues(copied, extra)
    +  }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
    index 3d78537ad84cb..bf7be363b8224 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
    @@ -18,11 +18,12 @@
     package org.apache.spark.ml.feature
     
     import org.apache.spark.SparkException
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{Estimator, Model}
     import org.apache.spark.ml.attribute.NominalAttribute
     import org.apache.spark.ml.param._
     import org.apache.spark.ml.param.shared._
    +import org.apache.spark.ml.util.Identifiable
     import org.apache.spark.sql.DataFrame
     import org.apache.spark.sql.functions._
     import org.apache.spark.sql.types.{NumericType, StringType, StructType}
    @@ -51,14 +52,17 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
     }
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * A label indexer that maps a string column of labels to an ML column of label indices.
      * If the input column is numeric, we cast it to string and index the string values.
      * The indices are in [0, numLabels), ordered by label frequencies.
      * So the most frequent label gets index 0.
      */
    -@AlphaComponent
    -class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase {
    +@Experimental
    +class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel]
    +  with StringIndexerBase {
    +
    +  def this() = this(Identifiable.randomUID("strIdx"))
     
       /** @group setParam */
       def setInputCol(value: String): this.type = set(inputCol, value)
    @@ -73,21 +77,26 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
           .map(_.getString(0))
           .countByValue()
         val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
    -    copyValues(new StringIndexerModel(this, labels))
    +    copyValues(new StringIndexerModel(uid, labels).setParent(this))
       }
     
       override def transformSchema(schema: StructType): StructType = {
         validateAndTransformSchema(schema)
       }
    +
    +  override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra)
     }
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Model fitted by [[StringIndexer]].
    + * NOTE: During transformation, if the input column does not exist,
    + * [[StringIndexerModel.transform]] would return the input dataset unmodified.
    + * This is a temporary fix for the case when target labels do not exist during prediction.
      */
    -@AlphaComponent
    +@Experimental
     class StringIndexerModel private[ml] (
    -    override val parent: StringIndexer,
    +    override val uid: String,
         labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
     
       private val labelToIndex: OpenHashMap[String, Double] = {
    @@ -108,6 +117,12 @@ class StringIndexerModel private[ml] (
       def setOutputCol(value: String): this.type = set(outputCol, value)
     
       override def transform(dataset: DataFrame): DataFrame = {
    +    if (!dataset.schema.fieldNames.contains($(inputCol))) {
    +      logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
    +        "Skip StringIndexerModel.")
    +      return dataset
    +    }
    +
         val indexer = udf { label: String =>
           if (labelToIndex.contains(label)) {
             labelToIndex(label)
    @@ -124,6 +139,16 @@ class StringIndexerModel private[ml] (
       }
     
       override def transformSchema(schema: StructType): StructType = {
    -    validateAndTransformSchema(schema)
    +    if (schema.fieldNames.contains($(inputCol))) {
    +      validateAndTransformSchema(schema)
    +    } else {
    +      // If the input column does not exist during transformation, we skip StringIndexerModel.
    +      schema
    +    }
    +  }
    +
    +  override def copy(extra: ParamMap): StringIndexerModel = {
    +    val copied = new StringIndexerModel(uid, labels)
    +    copyValues(copied, extra)
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
    index 649c217b16590..5f9f57a2ebcfa 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
    @@ -17,17 +17,22 @@
     
     package org.apache.spark.ml.feature
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.UnaryTransformer
     import org.apache.spark.ml.param._
    +import org.apache.spark.ml.util.Identifiable
     import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * A tokenizer that converts the input string to lowercase and then splits it by white spaces.
    + *
    + * @see [[RegexTokenizer]]
      */
    -@AlphaComponent
    -class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
    +@Experimental
    +class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] {
    +
    +  def this() = this(Identifiable.randomUID("tok"))
     
       override protected def createTransformFunc: String => Seq[String] = {
         _.toLowerCase.split("\\s")
    @@ -38,24 +43,29 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
       }
     
       override protected def outputDataType: DataType = new ArrayType(StringType, false)
    +
    +  override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
     }
     
     /**
    - * :: AlphaComponent ::
    - * A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default)
    - * or using it to split the text (set matching to false). Optional parameters also allow filtering
    - * tokens using a minimal length.
    + * :: Experimental ::
    + * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split
    + * the text (default) or repeatedly matching the regex (if `gaps` is true).
    + * Optional parameters also allow filtering tokens using a minimal length.
      * It returns an array of strings that can be empty.
      */
    -@AlphaComponent
    -class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] {
    +@Experimental
    +class RegexTokenizer(override val uid: String)
    +  extends UnaryTransformer[String, Seq[String], RegexTokenizer] {
    +
    +  def this() = this(Identifiable.randomUID("regexTok"))
     
       /**
        * Minimum token length, >= 0.
        * Default: 1, to avoid returning empty strings
        * @group param
        */
    -  val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length (>= 0)",
    +  val minTokenLength: IntParam = new IntParam(this, "minTokenLength", "minimum token length (>= 0)",
         ParamValidators.gtEq(0))
     
       /** @group setParam */
    @@ -65,8 +75,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
       def getMinTokenLength: Int = $(minTokenLength)
     
       /**
    -   * Indicates whether regex splits on gaps (true) or matching tokens (false).
    -   * Default: false
    +   * Indicates whether regex splits on gaps (true) or matches tokens (false).
    +   * Default: true
        * @group param
        */
       val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens")
    @@ -78,8 +88,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
       def getGaps: Boolean = $(gaps)
     
       /**
    -   * Regex pattern used by tokenizer.
    -   * Default: `"\\p{L}+|[^\\p{L}\\s]+"`
    +   * Regex pattern used to match delimiters if [[gaps]] is true or tokens if [[gaps]] is false.
    +   * Default: `"\\s+"`
        * @group param
        */
       val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing")
    @@ -90,7 +100,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
       /** @group getParam */
       def getPattern: String = $(pattern)
     
    -  setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+")
    +  setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+")
     
       override protected def createTransformFunc: String => Seq[String] = { str =>
         val re = $(pattern).r
    @@ -104,4 +114,6 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
       }
     
       override protected def outputDataType: DataType = new ArrayType(StringType, false)
    +
    +  override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
    index 796758a70ef18..9f83c2ee16178 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
    @@ -20,20 +20,26 @@ package org.apache.spark.ml.feature
     import scala.collection.mutable.ArrayBuilder
     
     import org.apache.spark.SparkException
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.Transformer
    +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
    +import org.apache.spark.ml.param.ParamMap
     import org.apache.spark.ml.param.shared._
    +import org.apache.spark.ml.util.Identifiable
     import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
     import org.apache.spark.sql.{DataFrame, Row}
     import org.apache.spark.sql.functions._
     import org.apache.spark.sql.types._
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * A feature transformer that merges multiple columns into a vector column.
      */
    -@AlphaComponent
    -class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
    +@Experimental
    +class VectorAssembler(override val uid: String)
    +  extends Transformer with HasInputCols with HasOutputCol {
    +
    +  def this() = this(Identifiable.randomUID("vecAssembler"))
     
       /** @group setParam */
       def setInputCols(value: Array[String]): this.type = set(inputCols, value)
    @@ -42,19 +48,59 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
       def setOutputCol(value: String): this.type = set(outputCol, value)
     
       override def transform(dataset: DataFrame): DataFrame = {
    +    // Schema transformation.
    +    val schema = dataset.schema
    +    lazy val first = dataset.first()
    +    val attrs = $(inputCols).flatMap { c =>
    +      val field = schema(c)
    +      val index = schema.fieldIndex(c)
    +      field.dataType match {
    +        case DoubleType =>
    +          val attr = Attribute.fromStructField(field)
    +          // If the input column doesn't have ML attribute, assume numeric.
    +          if (attr == UnresolvedAttribute) {
    +            Some(NumericAttribute.defaultAttr.withName(c))
    +          } else {
    +            Some(attr.withName(c))
    +          }
    +        case _: NumericType | BooleanType =>
    +          // If the input column type is a compatible scalar type, assume numeric.
    +          Some(NumericAttribute.defaultAttr.withName(c))
    +        case _: VectorUDT =>
    +          val group = AttributeGroup.fromStructField(field)
    +          if (group.attributes.isDefined) {
    +            // If attributes are defined, copy them with updated names.
    +            group.attributes.get.map { attr =>
    +              if (attr.name.isDefined) {
    +                // TODO: Define a rigorous naming scheme.
    +                attr.withName(c + "_" + attr.name.get)
    +              } else {
    +                attr
    +              }
    +            }
    +          } else {
    +            // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
    +            // from metadata, check the first row.
    +            val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
    +            Array.fill(numAttrs)(NumericAttribute.defaultAttr)
    +          }
    +      }
    +    }
    +    val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
    +
    +    // Data transformation.
         val assembleFunc = udf { r: Row =>
           VectorAssembler.assemble(r.toSeq: _*)
         }
    -    val schema = dataset.schema
    -    val inputColNames = $(inputCols)
    -    val args = inputColNames.map { c =>
    +    val args = $(inputCols).map { c =>
           schema(c).dataType match {
             case DoubleType => dataset(c)
             case _: VectorUDT => dataset(c)
             case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
           }
         }
    -    dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol)))
    +
    +    dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol), metadata))
       }
     
       override def transformSchema(schema: StructType): StructType = {
    @@ -72,10 +118,11 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
         }
         StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
       }
    +
    +  override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
     }
     
    -@AlphaComponent
    -object VectorAssembler {
    +private object VectorAssembler {
     
       private[feature] def assemble(vv: Any*): Vector = {
         val indices = ArrayBuilder.make[Int]
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
    index 2e6313ac14485..c73bdccdef5fa 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
    @@ -17,15 +17,20 @@
     
     package org.apache.spark.ml.feature
     
    -import org.apache.spark.annotation.AlphaComponent
    +import java.lang.{Double => JDouble, Integer => JInt}
    +import java.util.{Map => JMap}
    +
    +import scala.collection.JavaConverters._
    +
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{Estimator, Model}
     import org.apache.spark.ml.attribute._
    -import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}
    +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params}
     import org.apache.spark.ml.param.shared._
    -import org.apache.spark.ml.util.SchemaUtils
    +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
     import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
     import org.apache.spark.sql.{DataFrame, Row}
    -import org.apache.spark.sql.functions.callUDF
    +import org.apache.spark.sql.functions.udf
     import org.apache.spark.sql.types.{StructField, StructType}
     import org.apache.spark.util.collection.OpenHashSet
     
    @@ -51,8 +56,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
     }
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * Class for indexing categorical feature columns in a dataset of [[Vector]].
      *
      * This has 2 usage modes:
    @@ -86,8 +90,11 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
      *  - Add warning if a categorical feature has only 1 category.
      *  - Add option for allowing unknown categories.
      */
    -@AlphaComponent
    -class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams {
    +@Experimental
    +class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel]
    +  with VectorIndexerParams {
    +
    +  def this() = this(Identifiable.randomUID("vecIdx"))
     
       /** @group setParam */
       def setMaxCategories(value: Int): this.type = set(maxCategories, value)
    @@ -110,7 +117,9 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara
           iter.foreach(localCatStats.addVector)
           Iterator(localCatStats)
         }.reduce((stats1, stats2) => stats1.merge(stats2))
    -    copyValues(new VectorIndexerModel(this, numFeatures, categoryStats.getCategoryMaps))
    +    val model = new VectorIndexerModel(uid, numFeatures, categoryStats.getCategoryMaps)
    +      .setParent(this)
    +    copyValues(model)
       }
     
       override def transformSchema(schema: StructType): StructType = {
    @@ -122,6 +131,8 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara
         SchemaUtils.checkColumnType(schema, $(inputCol), dataType)
         SchemaUtils.appendColumn(schema, $(outputCol), dataType)
       }
    +
    +  override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra)
     }
     
     private object VectorIndexer {
    @@ -189,7 +200,8 @@ private object VectorIndexer {
     
         private def addDenseVector(dv: DenseVector): Unit = {
           var i = 0
    -      while (i < dv.size) {
    +      val size = dv.size
    +      while (i < size) {
             if (featureValueSets(i).size <= maxCategories) {
               featureValueSets(i).add(dv(i))
             }
    @@ -201,7 +213,8 @@ private object VectorIndexer {
           // TODO: This might be able to handle 0's more efficiently.
           var vecIndex = 0 // index into vector
           var k = 0 // index into non-zero elements
    -      while (vecIndex < sv.size) {
    +      val size = sv.size
    +      while (vecIndex < size) {
             val featureValue = if (k < sv.indices.length && vecIndex == sv.indices(k)) {
               k += 1
               sv.values(k - 1)
    @@ -218,8 +231,7 @@ private object VectorIndexer {
     }
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * Transform categorical features to use 0-based indices instead of their original values.
      *  - Categorical features are mapped to indices.
      *  - Continuous features (columns) are left unchanged.
    @@ -234,13 +246,18 @@ private object VectorIndexer {
      *                      Values are maps from original features values to 0-based category indices.
      *                      If a feature is not in this map, it is treated as continuous.
      */
    -@AlphaComponent
    +@Experimental
     class VectorIndexerModel private[ml] (
    -    override val parent: VectorIndexer,
    +    override val uid: String,
         val numFeatures: Int,
         val categoryMaps: Map[Int, Map[Double, Int]])
       extends Model[VectorIndexerModel] with VectorIndexerParams {
     
    +  /** Java-friendly version of [[categoryMaps]] */
    +  def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = {
    +    categoryMaps.mapValues(_.asJava).asJava.asInstanceOf[JMap[JInt, JMap[JDouble, JInt]]]
    +  }
    +
       /**
        * Pre-computed feature attributes, with some missing info.
        * In transform(), set attribute name and other info, if available.
    @@ -322,7 +339,8 @@ class VectorIndexerModel private[ml] (
       override def transform(dataset: DataFrame): DataFrame = {
         transformSchema(dataset.schema, logging = true)
         val newField = prepOutputField(dataset.schema)
    -    val newCol = callUDF(transformFunc, new VectorUDT, dataset($(inputCol)))
    +    val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
    +    val newCol = transformUDF(dataset($(inputCol)))
         dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
       }
     
    @@ -384,4 +402,9 @@ class VectorIndexerModel private[ml] (
         val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
         newAttributeGroup.toStructField()
       }
    +
    +  override def copy(extra: ParamMap): VectorIndexerModel = {
    +    val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps)
    +    copyValues(copied, extra)
    +  }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
    index 34ff92970129f..6ea6590956300 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
    @@ -17,11 +17,11 @@
     
     package org.apache.spark.ml.feature
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{Estimator, Model}
     import org.apache.spark.ml.param._
     import org.apache.spark.ml.param.shared._
    -import org.apache.spark.ml.util.SchemaUtils
    +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
     import org.apache.spark.mllib.feature
     import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
     import org.apache.spark.mllib.linalg.BLAS._
    @@ -37,6 +37,7 @@ private[feature] trait Word2VecBase extends Params
     
       /**
        * The dimension of the code that you want to transform from words.
    +   * @group param
        */
       final val vectorSize = new IntParam(
         this, "vectorSize", "the dimension of codes after transforming from words")
    @@ -47,6 +48,7 @@ private[feature] trait Word2VecBase extends Params
     
       /**
        * Number of partitions for sentences of words.
    +   * @group param
        */
       final val numPartitions = new IntParam(
         this, "numPartitions", "number of partitions for sentences of words")
    @@ -58,6 +60,7 @@ private[feature] trait Word2VecBase extends Params
       /**
        * The minimum number of times a token must appear to be included in the word2vec model's
        * vocabulary.
    +   * @group param
        */
       final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " +
         "appear to be included in the word2vec model's vocabulary")
    @@ -68,7 +71,6 @@ private[feature] trait Word2VecBase extends Params
     
       setDefault(stepSize -> 0.025)
       setDefault(maxIter -> 1)
    -  setDefault(seed -> 42L)
     
       /**
        * Validate and transform the input schema.
    @@ -80,12 +82,14 @@ private[feature] trait Word2VecBase extends Params
     }
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further
      * natural language processing or machine learning process.
      */
    -@AlphaComponent
    -final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase {
    +@Experimental
    +final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase {
    +
    +  def this() = this(Identifiable.randomUID("w2v"))
     
       /** @group setParam */
       def setInputCol(value: String): this.type = set(inputCol, value)
    @@ -122,21 +126,23 @@ final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase {
           .setSeed($(seed))
           .setVectorSize($(vectorSize))
           .fit(input)
    -    copyValues(new Word2VecModel(this, wordVectors))
    +    copyValues(new Word2VecModel(uid, wordVectors).setParent(this))
       }
     
       override def transformSchema(schema: StructType): StructType = {
         validateAndTransformSchema(schema)
       }
    +
    +  override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra)
     }
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Model fitted by [[Word2Vec]].
      */
    -@AlphaComponent
    +@Experimental
     class Word2VecModel private[ml] (
    -    override val parent: Word2Vec,
    +    override val uid: String,
         wordVectors: feature.Word2VecModel)
       extends Model[Word2VecModel] with Word2VecBase {
     
    @@ -176,4 +182,9 @@ class Word2VecModel private[ml] (
       override def transformSchema(schema: StructType): StructType = {
         validateAndTransformSchema(schema)
       }
    +
    +  override def copy(extra: ParamMap): Word2VecModel = {
    +    val copied = new Word2VecModel(uid, wordVectors)
    +    copyValues(copied, extra)
    +  }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java
    index 00d9c802e930d..87f4223964ada 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/package-info.java
    +++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java
    @@ -16,10 +16,10 @@
      */
     
     /**
    - * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly
    + * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly
      * assemble and configure practical machine learning pipelines.
      */
    -@AlphaComponent
    +@Experimental
     package org.apache.spark.ml;
     
    -import org.apache.spark.annotation.AlphaComponent;
    +import org.apache.spark.annotation.Experimental;
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala
    index ac75e9de1a8f2..c589d06d9f7e4 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala
    @@ -18,7 +18,7 @@
     package org.apache.spark
     
     /**
    - * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly
    + * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly
      * assemble and configure practical machine learning pipelines.
      *
      * @groupname param Parameters
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
    index 7ebbf106ee753..d034d7ec6b60e 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
    @@ -24,11 +24,11 @@ import scala.annotation.varargs
     import scala.collection.mutable
     import scala.collection.JavaConverters._
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.{DeveloperApi, Experimental}
     import org.apache.spark.ml.util.Identifiable
     
     /**
    - * :: AlphaComponent ::
    + * :: DeveloperApi ::
      * A param with self-contained documentation and optionally default value. Primitive-typed param
      * should use the specialized versions, which are more friendly to Java users.
      *
    @@ -39,13 +39,18 @@ import org.apache.spark.ml.util.Identifiable
      *                See [[ParamValidators]] for factory methods for common validation functions.
      * @tparam T param value type
      */
    -@AlphaComponent
    -class Param[T] (val parent: Params, val name: String, val doc: String, val isValid: T => Boolean)
    +@DeveloperApi
    +class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
       extends Serializable {
     
    -  def this(parent: Params, name: String, doc: String) =
    +  def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) =
    +    this(parent.uid, name, doc, isValid)
    +
    +  def this(parent: String, name: String, doc: String) =
         this(parent, name, doc, ParamValidators.alwaysTrue[T])
     
    +  def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
    +
       /**
        * Assert that the given value is valid for this parameter.
        *
    @@ -60,41 +65,34 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal
        */
       private[param] def validate(value: T): Unit = {
         if (!isValid(value)) {
    -      throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value." +
    -        s" Parameter description: $toString")
    +      throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value.")
         }
       }
     
    -  /**
    -   * Creates a param pair with the given value (for Java).
    -   */
    +  /** Creates a param pair with the given value (for Java). */
       def w(value: T): ParamPair[T] = this -> value
     
    -  /**
    -   * Creates a param pair with the given value (for Scala).
    -   */
    +  /** Creates a param pair with the given value (for Scala). */
       def ->(value: T): ParamPair[T] = ParamPair(this, value)
     
    -  /**
    -   * Converts this param's name, doc, and optionally its default value and the user-supplied
    -   * value in its parent to string.
    -   */
    -  override def toString: String = {
    -    val valueStr = if (parent.isDefined(this)) {
    -      val defaultValueStr = parent.getDefault(this).map("default: " + _)
    -      val currentValueStr = parent.get(this).map("current: " + _)
    -      (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")")
    -    } else {
    -      "(undefined)"
    +  override final def toString: String = s"${parent}__$name"
    +
    +  override final def hashCode: Int = toString.##
    +
    +  override final def equals(obj: Any): Boolean = {
    +    obj match {
    +      case p: Param[_] => (p.parent == parent) && (p.name == name)
    +      case _ => false
         }
    -    s"$name: $doc $valueStr"
       }
     }
     
     /**
    + * :: DeveloperApi ::
      * Factory methods for common validation functions for [[Param.isValid]].
      * The numerical methods only support Int, Long, Float, and Double.
      */
    +@DeveloperApi
     object ParamValidators {
     
       /** (private[param]) Default validation always return true */
    @@ -172,69 +170,136 @@ object ParamValidators {
     
     // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
     
    -/** Specialized version of [[Param[Double]]] for Java. */
    -class DoubleParam(parent: Params, name: String, doc: String, isValid: Double => Boolean)
    +/**
    + * :: DeveloperApi ::
    + * Specialized version of [[Param[Double]]] for Java.
    + */
    +@DeveloperApi
    +class DoubleParam(parent: String, name: String, doc: String, isValid: Double => Boolean)
       extends Param[Double](parent, name, doc, isValid) {
     
    -  def this(parent: Params, name: String, doc: String) =
    +  def this(parent: String, name: String, doc: String) =
         this(parent, name, doc, ParamValidators.alwaysTrue)
     
    +  def this(parent: Identifiable, name: String, doc: String, isValid: Double => Boolean) =
    +    this(parent.uid, name, doc, isValid)
    +
    +  def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
    +
    +  /** Creates a param pair with the given value (for Java). */
       override def w(value: Double): ParamPair[Double] = super.w(value)
     }
     
    -/** Specialized version of [[Param[Int]]] for Java. */
    -class IntParam(parent: Params, name: String, doc: String, isValid: Int => Boolean)
    +/**
    + * :: DeveloperApi ::
    + * Specialized version of [[Param[Int]]] for Java.
    + */
    +@DeveloperApi
    +class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolean)
       extends Param[Int](parent, name, doc, isValid) {
     
    -  def this(parent: Params, name: String, doc: String) =
    +  def this(parent: String, name: String, doc: String) =
         this(parent, name, doc, ParamValidators.alwaysTrue)
     
    +  def this(parent: Identifiable, name: String, doc: String, isValid: Int => Boolean) =
    +    this(parent.uid, name, doc, isValid)
    +
    +  def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
    +
    +  /** Creates a param pair with the given value (for Java). */
       override def w(value: Int): ParamPair[Int] = super.w(value)
     }
     
    -/** Specialized version of [[Param[Float]]] for Java. */
    -class FloatParam(parent: Params, name: String, doc: String, isValid: Float => Boolean)
    +/**
    + * :: DeveloperApi ::
    + * Specialized version of [[Param[Float]]] for Java.
    + */
    +@DeveloperApi
    +class FloatParam(parent: String, name: String, doc: String, isValid: Float => Boolean)
       extends Param[Float](parent, name, doc, isValid) {
     
    -  def this(parent: Params, name: String, doc: String) =
    +  def this(parent: String, name: String, doc: String) =
         this(parent, name, doc, ParamValidators.alwaysTrue)
     
    +  def this(parent: Identifiable, name: String, doc: String, isValid: Float => Boolean) =
    +    this(parent.uid, name, doc, isValid)
    +
    +  def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
    +
    +  /** Creates a param pair with the given value (for Java). */
       override def w(value: Float): ParamPair[Float] = super.w(value)
     }
     
    -/** Specialized version of [[Param[Long]]] for Java. */
    -class LongParam(parent: Params, name: String, doc: String, isValid: Long => Boolean)
    +/**
    + * :: DeveloperApi ::
    + * Specialized version of [[Param[Long]]] for Java.
    + */
    +@DeveloperApi
    +class LongParam(parent: String, name: String, doc: String, isValid: Long => Boolean)
       extends Param[Long](parent, name, doc, isValid) {
     
    -  def this(parent: Params, name: String, doc: String) =
    +  def this(parent: String, name: String, doc: String) =
         this(parent, name, doc, ParamValidators.alwaysTrue)
     
    +  def this(parent: Identifiable, name: String, doc: String, isValid: Long => Boolean) =
    +    this(parent.uid, name, doc, isValid)
    +
    +  def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
    +
    +  /** Creates a param pair with the given value (for Java). */
       override def w(value: Long): ParamPair[Long] = super.w(value)
     }
     
    -/** Specialized version of [[Param[Boolean]]] for Java. */
    -class BooleanParam(parent: Params, name: String, doc: String) // No need for isValid
    +/**
    + * :: DeveloperApi ::
    + * Specialized version of [[Param[Boolean]]] for Java.
    + */
    +@DeveloperApi
    +class BooleanParam(parent: String, name: String, doc: String) // No need for isValid
       extends Param[Boolean](parent, name, doc) {
     
    +  def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
    +
    +  /** Creates a param pair with the given value (for Java). */
       override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
     }
     
    -/** Specialized version of [[Param[Array[T]]]] for Java. */
    +/**
    + * :: DeveloperApi ::
    + * Specialized version of [[Param[Array[String]]]] for Java.
    + */
    +@DeveloperApi
     class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
       extends Param[Array[String]](parent, name, doc, isValid) {
     
       def this(parent: Params, name: String, doc: String) =
         this(parent, name, doc, ParamValidators.alwaysTrue)
     
    -  override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)
    -
       /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
       def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
     }
     
     /**
    - * A param amd its value.
    + * :: DeveloperApi ::
    + * Specialized version of [[Param[Array[Double]]]] for Java.
    + */
    +@DeveloperApi
    +class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array[Double] => Boolean)
    +  extends Param[Array[Double]](parent, name, doc, isValid) {
    +
    +  def this(parent: Params, name: String, doc: String) =
    +    this(parent, name, doc, ParamValidators.alwaysTrue)
    +
    +  /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
    +  def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] =
    +    w(value.asScala.map(_.asInstanceOf[Double]).toArray)
    +}
    +
    +/**
    + * :: Experimental ::
    + * A param and its value.
      */
    +@Experimental
     case class ParamPair[T](param: Param[T], value: T) {
       // This is *the* place Param.validate is called.  Whenever a parameter is specified, we should
       // always construct a ParamPair so that validate is called.
    @@ -242,16 +307,19 @@ case class ParamPair[T](param: Param[T], value: T) {
     }
     
     /**
    - * :: AlphaComponent ::
    + * :: DeveloperApi ::
      * Trait for components that take parameters. This also provides an internal param map to store
      * parameter values attached to the instance.
      */
    -@AlphaComponent
    +@DeveloperApi
     trait Params extends Identifiable with Serializable {
     
       /**
        * Returns all params sorted by their names. The default implementation uses Java reflection to
        * list all public methods that have no arguments and return [[Param]].
    +   *
    +   * Note: Developer should not use this method in constructor because we cannot guarantee that
    +   * this variable gets initialized before other params.
        */
       lazy val params: Array[Param[_]] = {
         val methods = this.getClass.getMethods
    @@ -264,37 +332,43 @@ trait Params extends Identifiable with Serializable {
       }
     
       /**
    -   * Validates parameter values stored internally plus the input parameter map.
    -   * Raises an exception if any parameter is invalid.
    +   * Validates parameter values stored internally.
    +   * Raise an exception if any parameter value is invalid.
        *
        * This only needs to check for interactions between parameters.
        * Parameter value checks which do not depend on other parameters are handled by
        * [[Param.validate()]].  This method does not handle input/output column parameters;
        * those are checked during schema validation.
        */
    -  def validateParams(paramMap: ParamMap): Unit = {
    -    copy(paramMap).validateParams()
    +  def validateParams(): Unit = {
    +    // Do nothing by default.  Override to handle Param interactions.
       }
     
       /**
    -   * Validates parameter values stored internally.
    -   * Raise an exception if any parameter value is invalid.
    -   *
    -   * This only needs to check for interactions between parameters.
    -   * Parameter value checks which do not depend on other parameters are handled by
    -   * [[Param.validate()]].  This method does not handle input/output column parameters;
    -   * those are checked during schema validation.
    +   * Explains a param.
    +   * @param param input param, must belong to this instance.
    +   * @return a string that contains the input param name, doc, and optionally its default value and
    +   *         the user-supplied value
        */
    -  def validateParams(): Unit = {
    -    params.filter(isDefined _).foreach { param =>
    -      param.asInstanceOf[Param[Any]].validate($(param))
    +  def explainParam(param: Param[_]): String = {
    +    shouldOwn(param)
    +    val valueStr = if (isDefined(param)) {
    +      val defaultValueStr = getDefault(param).map("default: " + _)
    +      val currentValueStr = get(param).map("current: " + _)
    +      (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")")
    +    } else {
    +      "(undefined)"
         }
    +    s"${param.name}: ${param.doc} $valueStr"
       }
     
       /**
    -   * Returns the documentation of all params.
    +   * Explains all params of this instance.
    +   * @see [[explainParam()]]
        */
    -  def explainParams(): String = params.mkString("\n")
    +  def explainParams(): String = {
    +    params.map(explainParam).mkString("\n")
    +  }
     
       /** Checks whether a param is explicitly set. */
       final def isSet(param: Param[_]): Boolean = {
    @@ -379,20 +453,18 @@ trait Params extends Identifiable with Serializable {
        * @param value  the default value
        */
       protected final def setDefault[T](param: Param[T], value: T): this.type = {
    -    shouldOwn(param)
    -    defaultParamMap.put(param, value)
    +    defaultParamMap.put(param -> value)
         this
       }
     
       /**
        * Sets default values for a list of params.
        *
    -   * Note: Java developers should use the single-parameter [[setDefault()]].
    -   *       Annotating this with varargs causes compilation failures. See SPARK-7498.
        * @param paramPairs  a list of param pairs that specify params and their default values to set
        *                    respectively. Make sure that the params are initialized before this method
        *                    gets called.
        */
    +  @varargs
       protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
         paramPairs.foreach { p =>
           setDefault(p.param.asInstanceOf[Param[Any]], p.value)
    @@ -417,24 +489,30 @@ trait Params extends Identifiable with Serializable {
       }
     
       /**
    -   * Creates a copy of this instance with a randomly generated uid and some extra params.
    -   * The default implementation calls the default constructor to create a new instance, then
    -   * copies the embedded and extra parameters over and returns the new instance.
    -   * Subclasses should override this method if the default approach is not sufficient.
    +   * Creates a copy of this instance with the same UID and some extra params.
    +   * Subclasses should implement this method and set the return type properly.
    +   *
    +   * @see [[defaultCopy()]]
    +   */
    +  def copy(extra: ParamMap): Params
    +
    +  /**
    +   * Default implementation of copy with extra params.
    +   * It tries to create a new instance with the same UID.
    +   * Then it copies the embedded and extra parameters over and returns the new instance.
        */
    -  def copy(extra: ParamMap): Params = {
    -    val that = this.getClass.newInstance()
    -    copyValues(that, extra)
    -    that
    +  protected final def defaultCopy[T <: Params](extra: ParamMap): T = {
    +    val that = this.getClass.getConstructor(classOf[String]).newInstance(uid)
    +    copyValues(that, extra).asInstanceOf[T]
       }
     
       /**
        * Extracts the embedded default param values and user-supplied values, and then merges them with
        * extra values from input into a flat param map, where the latter value is used if there exist
    -   * conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap.
    +   * conflicts, i.e., with ordering: default param values < user-supplied values < extra.
        */
    -  final def extractParamMap(extraParamMap: ParamMap): ParamMap = {
    -    defaultParamMap ++ paramMap ++ extraParamMap
    +  final def extractParamMap(extra: ParamMap): ParamMap = {
    +    defaultParamMap ++ paramMap ++ extra
       }
     
       /**
    @@ -452,7 +530,7 @@ trait Params extends Identifiable with Serializable {
     
       /** Validates that the input param belongs to this instance. */
       private def shouldOwn(param: Param[_]): Unit = {
    -    require(param.parent.eq(this), s"Param $param does not belong to $this.")
    +    require(param.parent == uid && hasParam(param.name), s"Param $param does not belong to $this.")
       }
     
       /**
    @@ -473,18 +551,20 @@ trait Params extends Identifiable with Serializable {
     }
     
     /**
    + * :: DeveloperApi ::
      * Java-friendly wrapper for [[Params]].
      * Java developers who need to extend [[Params]] should use this class instead.
      * If you need to extend a abstract class which already extends [[Params]], then that abstract
      * class should be Java-friendly as well.
      */
    +@DeveloperApi
     abstract class JavaParams extends Params
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * A param to value map.
      */
    -@AlphaComponent
    +@Experimental
     final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
       extends Serializable {
     
    @@ -502,7 +582,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
       /**
        * Puts a (param, value) pair (overwrites if the input param exists).
        */
    -  def put[T](param: Param[T], value: T): this.type = put(ParamPair(param, value))
    +  def put[T](param: Param[T], value: T): this.type = put(param -> value)
     
       /**
        * Puts a list of param pairs (overwrites if the input params exists).
    @@ -568,7 +648,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
     
       override def toString: String = {
         map.toSeq.sortBy(_._1.name).map { case (param, value) =>
    -      s"\t${param.parent.uid}-${param.name}: $value"
    +      s"\t${param.parent}-${param.name}: $value"
         }.mkString("{\n", ",\n", "\n}")
       }
     
    @@ -605,6 +685,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
       def size: Int = map.size
     }
     
    +@Experimental
     object ParamMap {
     
       /**
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
    index 5085b798daa17..f7ae1de522e01 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
    @@ -33,7 +33,7 @@ private[shared] object SharedParamsCodeGen {
         val params = Seq(
           ParamDesc[Double]("regParam", "regularization parameter (>= 0)",
             isValid = "ParamValidators.gtEq(0)"),
    -      ParamDesc[Int]("maxIter", "max number of iterations (>= 0)",
    +      ParamDesc[Int]("maxIter", "maximum number of iterations (>= 0)",
             isValid = "ParamValidators.gtEq(0)"),
           ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")),
           ParamDesc[String]("labelCol", "label column name", Some("\"label\"")),
    @@ -49,11 +49,13 @@ private[shared] object SharedParamsCodeGen {
             isValid = "ParamValidators.inRange(0, 1)"),
           ParamDesc[String]("inputCol", "input column name"),
           ParamDesc[Array[String]]("inputCols", "input column names"),
    -      ParamDesc[String]("outputCol", "output column name"),
    +      ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
           ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
             isValid = "ParamValidators.gtEq(1)"),
           ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
    -      ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")),
    +      ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
    +        " before fitting the model.", Some("true")),
    +      ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),
           ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." +
             " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
             isValid = "ParamValidators.inRange(0, 1)"),
    @@ -132,7 +134,7 @@ private[shared] object SharedParamsCodeGen {
     
         s"""
           |/**
    -      | * (private[ml]) Trait for shared param $name$defaultValueDoc.
    +      | * Trait for shared param $name$defaultValueDoc.
           | */
           |private[ml] trait Has$Name extends Params {
           |
    @@ -171,7 +173,6 @@ private[shared] object SharedParamsCodeGen {
             |package org.apache.spark.ml.param.shared
             |
             |import org.apache.spark.ml.param._
    -        |import org.apache.spark.util.Utils
             |
             |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
             |
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
    index 7525d37007377..65e48e4ee5083 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
    @@ -18,14 +18,13 @@
     package org.apache.spark.ml.param.shared
     
     import org.apache.spark.ml.param._
    -import org.apache.spark.util.Utils
     
     // DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
     
     // scalastyle:off
     
     /**
    - * (private[ml]) Trait for shared param regParam.
    + * Trait for shared param regParam.
      */
     private[ml] trait HasRegParam extends Params {
     
    @@ -40,22 +39,22 @@ private[ml] trait HasRegParam extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param maxIter.
    + * Trait for shared param maxIter.
      */
     private[ml] trait HasMaxIter extends Params {
     
       /**
    -   * Param for max number of iterations (>= 0).
    +   * Param for maximum number of iterations (>= 0).
        * @group param
        */
    -  final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0))
    +  final val maxIter: IntParam = new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", ParamValidators.gtEq(0))
     
       /** @group getParam */
       final def getMaxIter: Int = $(maxIter)
     }
     
     /**
    - * (private[ml]) Trait for shared param featuresCol (default: "features").
    + * Trait for shared param featuresCol (default: "features").
      */
     private[ml] trait HasFeaturesCol extends Params {
     
    @@ -72,7 +71,7 @@ private[ml] trait HasFeaturesCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param labelCol (default: "label").
    + * Trait for shared param labelCol (default: "label").
      */
     private[ml] trait HasLabelCol extends Params {
     
    @@ -89,7 +88,7 @@ private[ml] trait HasLabelCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param predictionCol (default: "prediction").
    + * Trait for shared param predictionCol (default: "prediction").
      */
     private[ml] trait HasPredictionCol extends Params {
     
    @@ -106,7 +105,7 @@ private[ml] trait HasPredictionCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param rawPredictionCol (default: "rawPrediction").
    + * Trait for shared param rawPredictionCol (default: "rawPrediction").
      */
     private[ml] trait HasRawPredictionCol extends Params {
     
    @@ -123,7 +122,7 @@ private[ml] trait HasRawPredictionCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param probabilityCol (default: "probability").
    + * Trait for shared param probabilityCol (default: "probability").
      */
     private[ml] trait HasProbabilityCol extends Params {
     
    @@ -140,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param threshold.
    + * Trait for shared param threshold.
      */
     private[ml] trait HasThreshold extends Params {
     
    @@ -155,7 +154,7 @@ private[ml] trait HasThreshold extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param inputCol.
    + * Trait for shared param inputCol.
      */
     private[ml] trait HasInputCol extends Params {
     
    @@ -170,7 +169,7 @@ private[ml] trait HasInputCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param inputCols.
    + * Trait for shared param inputCols.
      */
     private[ml] trait HasInputCols extends Params {
     
    @@ -185,7 +184,7 @@ private[ml] trait HasInputCols extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param outputCol.
    + * Trait for shared param outputCol (default: uid + "__output").
      */
     private[ml] trait HasOutputCol extends Params {
     
    @@ -195,12 +194,14 @@ private[ml] trait HasOutputCol extends Params {
        */
       final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name")
     
    +  setDefault(outputCol, uid + "__output")
    +
       /** @group getParam */
       final def getOutputCol: String = $(outputCol)
     }
     
     /**
    - * (private[ml]) Trait for shared param checkpointInterval.
    + * Trait for shared param checkpointInterval.
      */
     private[ml] trait HasCheckpointInterval extends Params {
     
    @@ -215,7 +216,7 @@ private[ml] trait HasCheckpointInterval extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param fitIntercept (default: true).
    + * Trait for shared param fitIntercept (default: true).
      */
     private[ml] trait HasFitIntercept extends Params {
     
    @@ -232,7 +233,24 @@ private[ml] trait HasFitIntercept extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param seed (default: Utils.random.nextLong()).
    + * Trait for shared param standardization (default: true).
    + */
    +private[ml] trait HasStandardization extends Params {
    +
    +  /**
    +   * Param for whether to standardize the training features before fitting the model..
    +   * @group param
    +   */
    +  final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model.")
    +
    +  setDefault(standardization, true)
    +
    +  /** @group getParam */
    +  final def getStandardization: Boolean = $(standardization)
    +}
    +
    +/**
    + * Trait for shared param seed (default: this.getClass.getName.hashCode.toLong).
      */
     private[ml] trait HasSeed extends Params {
     
    @@ -242,14 +260,14 @@ private[ml] trait HasSeed extends Params {
        */
       final val seed: LongParam = new LongParam(this, "seed", "random seed")
     
    -  setDefault(seed, Utils.random.nextLong())
    +  setDefault(seed, this.getClass.getName.hashCode.toLong)
     
       /** @group getParam */
       final def getSeed: Long = $(seed)
     }
     
     /**
    - * (private[ml]) Trait for shared param elasticNetParam.
    + * Trait for shared param elasticNetParam.
      */
     private[ml] trait HasElasticNetParam extends Params {
     
    @@ -264,7 +282,7 @@ private[ml] trait HasElasticNetParam extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param tol.
    + * Trait for shared param tol.
      */
     private[ml] trait HasTol extends Params {
     
    @@ -279,7 +297,7 @@ private[ml] trait HasTol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param stepSize.
    + * Trait for shared param stepSize.
      */
     private[ml] trait HasStepSize extends Params {
     
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
    index d7cbffc3be26f..2e44cd4cc6a22 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
    @@ -31,24 +31,50 @@ import org.apache.hadoop.fs.{FileSystem, Path}
     import org.netlib.util.intW
     
     import org.apache.spark.{Logging, Partitioner}
    -import org.apache.spark.annotation.DeveloperApi
    +import org.apache.spark.annotation.{DeveloperApi, Experimental}
     import org.apache.spark.ml.{Estimator, Model}
     import org.apache.spark.ml.param._
     import org.apache.spark.ml.param.shared._
    +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
     import org.apache.spark.mllib.optimization.NNLS
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
     import org.apache.spark.sql.functions._
    -import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
    +import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
     import org.apache.spark.storage.StorageLevel
     import org.apache.spark.util.Utils
     import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
     import org.apache.spark.util.random.XORShiftRandom
     
    +/**
    + * Common params for ALS and ALSModel.
    + */
    +private[recommendation] trait ALSModelParams extends Params with HasPredictionCol {
    +  /**
    +   * Param for the column name for user ids.
    +   * Default: "user"
    +   * @group param
    +   */
    +  val userCol = new Param[String](this, "userCol", "column name for user ids")
    +
    +  /** @group getParam */
    +  def getUserCol: String = $(userCol)
    +
    +  /**
    +   * Param for the column name for item ids.
    +   * Default: "item"
    +   * @group param
    +   */
    +  val itemCol = new Param[String](this, "itemCol", "column name for item ids")
    +
    +  /** @group getParam */
    +  def getItemCol: String = $(itemCol)
    +}
    +
     /**
      * Common params for ALS.
      */
    -private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
    +private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam
       with HasPredictionCol with HasCheckpointInterval with HasSeed {
     
       /**
    @@ -104,26 +130,6 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
       /** @group getParam */
       def getAlpha: Double = $(alpha)
     
    -  /**
    -   * Param for the column name for user ids.
    -   * Default: "user"
    -   * @group param
    -   */
    -  val userCol = new Param[String](this, "userCol", "column name for user ids")
    -
    -  /** @group getParam */
    -  def getUserCol: String = $(userCol)
    -
    -  /**
    -   * Param for the column name for item ids.
    -   * Default: "item"
    -   * @group param
    -   */
    -  val itemCol = new Param[String](this, "itemCol", "column name for item ids")
    -
    -  /** @group getParam */
    -  def getItemCol: String = $(itemCol)
    -
       /**
        * Param for the column name for ratings.
        * Default: "rating"
    @@ -147,7 +153,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
     
       setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
         implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
    -    ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, seed -> 0L)
    +    ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10)
     
       /**
        * Validates and transforms the input schema.
    @@ -155,58 +161,71 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
        * @return output schema
        */
       protected def validateAndTransformSchema(schema: StructType): StructType = {
    -    require(schema($(userCol)).dataType == IntegerType)
    -    require(schema($(itemCol)).dataType== IntegerType)
    +    SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
    +    SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
         val ratingType = schema($(ratingCol)).dataType
         require(ratingType == FloatType || ratingType == DoubleType)
    -    val predictionColName = $(predictionCol)
    -    require(!schema.fieldNames.contains(predictionColName),
    -      s"Prediction column $predictionColName already exists.")
    -    val newFields = schema.fields :+ StructField($(predictionCol), FloatType, nullable = false)
    -    StructType(newFields)
    +    SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
       }
     }
     
     /**
    + * :: Experimental ::
      * Model fitted by ALS.
    + *
    + * @param rank rank of the matrix factorization model
    + * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features`
    + * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features`
      */
    +@Experimental
     class ALSModel private[ml] (
    -    override val parent: ALS,
    -    k: Int,
    -    userFactors: RDD[(Int, Array[Float])],
    -    itemFactors: RDD[(Int, Array[Float])])
    -  extends Model[ALSModel] with ALSParams {
    +    override val uid: String,
    +    val rank: Int,
    +    @transient val userFactors: DataFrame,
    +    @transient val itemFactors: DataFrame)
    +  extends Model[ALSModel] with ALSModelParams {
    +
    +  /** @group setParam */
    +  def setUserCol(value: String): this.type = set(userCol, value)
    +
    +  /** @group setParam */
    +  def setItemCol(value: String): this.type = set(itemCol, value)
     
       /** @group setParam */
       def setPredictionCol(value: String): this.type = set(predictionCol, value)
     
       override def transform(dataset: DataFrame): DataFrame = {
    -    import dataset.sqlContext.implicits._
    -    val users = userFactors.toDF("id", "features")
    -    val items = itemFactors.toDF("id", "features")
    -
         // Register a UDF for DataFrame, and then
         // create a new column named map(predictionCol) by running the predict UDF.
         val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
           if (userFeatures != null && itemFeatures != null) {
    -        blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
    +        blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1)
           } else {
             Float.NaN
           }
         }
         dataset
    -      .join(users, dataset($(userCol)) === users("id"), "left")
    -      .join(items, dataset($(itemCol)) === items("id"), "left")
    -      .select(dataset("*"), predict(users("features"), items("features")).as($(predictionCol)))
    +      .join(userFactors, dataset($(userCol)) === userFactors("id"), "left")
    +      .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left")
    +      .select(dataset("*"),
    +        predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
       }
     
       override def transformSchema(schema: StructType): StructType = {
    -    validateAndTransformSchema(schema)
    +    SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
    +    SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
    +    SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
    +  }
    +
    +  override def copy(extra: ParamMap): ALSModel = {
    +    val copied = new ALSModel(uid, rank, userFactors, itemFactors)
    +    copyValues(copied, extra)
       }
     }
     
     
     /**
    + * :: Experimental ::
      * Alternating Least Squares (ALS) matrix factorization.
      *
      * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices,
    @@ -235,10 +254,13 @@ class ALSModel private[ml] (
      * indicated user
      * preferences rather than explicit ratings given to items.
      */
    -class ALS extends Estimator[ALSModel] with ALSParams {
    +@Experimental
    +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
     
       import org.apache.spark.ml.recommendation.ALS.Rating
     
    +  def this() = this(Identifiable.randomUID("als"))
    +
       /** @group setParam */
       def setRank(value: Int): this.type = set(rank, value)
     
    @@ -292,6 +314,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
       }
     
       override def fit(dataset: DataFrame): ALSModel = {
    +    import dataset.sqlContext.implicits._
         val ratings = dataset
           .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType),
             col($(ratingCol)).cast(FloatType))
    @@ -303,12 +326,17 @@ class ALS extends Estimator[ALSModel] with ALSParams {
           maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
           alpha = $(alpha), nonnegative = $(nonnegative),
           checkpointInterval = $(checkpointInterval), seed = $(seed))
    -    copyValues(new ALSModel(this, $(rank), userFactors, itemFactors))
    +    val userDF = userFactors.toDF("id", "features")
    +    val itemDF = itemFactors.toDF("id", "features")
    +    val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
    +    copyValues(model)
       }
     
       override def transformSchema(schema: StructType): StructType = {
         validateAndTransformSchema(schema)
       }
    +
    +  override def copy(extra: ParamMap): ALS = defaultCopy(extra)
     }
     
     /**
    @@ -322,7 +350,11 @@ class ALS extends Estimator[ALSModel] with ALSParams {
     @DeveloperApi
     object ALS extends Logging {
     
    -  /** Rating class for better code readability. */
    +  /**
    +   * :: DeveloperApi ::
    +   * Rating class for better code readability.
    +   */
    +  @DeveloperApi
       case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)
     
       /** Trait for least squares solvers applied to the normal equation. */
    @@ -483,8 +515,10 @@ object ALS extends Logging {
       }
     
       /**
    +   * :: DeveloperApi ::
        * Implementation of the ALS algorithm.
        */
    +  @DeveloperApi
       def train[ID: ClassTag]( // scalastyle:ignore
           ratings: RDD[Rating[ID]],
           rank: Int = 10,
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
    index f8f0b161a4812..be1f8063d41d8 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
    @@ -17,11 +17,11 @@
     
     package org.apache.spark.ml.regression
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.ParamMap
    -import org.apache.spark.ml.tree.{TreeRegressorParams, DecisionTreeParams, DecisionTreeModel, Node}
    -import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams}
    +import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
    @@ -31,17 +31,18 @@ import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
      * for regression.
      * It supports both continuous and categorical features.
      */
    -@AlphaComponent
    -final class DecisionTreeRegressor
    +@Experimental
    +final class DecisionTreeRegressor(override val uid: String)
       extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
       with DecisionTreeParams with TreeRegressorParams {
     
    +  def this() = this(Identifiable.randomUID("dtr"))
    +
       // Override parameter setters from parent trait for Java API compatibility.
     
       override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
    @@ -75,23 +76,25 @@ final class DecisionTreeRegressor
         super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
           subsamplingRate = 1.0)
       }
    +
    +  override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra)
     }
     
    +@Experimental
     object DecisionTreeRegressor {
       /** Accessor for supported impurities: variance */
       final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
     }
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression.
      * It supports both continuous and categorical features.
      * @param rootNode  Root of the decision tree
      */
    -@AlphaComponent
    +@Experimental
     final class DecisionTreeRegressionModel private[ml] (
    -    override val parent: DecisionTreeRegressor,
    +    override val uid: String,
         override val rootNode: Node)
       extends PredictionModel[Vector, DecisionTreeRegressionModel]
       with DecisionTreeModel with Serializable {
    @@ -104,7 +107,7 @@ final class DecisionTreeRegressionModel private[ml] (
       }
     
       override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
    -    copyValues(new DecisionTreeRegressionModel(parent, rootNode), extra)
    +    copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra)
       }
     
       override def toString: String = {
    @@ -128,6 +131,7 @@ private[ml] object DecisionTreeRegressionModel {
           s"Cannot convert non-regression DecisionTreeModel (old API) to" +
             s" DecisionTreeRegressionModel (new API).  Algo is: ${oldModel.algo}")
         val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
    -    new DecisionTreeRegressionModel(parent, rootNode)
    +    val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
    +    new DecisionTreeRegressionModel(uid, rootNode)
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
    index 461905c12701a..47c110d027d67 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
    @@ -20,11 +20,11 @@ package org.apache.spark.ml.regression
     import com.github.fommil.netlib.BLAS.{getInstance => blas}
     
     import org.apache.spark.Logging
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.{Param, ParamMap}
    -import org.apache.spark.ml.tree.{GBTParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel}
    -import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams}
    +import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
    @@ -35,17 +35,18 @@ import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
      * learning algorithm for regression.
      * It supports both continuous and categorical features.
      */
    -@AlphaComponent
    -final class GBTRegressor
    +@Experimental
    +final class GBTRegressor(override val uid: String)
       extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
       with GBTParams with TreeRegressorParams with Logging {
     
    +  def this() = this(Identifiable.randomUID("gbtr"))
    +
       // Override parameter setters from parent trait for Java API compatibility.
     
       // Parameters from TreeRegressorParams:
    @@ -130,8 +131,11 @@ final class GBTRegressor
         val oldModel = oldGBT.run(oldDataset)
         GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures)
       }
    +
    +  override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra)
     }
     
    +@Experimental
     object GBTRegressor {
       // The losses below should be lowercase.
       /** Accessor for supported loss settings: squared (L2), absolute (L1) */
    @@ -139,7 +143,7 @@ object GBTRegressor {
     }
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      *
      * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
      * model for regression.
    @@ -147,9 +151,9 @@ object GBTRegressor {
      * @param _trees  Decision trees in the ensemble.
      * @param _treeWeights  Weights for the decision trees in the ensemble.
      */
    -@AlphaComponent
    +@Experimental
     final class GBTRegressionModel(
    -    override val parent: GBTRegressor,
    +    override val uid: String,
         private val _trees: Array[DecisionTreeRegressionModel],
         private val _treeWeights: Array[Double])
       extends PredictionModel[Vector, GBTRegressionModel]
    @@ -168,12 +172,11 @@ final class GBTRegressionModel(
         // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
         // Classifies by thresholding sum of weighted tree predictions
         val treePredictions = _trees.map(_.rootNode.predict(features))
    -    val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
    -    if (prediction > 0.0) 1.0 else 0.0
    +    blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
       }
     
       override def copy(extra: ParamMap): GBTRegressionModel = {
    -    copyValues(new GBTRegressionModel(parent, _trees, _treeWeights), extra)
    +    copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra)
       }
     
       override def toString: String = {
    @@ -196,9 +199,10 @@ private[ml] object GBTRegressionModel {
         require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" +
           s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).")
         val newTrees = oldModel.trees.map { tree =>
    -      // parent, fittingParamMap for each tree is null since there are no good ways to set these.
    +      // parent for each tree is null since there is no good way to set this.
           DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
         }
    -    new GBTRegressionModel(parent, newTrees, oldModel.treeWeights)
    +    val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
    +    new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights)
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
    index 6377923afc0c4..8fc986056657d 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
    @@ -20,20 +20,22 @@ package org.apache.spark.ml.regression
     import scala.collection.mutable
     
     import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
    -import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
    -  OWLQN => BreezeOWLQN}
    +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
     
    -import org.apache.spark.Logging
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.{Logging, SparkException}
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.PredictorParams
     import org.apache.spark.ml.param.ParamMap
    -import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
    +import org.apache.spark.ml.param.shared._
    +import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.mllib.evaluation.RegressionMetrics
     import org.apache.spark.mllib.linalg.{Vector, Vectors}
     import org.apache.spark.mllib.linalg.BLAS._
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
     import org.apache.spark.rdd.RDD
    -import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.{DataFrame, Row}
    +import org.apache.spark.sql.functions.{col, udf}
     import org.apache.spark.storage.StorageLevel
     import org.apache.spark.util.StatCounter
     
    @@ -41,11 +43,11 @@ import org.apache.spark.util.StatCounter
      * Params for linear regression.
      */
     private[regression] trait LinearRegressionParams extends PredictorParams
    -  with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
    +    with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
    +    with HasFitIntercept
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * Linear regression.
      *
      * The learning objective is to minimize the squared error, with regularization.
    @@ -58,10 +60,13 @@ private[regression] trait LinearRegressionParams extends PredictorParams
      *  - L1 (Lasso)
      *  - L2 + L1 (elastic net)
      */
    -@AlphaComponent
    -class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
    +@Experimental
    +class LinearRegression(override val uid: String)
    +  extends Regressor[Vector, LinearRegression, LinearRegressionModel]
       with LinearRegressionParams with Logging {
     
    +  def this() = this(Identifiable.randomUID("linReg"))
    +
       /**
        * Set the regularization parameter.
        * Default is 0.0.
    @@ -70,6 +75,14 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
       def setRegParam(value: Double): this.type = set(regParam, value)
       setDefault(regParam -> 0.0)
     
    +  /**
    +   * Set if we should fit the intercept
    +   * Default is true.
    +   * @group setParam
    +   */
    +  def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
    +  setDefault(fitIntercept -> true)
    +
       /**
        * Set the ElasticNet mixing parameter.
        * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
    @@ -81,7 +94,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
       setDefault(elasticNetParam -> 0.0)
     
       /**
    -   * Set the maximal number of iterations.
    +   * Set the maximum number of iterations.
        * Default is 100.
        * @group setParam
        */
    @@ -128,7 +141,16 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
           logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
             s"and the intercept will be the mean of the label; as a result, training is not needed.")
           if (handlePersistence) instances.unpersist()
    -      return new LinearRegressionModel(this, Vectors.sparse(numFeatures, Seq()), yMean)
    +      val weights = Vectors.sparse(numFeatures, Seq())
    +      val intercept = yMean
    +
    +      val model = new LinearRegressionModel(uid, weights, intercept)
    +      val trainingSummary = new LinearRegressionTrainingSummary(
    +        model.transform(dataset).select($(predictionCol), $(labelCol)),
    +        $(predictionCol),
    +        $(labelCol),
    +        Array(0D))
    +      return copyValues(model.setSummary(trainingSummary))
         }
     
         val featuresMean = summarizer.mean.toArray
    @@ -140,7 +162,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
         val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam
         val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
     
    -    val costFun = new LeastSquaresCostFun(instances, yStd, yMean,
    +    val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
           featuresStd, featuresMean, effectiveL2RegParam)
     
         val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
    @@ -150,63 +172,197 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
         }
     
         val initialWeights = Vectors.zeros(numFeatures)
    -    val states =
    -      optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
    -
    -    var state = states.next()
    -    val lossHistory = mutable.ArrayBuilder.make[Double]
    -
    -    while (states.hasNext) {
    -      lossHistory += state.value
    -      state = states.next()
    -    }
    -    lossHistory += state.value
    +    val states = optimizer.iterations(new CachedDiffFunction(costFun),
    +      initialWeights.toBreeze.toDenseVector)
    +
    +    val (weights, objectiveHistory) = {
    +      /*
    +         Note that in Linear Regression, the objective history (loss + regularization) returned
    +         from optimizer is computed in the scaled space given by the following formula.
    +         {{{
    +         L = 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2 + regTerms
    +         }}}
    +       */
    +      val arrayBuilder = mutable.ArrayBuilder.make[Double]
    +      var state: optimizer.State = null
    +      while (states.hasNext) {
    +        state = states.next()
    +        arrayBuilder += state.adjustedValue
    +      }
    +      if (state == null) {
    +        val msg = s"${optimizer.getClass.getName} failed."
    +        logError(msg)
    +        throw new SparkException(msg)
    +      }
     
    -    // The weights are trained in the scaled space; we're converting them back to
    -    // the original space.
    -    val weights = {
    +      /*
    +         The weights are trained in the scaled space; we're converting them back to
    +         the original space.
    +       */
           val rawWeights = state.x.toArray.clone()
           var i = 0
    -      while (i < rawWeights.length) {
    +      val len = rawWeights.length
    +      while (i < len) {
             rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 }
             i += 1
           }
    -      Vectors.dense(rawWeights)
    +
    +      (Vectors.dense(rawWeights).compressed, arrayBuilder.result())
         }
     
    -    // The intercept in R's GLMNET is computed using closed form after the coefficients are
    -    // converged. See the following discussion for detail.
    -    // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
    -    val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
    +    /*
    +       The intercept in R's GLMNET is computed using closed form after the coefficients are
    +       converged. See the following discussion for detail.
    +       http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
    +     */
    +    val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0
    +
         if (handlePersistence) instances.unpersist()
     
    -    // TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
    -    new LinearRegressionModel(this, weights.compressed, intercept)
    +    val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
    +    val trainingSummary = new LinearRegressionTrainingSummary(
    +      model.transform(dataset).select($(predictionCol), $(labelCol)),
    +      $(predictionCol),
    +      $(labelCol),
    +      objectiveHistory)
    +    model.setSummary(trainingSummary)
       }
    +
    +  override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
     }
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * Model produced by [[LinearRegression]].
      */
    -@AlphaComponent
    +@Experimental
     class LinearRegressionModel private[ml] (
    -    override val parent: LinearRegression,
    +    override val uid: String,
         val weights: Vector,
         val intercept: Double)
       extends RegressionModel[Vector, LinearRegressionModel]
       with LinearRegressionParams {
     
    +  private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
    +
    +  /**
    +   * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is
    +   * thrown if `trainingSummary == None`.
    +   */
    +  def summary: LinearRegressionTrainingSummary = trainingSummary match {
    +    case Some(summ) => summ
    +    case None =>
    +      throw new SparkException(
    +        "No training summary available for this LinearRegressionModel",
    +        new NullPointerException())
    +  }
    +
    +  private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = {
    +    this.trainingSummary = Some(summary)
    +    this
    +  }
    +
    +  /** Indicates whether a training summary exists for this model instance. */
    +  def hasSummary: Boolean = trainingSummary.isDefined
    +
    +  /**
    +   * Evaluates the model on a testset.
    +   * @param dataset Test dataset to evaluate model on.
    +   */
    +  // TODO: decide on a good name before exposing to public API
    +  private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = {
    +    val t = udf { features: Vector => predict(features) }
    +    val predictionAndObservations = dataset
    +      .select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol)))
    +
    +    new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol))
    +  }
    +
       override protected def predict(features: Vector): Double = {
         dot(features, weights) + intercept
       }
     
       override def copy(extra: ParamMap): LinearRegressionModel = {
    -    copyValues(new LinearRegressionModel(parent, weights, intercept), extra)
    +    val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept))
    +    if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
    +    newModel
       }
     }
     
    +/**
    + * :: Experimental ::
    + * Linear regression training results.
    + * @param predictions predictions outputted by the model's `transform` method.
    + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
    + */
    +@Experimental
    +class LinearRegressionTrainingSummary private[regression] (
    +    predictions: DataFrame,
    +    predictionCol: String,
    +    labelCol: String,
    +    val objectiveHistory: Array[Double])
    +  extends LinearRegressionSummary(predictions, predictionCol, labelCol) {
    +
    +  /** Number of training iterations until termination */
    +  val totalIterations = objectiveHistory.length
    +
    +}
    +
    +/**
    + * :: Experimental ::
    + * Linear regression results evaluated on a dataset.
    + * @param predictions predictions outputted by the model's `transform` method.
    + */
    +@Experimental
    +class LinearRegressionSummary private[regression] (
    +    @transient val predictions: DataFrame,
    +    val predictionCol: String,
    +    val labelCol: String) extends Serializable {
    +
    +  @transient private val metrics = new RegressionMetrics(
    +    predictions
    +      .select(predictionCol, labelCol)
    +      .map { case Row(pred: Double, label: Double) => (pred, label) } )
    +
    +  /**
    +   * Returns the explained variance regression score.
    +   * explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
    +   * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
    +   */
    +  val explainedVariance: Double = metrics.explainedVariance
    +
    +  /**
    +   * Returns the mean absolute error, which is a risk function corresponding to the
    +   * expected value of the absolute error loss or l1-norm loss.
    +   */
    +  val meanAbsoluteError: Double = metrics.meanAbsoluteError
    +
    +  /**
    +   * Returns the mean squared error, which is a risk function corresponding to the
    +   * expected value of the squared error loss or quadratic loss.
    +   */
    +  val meanSquaredError: Double = metrics.meanSquaredError
    +
    +  /**
    +   * Returns the root mean squared error, which is defined as the square root of
    +   * the mean squared error.
    +   */
    +  val rootMeanSquaredError: Double = metrics.rootMeanSquaredError
    +
    +  /**
    +   * Returns R^2^, the coefficient of determination.
    +   * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
    +   */
    +  val r2: Double = metrics.r2
    +
    +  /** Residuals (predicted value - label value) */
    +  @transient lazy val residuals: DataFrame = {
    +    val t = udf { (pred: Double, label: Double) => pred - label}
    +    predictions.select(t(col(predictionCol), col(labelCol)).as("residuals"))
    +  }
    +
    +}
    +
     /**
      * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function,
      * as used in linear regression for samples in sparse or dense vector in a online fashion.
    @@ -230,6 +386,7 @@ class LinearRegressionModel private[ml] (
      * See this discussion for detail.
      * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
      *
    + * When training with intercept enabled,
      * The objective function in the scaled space is given by
      * {{{
      * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2,
    @@ -237,6 +394,10 @@ class LinearRegressionModel private[ml] (
      * where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i,
      * \bar{y} is the mean of label, and \hat{y} is the standard deviation of label.
      *
    + * If we fitting the intercept disabled (that is forced through 0.0),
    + * we can use the same equation except we set \bar{y} and \bar{x_i} to 0 instead
    + * of the respective means.
    + *
      * This can be rewritten as
      * {{{
      * L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y}
    @@ -251,6 +412,7 @@ class LinearRegressionModel private[ml] (
      * \sum_i w_i^\prime x_i - y / \hat{y} + offset
      * }}}
      *
    + *
      * Note that the effective weights and offset don't depend on training dataset,
      * so they can be precomputed.
      *
    @@ -297,6 +459,7 @@ private class LeastSquaresAggregator(
         weights: Vector,
         labelStd: Double,
         labelMean: Double,
    +    fitIntercept: Boolean,
         featuresStd: Array[Double],
         featuresMean: Array[Double]) extends Serializable {
     
    @@ -307,7 +470,8 @@ private class LeastSquaresAggregator(
         val weightsArray = weights.toArray.clone()
         var sum = 0.0
         var i = 0
    -    while (i < weightsArray.length) {
    +    val len = weightsArray.length
    +    while (i < len) {
           if (featuresStd(i) != 0.0) {
             weightsArray(i) /=  featuresStd(i)
             sum += weightsArray(i) * featuresMean(i)
    @@ -316,9 +480,9 @@ private class LeastSquaresAggregator(
           }
           i += 1
         }
    -    (weightsArray, -sum + labelMean / labelStd, weightsArray.length)
    +    (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length)
       }
    -  
    +
       private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
     
       private val gradientSumArray = Array.ofDim[Double](dim)
    @@ -399,6 +563,7 @@ private class LeastSquaresCostFun(
         data: RDD[(Double, Vector)],
         labelStd: Double,
         labelMean: Double,
    +    fitIntercept: Boolean,
         featuresStd: Array[Double],
         featuresMean: Array[Double],
         effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
    @@ -407,7 +572,7 @@ private class LeastSquaresCostFun(
         val w = Vectors.fromBreeze(weights)
     
         val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd,
    -      labelMean, featuresStd, featuresMean))(
    +      labelMean, fitIntercept, featuresStd, featuresMean))(
             seqOp = (c, v) => (c, v) match {
               case (aggregator, (label, features)) => aggregator.add(label, features)
             },
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
    index dbc628927433d..21c59061a02fa 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
    @@ -17,11 +17,11 @@
     
     package org.apache.spark.ml.regression
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.ParamMap
    -import org.apache.spark.ml.tree.{RandomForestParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel}
    -import org.apache.spark.ml.util.MetadataUtils
    +import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
    +import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
    @@ -31,16 +31,17 @@ import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * [[http://en.wikipedia.org/wiki/Random_forest  Random Forest]] learning algorithm for regression.
      * It supports both continuous and categorical features.
      */
    -@AlphaComponent
    -final class RandomForestRegressor
    +@Experimental
    +final class RandomForestRegressor(override val uid: String)
       extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
       with RandomForestParams with TreeRegressorParams {
     
    +  def this() = this(Identifiable.randomUID("rfr"))
    +
       // Override parameter setters from parent trait for Java API compatibility.
     
       // Parameters from TreeRegressorParams:
    @@ -85,8 +86,11 @@ final class RandomForestRegressor
           oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
         RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures)
       }
    +
    +  override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
     }
     
    +@Experimental
     object RandomForestRegressor {
       /** Accessor for supported impurity settings: variance */
       final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
    @@ -97,15 +101,14 @@ object RandomForestRegressor {
     }
     
     /**
    - * :: AlphaComponent ::
    - *
    + * :: Experimental ::
      * [[http://en.wikipedia.org/wiki/Random_forest  Random Forest]] model for regression.
      * It supports both continuous and categorical features.
      * @param _trees  Decision trees in the ensemble.
      */
    -@AlphaComponent
    +@Experimental
     final class RandomForestRegressionModel private[ml] (
    -    override val parent: RandomForestRegressor,
    +    override val uid: String,
         private val _trees: Array[DecisionTreeRegressionModel])
       extends PredictionModel[Vector, RandomForestRegressionModel]
       with TreeEnsembleModel with Serializable {
    @@ -128,7 +131,7 @@ final class RandomForestRegressionModel private[ml] (
       }
     
       override def copy(extra: ParamMap): RandomForestRegressionModel = {
    -    copyValues(new RandomForestRegressionModel(parent, _trees), extra)
    +    copyValues(new RandomForestRegressionModel(uid, _trees), extra)
       }
     
       override def toString: String = {
    @@ -151,9 +154,9 @@ private[ml] object RandomForestRegressionModel {
         require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
           s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
         val newTrees = oldModel.trees.map { tree =>
    -      // parent, fittingParamMap for each tree is null since there are no good ways to set these.
    +      // parent for each tree is null since there is no good way to set this.
           DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
         }
    -    new RandomForestRegressionModel(parent, newTrees)
    +    new RandomForestRegressionModel(parent.uid, newTrees)
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
    index d2dec0c76cb12..4242154be14ce 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
    @@ -17,14 +17,16 @@
     
     package org.apache.spark.ml.tree
     
    +import org.apache.spark.annotation.DeveloperApi
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
       Node => OldNode, Predict => OldPredict}
     
    -
     /**
    + * :: DeveloperApi ::
      * Decision tree node interface.
      */
    +@DeveloperApi
     sealed abstract class Node extends Serializable {
     
       // TODO: Add aggregate stats (once available).  This will happen after we move the DecisionTree
    @@ -89,10 +91,12 @@ private[ml] object Node {
     }
     
     /**
    + * :: DeveloperApi ::
      * Decision tree leaf node.
      * @param prediction  Prediction this node makes
      * @param impurity  Impurity measure at this node (for training data)
      */
    +@DeveloperApi
     final class LeafNode private[ml] (
         override val prediction: Double,
         override val impurity: Double) extends Node {
    @@ -118,6 +122,7 @@ final class LeafNode private[ml] (
     }
     
     /**
    + * :: DeveloperApi ::
      * Internal Decision Tree node.
      * @param prediction  Prediction this node would make if it were a leaf node
      * @param impurity  Impurity measure at this node (for training data)
    @@ -127,6 +132,7 @@ final class LeafNode private[ml] (
      * @param rightChild  Right-hand child node
      * @param split  Information about the test used to split to the left or right child.
      */
    +@DeveloperApi
     final class InternalNode private[ml] (
         override val prediction: Double,
         override val impurity: Double,
    @@ -153,9 +159,9 @@ final class InternalNode private[ml] (
     
       override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
         val prefix: String = " " * indentFactor
    -    prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" +
    +    prefix + s"If (${InternalNode.splitToString(split, left = true)})\n" +
           leftChild.subtreeToString(indentFactor + 1) +
    -      prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" +
    +      prefix + s"Else (${InternalNode.splitToString(split, left = false)})\n" +
           rightChild.subtreeToString(indentFactor + 1)
       }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
    index 90f1d052764d3..7acdeeee72d23 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
    @@ -17,15 +17,18 @@
     
     package org.apache.spark.ml.tree
     
    +import org.apache.spark.annotation.DeveloperApi
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
     import org.apache.spark.mllib.tree.model.{Split => OldSplit}
     
     
     /**
    + * :: DeveloperApi ::
      * Interface for a "Split," which specifies a test made at a decision tree node
      * to choose the left or right path.
      */
    +@DeveloperApi
     sealed trait Split extends Serializable {
     
       /** Index of feature which this split tests */
    @@ -52,12 +55,14 @@ private[tree] object Split {
     }
     
     /**
    + * :: DeveloperApi ::
      * Split which tests a categorical feature.
      * @param featureIndex  Index of the feature to test
      * @param _leftCategories  If the feature value is in this set of categories, then the split goes
      *                         left. Otherwise, it goes right.
      * @param numCategories  Number of categories for this feature.
      */
    +@DeveloperApi
     final class CategoricalSplit private[ml] (
         override val featureIndex: Int,
         _leftCategories: Array[Double],
    @@ -125,11 +130,13 @@ final class CategoricalSplit private[ml] (
     }
     
     /**
    + * :: DeveloperApi ::
      * Split which tests a continuous feature.
      * @param featureIndex  Index of the feature to test
      * @param threshold  If the feature value is <= this threshold, then the split goes left.
      *                    Otherwise, it goes right.
      */
    +@DeveloperApi
     final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
       extends Split {
     
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
    index 1929f9d02156e..22873909c33fa 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
    @@ -17,6 +17,7 @@
     
     package org.apache.spark.ml.tree
     
    +import org.apache.spark.mllib.linalg.{Vectors, Vector}
     
     /**
      * Abstraction for Decision Tree models.
    @@ -70,6 +71,10 @@ private[ml] trait TreeEnsembleModel {
       /** Weights for each tree, zippable with [[trees]] */
       def treeWeights: Array[Double]
     
    +  /** Weights used by the python wrappers. */
    +  // Note: An array cannot be returned directly due to serialization problems.
    +  private[spark] def javaTreeWeights: Vector = Vectors.dense(treeWeights)
    +
       /** Summary of the model */
       override def toString: String = {
         // Implementing classes should generally override this method to be more descriptive.
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
    index 816fcedf2efb3..a0c5238d966bf 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
    @@ -17,7 +17,6 @@
     
     package org.apache.spark.ml.tree
     
    -import org.apache.spark.annotation.DeveloperApi
     import org.apache.spark.ml.PredictorParams
     import org.apache.spark.ml.param._
     import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed}
    @@ -26,12 +25,10 @@ import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldG
     import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
     
     /**
    - * :: DeveloperApi ::
      * Parameters for Decision Tree-based algorithms.
      *
      * Note: Marked as private and DeveloperApi since this may be made public in the future.
      */
    -@DeveloperApi
     private[ml] trait DecisionTreeParams extends PredictorParams {
     
       /**
    @@ -265,12 +262,10 @@ private[ml] object TreeRegressorParams {
     }
     
     /**
    - * :: DeveloperApi ::
      * Parameters for Decision Tree-based ensemble algorithms.
      *
      * Note: Marked as private and DeveloperApi since this may be made public in the future.
      */
    -@DeveloperApi
     private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
     
       /**
    @@ -307,12 +302,10 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
     }
     
     /**
    - * :: DeveloperApi ::
      * Parameters for Random Forest algorithms.
      *
      * Note: Marked as private and DeveloperApi since this may be made public in the future.
      */
    -@DeveloperApi
     private[ml] trait RandomForestParams extends TreeEnsembleParams {
     
       /**
    @@ -377,12 +370,10 @@ private[ml] object RandomForestParams {
     }
     
     /**
    - * :: DeveloperApi ::
      * Parameters for Gradient-Boosted Tree algorithms.
      *
      * Note: Marked as private and DeveloperApi since this may be made public in the future.
      */
    -@DeveloperApi
     private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
     
       /**
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
    index ac0d1fed84b2e..e2444ab65b43b 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
    @@ -20,9 +20,11 @@ package org.apache.spark.ml.tuning
     import com.github.fommil.netlib.F2jBLAS
     
     import org.apache.spark.Logging
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml._
    +import org.apache.spark.ml.evaluation.Evaluator
     import org.apache.spark.ml.param._
    +import org.apache.spark.ml.util.Identifiable
     import org.apache.spark.mllib.util.MLUtils
     import org.apache.spark.sql.DataFrame
     import org.apache.spark.sql.types.StructType
    @@ -77,11 +79,14 @@ private[ml] trait CrossValidatorParams extends Params {
     }
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * K-fold cross validation.
      */
    -@AlphaComponent
    -class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging {
    +@Experimental
    +class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel]
    +  with CrossValidatorParams with Logging {
    +
    +  def this() = this(Identifiable.randomUID("cv"))
     
       private val f2jBLAS = new F2jBLAS
     
    @@ -97,12 +102,6 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
       /** @group setParam */
       def setNumFolds(value: Int): this.type = set(numFolds, value)
     
    -  override def validateParams(paramMap: ParamMap): Unit = {
    -    getEstimatorParamMaps.foreach { eMap =>
    -      getEstimator.validateParams(eMap ++ paramMap)
    -    }
    -  }
    -
       override def fit(dataset: DataFrame): CrossValidatorModel = {
         val schema = dataset.schema
         transformSchema(schema, logging = true)
    @@ -136,26 +135,46 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
         logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
         logInfo(s"Best cross-validation metric: $bestMetric.")
         val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
    -    copyValues(new CrossValidatorModel(this, bestModel))
    +    copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
       }
     
       override def transformSchema(schema: StructType): StructType = {
         $(estimator).transformSchema(schema)
       }
    +
    +  override def validateParams(): Unit = {
    +    super.validateParams()
    +    val est = $(estimator)
    +    for (paramMap <- $(estimatorParamMaps)) {
    +      est.copy(paramMap).validateParams()
    +    }
    +  }
    +
    +  override def copy(extra: ParamMap): CrossValidator = {
    +    val copied = defaultCopy(extra).asInstanceOf[CrossValidator]
    +    if (copied.isDefined(estimator)) {
    +      copied.setEstimator(copied.getEstimator.copy(extra))
    +    }
    +    if (copied.isDefined(evaluator)) {
    +      copied.setEvaluator(copied.getEvaluator.copy(extra))
    +    }
    +    copied
    +  }
     }
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Model from k-fold cross validation.
      */
    -@AlphaComponent
    +@Experimental
     class CrossValidatorModel private[ml] (
    -    override val parent: CrossValidator,
    -    val bestModel: Model[_])
    +    override val uid: String,
    +    val bestModel: Model[_],
    +    val avgMetrics: Array[Double])
       extends Model[CrossValidatorModel] with CrossValidatorParams {
     
    -  override def validateParams(paramMap: ParamMap): Unit = {
    -    bestModel.validateParams(paramMap)
    +  override def validateParams(): Unit = {
    +    bestModel.validateParams()
       }
     
       override def transform(dataset: DataFrame): DataFrame = {
    @@ -166,4 +185,12 @@ class CrossValidatorModel private[ml] (
       override def transformSchema(schema: StructType): StructType = {
         bestModel.transformSchema(schema)
       }
    +
    +  override def copy(extra: ParamMap): CrossValidatorModel = {
    +    val copied = new CrossValidatorModel(
    +      uid,
    +      bestModel.copy(extra).asInstanceOf[Model[_]],
    +      avgMetrics.clone())
    +    copyValues(copied, extra)
    +  }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala
    index dafe73d82c00a..98a8f0330ca45 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala
    @@ -20,14 +20,14 @@ package org.apache.spark.ml.tuning
     import scala.annotation.varargs
     import scala.collection.mutable
     
    -import org.apache.spark.annotation.AlphaComponent
    +import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.param._
     
     /**
    - * :: AlphaComponent ::
    + * :: Experimental ::
      * Builder for a param grid used in grid search-based model selection.
      */
    -@AlphaComponent
    +@Experimental
     class ParamGridBuilder {
     
       private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]]
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
    index 8a56748ab0a02..ddd34a54503a6 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
    @@ -19,15 +19,26 @@ package org.apache.spark.ml.util
     
     import java.util.UUID
     
    +
     /**
    - * Object with a unique id.
    + * Trait for an object with an immutable unique ID that identifies itself and its derivatives.
      */
    -private[ml] trait Identifiable extends Serializable {
    +private[spark] trait Identifiable {
    +
    +  /**
    +   * An immutable unique ID for the object and its derivatives.
    +   */
    +  val uid: String
    +
    +  override def toString: String = uid
    +}
    +
    +private[spark] object Identifiable {
     
       /**
    -   * A unique id for the object. The default implementation concatenates the class name, "_", and 8
    -   * random hex chars.
    +   * Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars.
        */
    -  private[ml] val uid: String =
    -    this.getClass.getSimpleName + "_" + UUID.randomUUID().toString.take(8)
    +  def randomUID(prefix: String): String = {
    +    prefix + "_" + UUID.randomUUID().toString.takeRight(12)
    +  }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
    index 56075c9a6b39f..2a1db90f2ca2b 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
    @@ -19,18 +19,14 @@ package org.apache.spark.ml.util
     
     import scala.collection.immutable.HashMap
     
    -import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.attribute._
     import org.apache.spark.sql.types.StructField
     
     
     /**
    - * :: Experimental ::
    - *
      * Helper utilities for tree-based algorithms
      */
    -@Experimental
    -object MetadataUtils {
    +private[spark] object MetadataUtils {
     
       /**
        * Examine a schema to identify the number of classes in a label column.
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
    index 11592b77eb356..76f651488aef9 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
    @@ -17,15 +17,13 @@
     
     package org.apache.spark.ml.util
     
    -import org.apache.spark.annotation.DeveloperApi
     import org.apache.spark.sql.types.{DataType, StructField, StructType}
     
    +
     /**
    - * :: DeveloperApi ::
      * Utils for handling schemas.
      */
    -@DeveloperApi
    -object SchemaUtils {
    +private[spark] object SchemaUtils {
     
       // TODO: Move the utility methods to SQL.
     
    @@ -34,10 +32,15 @@ object SchemaUtils {
        * @param colName  column name
        * @param dataType  required column data type
        */
    -  def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = {
    +  def checkColumnType(
    +      schema: StructType,
    +      colName: String,
    +      dataType: DataType,
    +      msg: String = ""): Unit = {
         val actualDataType = schema(colName).dataType
    +    val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
         require(actualDataType.equals(dataType),
    -      s"Column $colName must be of type $dataType but was actually $actualDataType.")
    +      s"Column $colName must be of type $dataType but was actually $actualDataType.$message")
       }
     
       /**
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala
    new file mode 100644
    index 0000000000000..bc6041b221732
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala
    @@ -0,0 +1,32 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.api.python
    +
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.mllib.clustering.PowerIterationClusteringModel
    +
    +/**
    + * A Wrapper of PowerIterationClusteringModel to provide helper method for Python
    + */
    +private[python] class PowerIterationClusteringModelWrapper(model: PowerIterationClusteringModel)
    +  extends PowerIterationClusteringModel(model.k, model.assignments) {
    +
    +  def getAssignments: RDD[Array[Any]] = {
    +    model.assignments.map(x => Array(x.id, x.cluster))
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
    index f4c477596557f..e628059c4af8e 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
    @@ -28,6 +28,7 @@ import scala.reflect.ClassTag
     
     import net.razorvine.pickle._
     
    +import org.apache.spark.SparkContext
     import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
     import org.apache.spark.api.python.SerDeUtil
     import org.apache.spark.mllib.classification._
    @@ -43,13 +44,15 @@ import org.apache.spark.mllib.regression._
     import org.apache.spark.mllib.stat.correlation.CorrelationNames
     import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
     import org.apache.spark.mllib.stat.test.ChiSqTestResult
    -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
    +import org.apache.spark.mllib.stat.{
    +  KernelDensity, MultivariateStatisticalSummary, Statistics}
     import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy}
     import org.apache.spark.mllib.tree.impurity._
     import org.apache.spark.mllib.tree.loss.Losses
     import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel}
     import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
     import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.mllib.util.LinearDataGenerator
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
     import org.apache.spark.storage.StorageLevel
    @@ -73,6 +76,15 @@ private[python] class PythonMLLibAPI extends Serializable {
           minPartitions: Int): JavaRDD[LabeledPoint] =
         MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions)
     
    +  /**
    +   * Loads and serializes vectors saved with `RDD#saveAsTextFile`.
    +   * @param jsc Java SparkContext
    +   * @param path file or directory path in any Hadoop-supported file system URI
    +   * @return serialized vectors in a RDD
    +   */
    +  def loadVectors(jsc: JavaSparkContext, path: String): RDD[Vector] =
    +    MLUtils.loadVectors(jsc.sc, path)
    +
       private def trainRegressionModel(
           learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
           data: JavaRDD[LabeledPoint],
    @@ -276,7 +288,7 @@ private[python] class PythonMLLibAPI extends Serializable {
       /**
        * Java stub for NaiveBayes.train()
        */
    -  def trainNaiveBayes(
    +  def trainNaiveBayesModel(
           data: JavaRDD[LabeledPoint],
           lambda: Double): JList[Object] = {
         val model = NaiveBayes.train(data.rdd, lambda)
    @@ -344,29 +356,41 @@ private[python] class PythonMLLibAPI extends Serializable {
        * Java stub for Python mllib GaussianMixture.run()
        * Returns a list containing weights, mean and covariance of each mixture component.
        */
    -  def trainGaussianMixture(
    -      data: JavaRDD[Vector], 
    -      k: Int, 
    -      convergenceTol: Double, 
    +  def trainGaussianMixtureModel(
    +      data: JavaRDD[Vector],
    +      k: Int,
    +      convergenceTol: Double,
           maxIterations: Int,
    -      seed: java.lang.Long): JList[Object] = {
    +      seed: java.lang.Long,
    +      initialModelWeights: java.util.ArrayList[Double],
    +      initialModelMu: java.util.ArrayList[Vector],
    +      initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = {
         val gmmAlg = new GaussianMixture()
           .setK(k)
           .setConvergenceTol(convergenceTol)
           .setMaxIterations(maxIterations)
     
    +    if (initialModelWeights != null && initialModelMu != null && initialModelSigma != null) {
    +      val gaussians = initialModelMu.asScala.toSeq.zip(initialModelSigma.asScala.toSeq).map {
    +        case (x, y) => new MultivariateGaussian(x.asInstanceOf[Vector], y.asInstanceOf[Matrix])
    +      }
    +      val initialModel = new GaussianMixtureModel(
    +        initialModelWeights.asScala.toArray, gaussians.toArray)
    +      gmmAlg.setInitialModel(initialModel)
    +    }
    +
         if (seed != null) gmmAlg.setSeed(seed)
     
         try {
           val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
           var wt = ArrayBuffer.empty[Double]
    -      var mu = ArrayBuffer.empty[Vector]      
    +      var mu = ArrayBuffer.empty[Vector]
           var sigma = ArrayBuffer.empty[Matrix]
           for (i <- 0 until model.k) {
               wt += model.weights(i)
               mu += model.gaussians(i).mu
               sigma += model.gaussians(i).sigma
    -      }    
    +      }
           List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
         } finally {
           data.rdd.unpersist(blocking = false)
    @@ -380,18 +404,45 @@ private[python] class PythonMLLibAPI extends Serializable {
           data: JavaRDD[Vector],
           wt: Vector,
           mu: Array[Object],
    -      si: Array[Object]):  RDD[Vector]  = {
    +      si: Array[Object]): RDD[Vector] = {
     
           val weight = wt.toArray
           val mean = mu.map(_.asInstanceOf[DenseVector])
           val sigma = si.map(_.asInstanceOf[DenseMatrix])
           val gaussians = Array.tabulate(weight.length){
             i => new MultivariateGaussian(mean(i), sigma(i))
    -      }      
    +      }
           val model = new GaussianMixtureModel(weight, gaussians)
           model.predictSoft(data).map(Vectors.dense)
       }
     
    +  /**
    +   * Java stub for Python mllib PowerIterationClustering.run(). This stub returns a
    +   * handle to the Java object instead of the content of the Java object.  Extra care
    +   * needs to be taken in the Python code to ensure it gets freed on exit; see the
    +   * Py4J documentation.
    +   * @param data an RDD of (i, j, s,,ij,,) tuples representing the affinity matrix.
    +   * @param k number of clusters.
    +   * @param maxIterations maximum number of iterations of the power iteration loop.
    +   * @param initMode the initialization mode. This can be either "random" to use
    +   *                 a random vector as vertex properties, or "degree" to use
    +   *                 normalized sum similarities. Default: random.
    +   */
    +  def trainPowerIterationClusteringModel(
    +      data: JavaRDD[Vector],
    +      k: Int,
    +      maxIterations: Int,
    +      initMode: String): PowerIterationClusteringModel = {
    +
    +    val pic = new PowerIterationClustering()
    +      .setK(k)
    +      .setMaxIterations(maxIterations)
    +      .setInitializationMode(initMode)
    +
    +    val model = pic.run(data.rdd.map(v => (v(0).toLong, v(1).toLong, v(2))))
    +    new PowerIterationClusteringModelWrapper(model)
    +  }
    +
       /**
        * Java stub for Python mllib ALS.train().  This stub returns a handle
        * to the Java object instead of the content of the Java object.  Extra care
    @@ -416,7 +467,7 @@ private[python] class PythonMLLibAPI extends Serializable {
     
         if (seed != null) als.setSeed(seed)
     
    -    val model =  als.run(ratingsJRDD.rdd)
    +    val model = als.run(ratingsJRDD.rdd)
         new MatrixFactorizationModelWrapper(model)
       }
     
    @@ -447,7 +498,7 @@ private[python] class PythonMLLibAPI extends Serializable {
     
         if (seed != null) als.setSeed(seed)
     
    -    val model =  als.run(ratingsJRDD.rdd)
    +    val model = als.run(ratingsJRDD.rdd)
         new MatrixFactorizationModelWrapper(model)
       }
     
    @@ -482,7 +533,7 @@ private[python] class PythonMLLibAPI extends Serializable {
       def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = {
         new Normalizer(p).transform(rdd)
       }
    -  
    +
       /**
        * Java stub for StandardScaler.fit(). This stub returns a
        * handle to the Java object instead of the content of the Java object.
    @@ -506,6 +557,16 @@ private[python] class PythonMLLibAPI extends Serializable {
         new ChiSqSelector(numTopFeatures).fit(data.rdd)
       }
     
    +  /**
    +   * Java stub for PCA.fit(). This stub returns a
    +   * handle to the Java object instead of the content of the Java object.
    +   * Extra care needs to be taken in the Python code to ensure it gets freed on
    +   * exit; see the Py4J documentation.
    +   */
    +  def fitPCA(k: Int, data: JavaRDD[Vector]): PCAModel = {
    +    new PCA(k).fit(data.rdd)
    +  }
    +
       /**
        * Java stub for IDF.fit(). This stub returns a
        * handle to the Java object instead of the content of the Java object.
    @@ -529,7 +590,7 @@ private[python] class PythonMLLibAPI extends Serializable {
        * @param seed initial seed for random generator
        * @return A handle to java Word2VecModelWrapper instance at python side
        */
    -  def trainWord2Vec(
    +  def trainWord2VecModel(
           dataJRDD: JavaRDD[java.util.ArrayList[String]],
           vectorSize: Int,
           learningRate: Double,
    @@ -581,6 +642,8 @@ private[python] class PythonMLLibAPI extends Serializable {
         def getVectors: JMap[String, JList[Float]] = {
           model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
         }
    +
    +    def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
       }
     
       /**
    @@ -673,12 +736,14 @@ private[python] class PythonMLLibAPI extends Serializable {
           lossStr: String,
           numIterations: Int,
           learningRate: Double,
    -      maxDepth: Int): GradientBoostedTreesModel = {
    +      maxDepth: Int,
    +      maxBins: Int): GradientBoostedTreesModel = {
         val boostingStrategy = BoostingStrategy.defaultParams(algoStr)
         boostingStrategy.setLoss(Losses.fromString(lossStr))
         boostingStrategy.setNumIterations(numIterations)
         boostingStrategy.setLearningRate(learningRate)
         boostingStrategy.treeStrategy.setMaxDepth(maxDepth)
    +    boostingStrategy.treeStrategy.setMaxBins(maxBins)
         boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap
     
         val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
    @@ -689,6 +754,14 @@ private[python] class PythonMLLibAPI extends Serializable {
         }
       }
     
    +  def elementwiseProductVector(scalingVector: Vector, vector: Vector): Vector = {
    +    new ElementwiseProduct(scalingVector).transform(vector)
    +  }
    +
    +  def elementwiseProductVector(scalingVector: Vector, vector: JavaRDD[Vector]): JavaRDD[Vector] = {
    +    new ElementwiseProduct(scalingVector).transform(vector)
    +  }
    +
       /**
        * Java stub for mllib Statistics.colStats(X: RDD[Vector]).
        * TODO figure out return type.
    @@ -933,7 +1006,60 @@ private[python] class PythonMLLibAPI extends Serializable {
           r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any])))
       }
     
    +  /**
    +   * Java stub for the estimate method of KernelDensity
    +   */
    +  def estimateKernelDensity(
    +      sample: JavaRDD[Double],
    +      bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = {
    +    new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate(
    +      points.asScala.toArray)
    +  }
    +
    +  /**
    +   * Java stub for the update method of StreamingKMeansModel.
    +   */
    +  def updateStreamingKMeansModel(
    +      clusterCenters: JList[Vector],
    +      clusterWeights: JList[Double],
    +      data: JavaRDD[Vector],
    +      decayFactor: Double,
    +      timeUnit: String): JList[Object] = {
    +    val model = new StreamingKMeansModel(
    +      clusterCenters.asScala.toArray, clusterWeights.asScala.toArray)
    +        .update(data, decayFactor, timeUnit)
    +      List[AnyRef](model.clusterCenters, Vectors.dense(model.clusterWeights)).asJava
    +  }
     
    +  /**
    +   * Wrapper around the generateLinearInput method of LinearDataGenerator.
    +   */
    +  def generateLinearInputWrapper(
    +      intercept: Double,
    +      weights: JList[Double],
    +      xMean: JList[Double],
    +      xVariance: JList[Double],
    +      nPoints: Int,
    +      seed: Int,
    +      eps: Double): Array[LabeledPoint] = {
    +    LinearDataGenerator.generateLinearInput(
    +      intercept, weights.asScala.toArray, xMean.asScala.toArray,
    +      xVariance.asScala.toArray, nPoints, seed, eps).toArray
    +  }
    +
    +  /**
    +   * Wrapper around the generateLinearRDD method of LinearDataGenerator.
    +   */
    +  def generateLinearRDDWrapper(
    +      sc: JavaSparkContext,
    +      nexamples: Int,
    +      nfeatures: Int,
    +      eps: Double,
    +      nparts: Int,
    +      intercept: Double): JavaRDD[LabeledPoint] = {
    +    LinearDataGenerator.generateLinearRDD(
    +      sc, nexamples, nfeatures, eps, nparts, intercept)
    +  }
     }
     
     /**
    @@ -1230,7 +1356,7 @@ private[spark] object SerDe extends Serializable {
       }
     
       /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
    -  def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]]  = {
    +  def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
         rdd.map(x => Array(x._1, x._2))
       }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
    index bd2e9079ce1ae..2df4d21e8cd55 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
    @@ -163,7 +163,7 @@ class LogisticRegressionModel (
       override protected def formatVersion: String = "1.0"
     
       override def toString: String = {
    -    s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.get}"
    +    s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}"
       }
     }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
    index c9b3ff0172e2e..9e379d7d74b2f 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
    @@ -21,21 +21,16 @@ import java.lang.{Iterable => JIterable}
     
     import scala.collection.JavaConverters._
     
    -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
    -import breeze.numerics.{exp => brzExp, log => brzLog}
    -
     import org.json4s.JsonDSL._
     import org.json4s.jackson.JsonMethods._
    -import org.json4s.{DefaultFormats, JValue}
     
     import org.apache.spark.{Logging, SparkContext, SparkException}
    -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
    +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector}
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.util.{Loader, Saveable}
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.{DataFrame, SQLContext}
     
    -
     /**
      * Model for Naive Bayes Classifiers.
      *
    @@ -43,7 +38,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
      * @param pi log of class priors, whose dimension is C, number of labels
      * @param theta log of class conditional probabilities, whose dimension is C-by-D,
      *              where D is number of features
    - * @param modelType The type of NB model to fit  can be "Multinomial" or "Bernoulli"
    + * @param modelType The type of NB model to fit  can be "multinomial" or "bernoulli"
      */
     class NaiveBayesModel private[mllib] (
         val labels: Array[Double],
    @@ -52,8 +47,13 @@ class NaiveBayesModel private[mllib] (
         val modelType: String)
       extends ClassificationModel with Serializable with Saveable {
     
    +  import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes}
    +
    +  private val piVector = new DenseVector(pi)
    +  private val thetaMatrix = new DenseMatrix(labels.length, theta(0).length, theta.flatten, true)
    +
       private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
    -    this(labels, pi, theta, "Multinomial")
    +    this(labels, pi, theta, NaiveBayes.Multinomial)
     
       /** A Java-friendly constructor that takes three Iterable parameters. */
       private[mllib] def this(
    @@ -62,20 +62,24 @@ class NaiveBayesModel private[mllib] (
           theta: JIterable[JIterable[Double]]) =
         this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
     
    -  private val brzPi = new BDV[Double](pi)
    -  private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
    +  require(supportedModelTypes.contains(modelType),
    +    s"Invalid modelType $modelType. Supported modelTypes are $supportedModelTypes.")
     
       // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
    -  // This precomputes log(1.0 - exp(theta)) and its sum  which are used for the  linear algebra
    +  // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
       // application of this condition (in predict function).
    -  private val (brzNegTheta, brzNegThetaSum) = modelType match {
    -    case "Multinomial" => (None, None)
    -    case "Bernoulli" =>
    -      val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
    -      (Option(negTheta), Option(brzSum(negTheta, Axis._1)))
    +  private val (thetaMinusNegTheta, negThetaSum) = modelType match {
    +    case Multinomial => (None, None)
    +    case Bernoulli =>
    +      val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
    +      val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0})
    +      val thetaMinusNegTheta = thetaMatrix.map { value =>
    +        value - math.log(1.0 - math.exp(value))
    +      }
    +      (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
         case _ =>
           // This should never happen.
    -      throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
    +      throw new UnknownError(s"Invalid modelType: $modelType.")
       }
     
       override def predict(testData: RDD[Vector]): RDD[Double] = {
    @@ -88,17 +92,71 @@ class NaiveBayesModel private[mllib] (
     
       override def predict(testData: Vector): Double = {
         modelType match {
    -      case "Multinomial" =>
    -        labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
    -      case "Bernoulli" =>
    -        labels (brzArgmax (brzPi +
    -          (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
    -      case _ =>
    -        // This should never happen.
    -        throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
    +      case Multinomial =>
    +        labels(multinomialCalculation(testData).argmax)
    +      case Bernoulli =>
    +        labels(bernoulliCalculation(testData).argmax)
    +    }
    +  }
    +
    +  /**
    +   * Predict values for the given data set using the model trained.
    +   *
    +   * @param testData RDD representing data points to be predicted
    +   * @return an RDD[Vector] where each entry contains the predicted posterior class probabilities,
    +   *         in the same order as class labels
    +   */
    +  def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = {
    +    val bcModel = testData.context.broadcast(this)
    +    testData.mapPartitions { iter =>
    +      val model = bcModel.value
    +      iter.map(model.predictProbabilities)
    +    }
    +  }
    +
    +  /**
    +   * Predict posterior class probabilities for a single data point using the model trained.
    +   *
    +   * @param testData array representing a single data point
    +   * @return predicted posterior class probabilities from the trained model,
    +   *         in the same order as class labels
    +   */
    +  def predictProbabilities(testData: Vector): Vector = {
    +    modelType match {
    +      case Multinomial =>
    +        posteriorProbabilities(multinomialCalculation(testData))
    +      case Bernoulli =>
    +        posteriorProbabilities(bernoulliCalculation(testData))
         }
       }
     
    +  private def multinomialCalculation(testData: Vector) = {
    +    val prob = thetaMatrix.multiply(testData)
    +    BLAS.axpy(1.0, piVector, prob)
    +    prob
    +  }
    +
    +  private def bernoulliCalculation(testData: Vector) = {
    +    testData.foreachActive((_, value) =>
    +      if (value != 0.0 && value != 1.0) {
    +        throw new SparkException(
    +          s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
    +      }
    +    )
    +    val prob = thetaMinusNegTheta.get.multiply(testData)
    +    BLAS.axpy(1.0, piVector, prob)
    +    BLAS.axpy(1.0, negThetaSum.get, prob)
    +    prob
    +  }
    +
    +  private def posteriorProbabilities(logProb: DenseVector) = {
    +    val logProbArray = logProb.toArray
    +    val maxLog = logProbArray.max
    +    val scaledProbs = logProbArray.map(lp => math.exp(lp - maxLog))
    +    val probSum = scaledProbs.sum
    +    new DenseVector(scaledProbs.map(_ / probSum))
    +  }
    +
       override def save(sc: SparkContext, path: String): Unit = {
         val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType)
         NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
    @@ -137,17 +195,17 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
     
           // Create Parquet data.
           val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
    -      dataRDD.saveAsParquetFile(dataPath(path))
    +      dataRDD.write.parquet(dataPath(path))
         }
     
         def load(sc: SparkContext, path: String): NaiveBayesModel = {
           val sqlContext = new SQLContext(sc)
           // Load Parquet data.
    -      val dataRDD = sqlContext.parquetFile(dataPath(path))
    +      val dataRDD = sqlContext.read.parquet(dataPath(path))
           // Check schema explicitly since erasure makes it hard to use match-case for checking.
           checkSchema[Data](dataRDD.schema)
           val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1)
    -      assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
    +      assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
           val data = dataArray(0)
           val labels = data.getAs[Seq[Double]](0).toArray
           val pi = data.getAs[Seq[Double]](1).toArray
    @@ -183,17 +241,17 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
     
           // Create Parquet data.
           val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
    -      dataRDD.saveAsParquetFile(dataPath(path))
    +      dataRDD.write.parquet(dataPath(path))
         }
     
         def load(sc: SparkContext, path: String): NaiveBayesModel = {
           val sqlContext = new SQLContext(sc)
           // Load Parquet data.
    -      val dataRDD = sqlContext.parquetFile(dataPath(path))
    +      val dataRDD = sqlContext.read.parquet(dataPath(path))
           // Check schema explicitly since erasure makes it hard to use match-case for checking.
           checkSchema[Data](dataRDD.schema)
           val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
    -      assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
    +      assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
           val data = dataArray(0)
           val labels = data.getAs[Seq[Double]](0).toArray
           val pi = data.getAs[Seq[Double]](1).toArray
    @@ -220,16 +278,16 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
             s"($loadedClassName, $version).  Supported:\n" +
             s"  ($classNameV1_0, 1.0)")
         }
    -    assert(model.pi.size == numClasses,
    +    assert(model.pi.length == numClasses,
           s"NaiveBayesModel.load expected $numClasses classes," +
    -        s" but class priors vector pi had ${model.pi.size} elements")
    -    assert(model.theta.size == numClasses,
    +        s" but class priors vector pi had ${model.pi.length} elements")
    +    assert(model.theta.length == numClasses,
           s"NaiveBayesModel.load expected $numClasses classes," +
    -        s" but class conditionals array theta had ${model.theta.size} elements")
    -    assert(model.theta.forall(_.size == numFeatures),
    +        s" but class conditionals array theta had ${model.theta.length} elements")
    +    assert(model.theta.forall(_.length == numFeatures),
           s"NaiveBayesModel.load expected $numFeatures features," +
             s" but class conditionals array theta had elements of size:" +
    -        s" ${model.theta.map(_.size).mkString(",")}")
    +        s" ${model.theta.map(_.length).mkString(",")}")
         model
       }
     }
    @@ -247,9 +305,11 @@ class NaiveBayes private (
         private var lambda: Double,
         private var modelType: String) extends Serializable with Logging {
     
    -  def this(lambda: Double) = this(lambda, "Multinomial")
    +  import NaiveBayes.{Bernoulli, Multinomial}
    +
    +  def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
     
    -  def this() = this(1.0, "Multinomial")
    +  def this() = this(1.0, NaiveBayes.Multinomial)
     
       /** Set the smoothing parameter. Default: 1.0. */
       def setLambda(lambda: Double): NaiveBayes = {
    @@ -262,12 +322,11 @@ class NaiveBayes private (
     
       /**
        * Set the model type using a string (case-sensitive).
    -   * Supported options: "Multinomial" and "Bernoulli".
    -   * (default: Multinomial)
    +   * Supported options: "multinomial" (default) and "bernoulli".
        */
    -  def setModelType(modelType:String): NaiveBayes = {
    +  def setModelType(modelType: String): NaiveBayes = {
         require(NaiveBayes.supportedModelTypes.contains(modelType),
    -      s"NaiveBayes was created with an unknown ModelType: $modelType")
    +      s"NaiveBayes was created with an unknown modelType: $modelType.")
         this.modelType = modelType
         this
       }
    @@ -283,30 +342,46 @@ class NaiveBayes private (
       def run(data: RDD[LabeledPoint]): NaiveBayesModel = {
         val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
           val values = v match {
    -        case SparseVector(size, indices, values) =>
    -          values
    -        case DenseVector(values) =>
    -          values
    +        case sv: SparseVector => sv.values
    +        case dv: DenseVector => dv.values
           }
           if (!values.forall(_ >= 0.0)) {
             throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")
           }
         }
     
    +    val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
    +      val values = v match {
    +        case sv: SparseVector => sv.values
    +        case dv: DenseVector => dv.values
    +      }
    +      if (!values.forall(v => v == 0.0 || v == 1.0)) {
    +        throw new SparkException(
    +          s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
    +      }
    +    }
    +
         // Aggregates term frequencies per label.
         // TODO: Calling combineByKey and collect creates two stages, we can implement something
         // TODO: similar to reduceByKeyLocally to save one stage.
    -    val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
    +    val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)](
           createCombiner = (v: Vector) => {
    -        requireNonnegativeValues(v)
    -        (1L, v.toBreeze.toDenseVector)
    +        if (modelType == Bernoulli) {
    +          requireZeroOneBernoulliValues(v)
    +        } else {
    +          requireNonnegativeValues(v)
    +        }
    +        (1L, v.copy.toDense)
           },
    -      mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
    +      mergeValue = (c: (Long, DenseVector), v: Vector) => {
             requireNonnegativeValues(v)
    -        (c._1 + 1L, c._2 += v.toBreeze)
    +        BLAS.axpy(1.0, v, c._2)
    +        (c._1 + 1L, c._2)
           },
    -      mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
    -        (c1._1 + c2._1, c1._2 += c2._2)
    +      mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => {
    +        BLAS.axpy(1.0, c2._2, c1._2)
    +        (c1._1 + c2._1, c1._2)
    +      }
         ).collect()
     
         val numLabels = aggregated.length
    @@ -326,11 +401,11 @@ class NaiveBayes private (
           labels(i) = label
           pi(i) = math.log(n + lambda) - piLogDenom
           val thetaLogDenom = modelType match {
    -        case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
    -        case "Bernoulli" => math.log(n + 2.0 * lambda)
    +        case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
    +        case Bernoulli => math.log(n + 2.0 * lambda)
             case _ =>
               // This should never happen.
    -          throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
    +          throw new UnknownError(s"Invalid modelType: $modelType.")
           }
           var j = 0
           while (j < numFeatures) {
    @@ -349,8 +424,14 @@ class NaiveBayes private (
      */
     object NaiveBayes {
     
    +  /** String name for multinomial model type. */
    +  private[classification] val Multinomial: String = "multinomial"
    +
    +  /** String name for Bernoulli model type. */
    +  private[classification] val Bernoulli: String = "bernoulli"
    +
       /* Set of modelTypes that NaiveBayes supports */
    -  private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli")
    +  private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
     
       /**
        * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
    @@ -380,7 +461,7 @@ object NaiveBayes {
        * @param lambda The smoothing parameter
        */
       def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
    -    new NaiveBayes(lambda, "Multinomial").run(input)
    +    new NaiveBayes(lambda, Multinomial).run(input)
       }
     
       /**
    @@ -403,7 +484,7 @@ object NaiveBayes {
        */
       def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
         require(supportedModelTypes.contains(modelType),
    -      s"NaiveBayes was created with an unknown ModelType: $modelType")
    +      s"NaiveBayes was created with an unknown modelType: $modelType.")
         new NaiveBayes(lambda, modelType).run(input)
       }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
    index 33104cf06c6ea..348485560713e 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
    @@ -89,7 +89,7 @@ class SVMModel (
       override protected def formatVersion: String = "1.0"
     
       override def toString: String = {
    -    s"${super.toString}, numClasses = 2, threshold = ${threshold.get}"
    +    s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}"
       }
     }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
    index 3b6790cce47c6..fe09f6b75d28b 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
    @@ -62,7 +62,7 @@ private[classification] object GLMClassificationModel {
     
           // Create Parquet data.
           val data = Data(weights, intercept, threshold)
    -      sc.parallelize(Seq(data), 1).toDF().saveAsParquetFile(Loader.dataPath(path))
    +      sc.parallelize(Seq(data), 1).toDF().write.parquet(Loader.dataPath(path))
         }
     
         /**
    @@ -75,7 +75,7 @@ private[classification] object GLMClassificationModel {
         def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
           val datapath = Loader.dataPath(path)
           val sqlContext = new SQLContext(sc)
    -      val dataRDD = sqlContext.parquetFile(datapath)
    +      val dataRDD = sqlContext.read.parquet(datapath)
           val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
           assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
           val data = dataArray(0)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
    index c88410ac0ff43..e459367333d26 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
    @@ -22,6 +22,7 @@ import scala.collection.mutable.IndexedSeq
     import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV}
     
     import org.apache.spark.annotation.Experimental
    +import org.apache.spark.api.java.JavaRDD
     import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors}
     import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
     import org.apache.spark.mllib.util.MLUtils
    @@ -36,11 +37,11 @@ import org.apache.spark.util.Utils
      * independent Gaussian distributions with associated "mixing" weights
      * specifying each's contribution to the composite.
      *
    - * Given a set of sample points, this class will maximize the log-likelihood 
    - * for a mixture of k Gaussians, iterating until the log-likelihood changes by 
    + * Given a set of sample points, this class will maximize the log-likelihood
    + * for a mixture of k Gaussians, iterating until the log-likelihood changes by
      * less than convergenceTol, or until it has reached the max number of iterations.
      * While this process is generally guaranteed to converge, it is not guaranteed
    - * to find a global optimum.  
    + * to find a global optimum.
      *
      * Note: For high-dimensional data (with many features), this algorithm may perform poorly.
      *       This is due to high-dimensional data (a) making it difficult to cluster at all (based
    @@ -53,24 +54,24 @@ import org.apache.spark.util.Utils
      */
     @Experimental
     class GaussianMixture private (
    -    private var k: Int, 
    -    private var convergenceTol: Double, 
    +    private var k: Int,
    +    private var convergenceTol: Double,
         private var maxIterations: Int,
         private var seed: Long) extends Serializable {
    -  
    +
       /**
        * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01,
        * maxIterations: 100, seed: random}.
        */
       def this() = this(2, 0.01, 100, Utils.random.nextLong())
    -  
    +
       // number of samples per cluster to use when initializing Gaussians
       private val nSamples = 5
    -  
    -  // an initializing GMM can be provided rather than using the 
    +
    +  // an initializing GMM can be provided rather than using the
       // default random starting point
       private var initialModel: Option[GaussianMixtureModel] = None
    -  
    +
       /** Set the initial GMM starting point, bypassing the random initialization.
        *  You must call setK() prior to calling this method, and the condition
        *  (model.k == this.k) must be met; failure will result in an IllegalArgumentException
    @@ -83,37 +84,37 @@ class GaussianMixture private (
         }
         this
       }
    -  
    +
       /** Return the user supplied initial GMM, if supplied */
       def getInitialModel: Option[GaussianMixtureModel] = initialModel
    -  
    +
       /** Set the number of Gaussians in the mixture model.  Default: 2 */
       def setK(k: Int): this.type = {
         this.k = k
         this
       }
    -  
    +
       /** Return the number of Gaussians in the mixture model */
       def getK: Int = k
    -  
    +
       /** Set the maximum number of iterations to run. Default: 100 */
       def setMaxIterations(maxIterations: Int): this.type = {
         this.maxIterations = maxIterations
         this
       }
    -  
    +
       /** Return the maximum number of iterations to run */
       def getMaxIterations: Int = maxIterations
    -  
    +
       /**
    -   * Set the largest change in log-likelihood at which convergence is 
    +   * Set the largest change in log-likelihood at which convergence is
        * considered to have occurred.
        */
       def setConvergenceTol(convergenceTol: Double): this.type = {
         this.convergenceTol = convergenceTol
         this
       }
    -  
    +
       /**
        * Return the largest change in log-likelihood at which convergence is
        * considered to have occurred.
    @@ -132,69 +133,100 @@ class GaussianMixture private (
       /** Perform expectation maximization */
       def run(data: RDD[Vector]): GaussianMixtureModel = {
         val sc = data.sparkContext
    -    
    +
         // we will operate on the data as breeze data
         val breezeData = data.map(_.toBreeze).cache()
    -    
    +
         // Get length of the input vectors
         val d = breezeData.first().length
    -    
    +
    +    // Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when
    +    // d > 25 except for when k is very small
    +    val distributeGaussians = ((k - 1.0) / k) * d > 25
    +
         // Determine initial weights and corresponding Gaussians.
         // If the user supplied an initial GMM, we use those values, otherwise
         // we start with uniform weights, a random mean from the data, and
         // diagonal covariance matrices using component variances
    -    // derived from the samples    
    +    // derived from the samples
         val (weights, gaussians) = initialModel match {
           case Some(gmm) => (gmm.weights, gmm.gaussians)
    -      
    +
           case None => {
             val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
    -        (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => 
    +        (Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
               val slice = samples.view(i * nSamples, (i + 1) * nSamples)
    -          new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) 
    +          new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
             })
           }
         }
    -    
    -    var llh = Double.MinValue // current log-likelihood 
    +
    +    var llh = Double.MinValue // current log-likelihood
         var llhp = 0.0            // previous log-likelihood
    -    
    +
         var iter = 0
         while (iter < maxIterations && math.abs(llh-llhp) > convergenceTol) {
           // create and broadcast curried cluster contribution function
           val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_)
    -      
    +
           // aggregate the cluster contribution for all sample points
           val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _)
    -      
    +
           // Create new distributions based on the partial assignments
           // (often referred to as the "M" step in literature)
           val sumWeights = sums.weights.sum
    -      var i = 0
    -      while (i < k) {
    -        val mu = sums.means(i) / sums.weights(i)
    -        BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu),
    -          Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
    -        weights(i) = sums.weights(i) / sumWeights
    -        gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
    -        i = i + 1
    +
    +      if (distributeGaussians) {
    +        val numPartitions = math.min(k, 1024)
    +        val tuples =
    +          Seq.tabulate(k)(i => (sums.means(i), sums.sigmas(i), sums.weights(i)))
    +        val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, sigma, weight) =>
    +          updateWeightsAndGaussians(mean, sigma, weight, sumWeights)
    +        }.collect.unzip
    +        Array.copy(ws, 0, weights, 0, ws.length)
    +        Array.copy(gs, 0, gaussians, 0, gs.length)
    +      } else {
    +        var i = 0
    +        while (i < k) {
    +          val (weight, gaussian) =
    +            updateWeightsAndGaussians(sums.means(i), sums.sigmas(i), sums.weights(i), sumWeights)
    +          weights(i) = weight
    +          gaussians(i) = gaussian
    +          i = i + 1
    +        }
           }
    -   
    +
           llhp = llh // current becomes previous
           llh = sums.logLikelihood // this is the freshly computed log-likelihood
           iter += 1
    -    } 
    -    
    +    }
    +
         new GaussianMixtureModel(weights, gaussians)
       }
    -    
    +
    +  /** Java-friendly version of [[run()]] */
    +  def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd)
    +
    +  private def updateWeightsAndGaussians(
    +      mean: BDV[Double],
    +      sigma: BreezeMatrix[Double],
    +      weight: Double,
    +      sumWeights: Double): (Double, MultivariateGaussian) = {
    +    val mu = (mean /= weight)
    +    BLAS.syr(-weight, Vectors.fromBreeze(mu),
    +      Matrices.fromBreeze(sigma).asInstanceOf[DenseMatrix])
    +    val newWeight = weight / sumWeights
    +    val newGaussian = new MultivariateGaussian(mu, sigma / weight)
    +    (newWeight, newGaussian)
    +  }
    +
       /** Average of dense breeze vectors */
       private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
         val v = BDV.zeros[Double](x(0).length)
         x.foreach(xi => v += xi)
    -    v / x.length.toDouble 
    +    v / x.length.toDouble
       }
    -  
    +
       /**
        * Construct matrix where diagonal entries are element-wise
        * variance of input vectors (computes biased variance)
    @@ -210,14 +242,14 @@ class GaussianMixture private (
     // companion class to provide zero constructor for ExpectationSum
     private object ExpectationSum {
       def zero(k: Int, d: Int): ExpectationSum = {
    -    new ExpectationSum(0.0, Array.fill(k)(0.0), 
    -      Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
    +    new ExpectationSum(0.0, Array.fill(k)(0.0),
    +      Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d, d)))
       }
    -  
    +
       // compute cluster contributions for each input point
       // (U, T) => U for aggregation
       def add(
    -      weights: Array[Double], 
    +      weights: Array[Double],
           dists: Array[MultivariateGaussian])
           (sums: ExpectationSum, x: BV[Double]): ExpectationSum = {
         val p = weights.zip(dists).map {
    @@ -235,7 +267,7 @@ private object ExpectationSum {
           i = i + 1
         }
         sums
    -  }  
    +  }
     }
     
     // Aggregation class for partial expectation results
    @@ -244,9 +276,9 @@ private class ExpectationSum(
         val weights: Array[Double],
         val means: Array[BDV[Double]],
         val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {
    -  
    +
       val k = weights.length
    -  
    +
       def +=(x: ExpectationSum): ExpectationSum = {
         var i = 0
         while (i < k) {
    @@ -257,5 +289,5 @@ private class ExpectationSum(
         }
         logLikelihood += x.logLikelihood
         this
    -  }  
    +  }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
    index ec65a3da689de..cb807c8038101 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
    @@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._
     
     import org.apache.spark.SparkContext
     import org.apache.spark.annotation.Experimental
    +import org.apache.spark.api.java.JavaRDD
     import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
     import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
     import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable}
    @@ -34,21 +35,20 @@ import org.apache.spark.sql.{SQLContext, Row}
     /**
      * :: Experimental ::
      *
    - * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points 
    - * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are 
    - * the respective mean and covariance for each Gaussian distribution i=1..k. 
    - * 
    - * @param weight Weights for each Gaussian distribution in the mixture, where weight(i) is
    - *               the weight for Gaussian i, and weight.sum == 1
    - * @param mu Means for each Gaussian in the mixture, where mu(i) is the mean for Gaussian i
    - * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the
    - *              covariance matrix for Gaussian i
    + * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
    + * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are
    + * the respective mean and covariance for each Gaussian distribution i=1..k.
    + *
    + * @param weights Weights for each Gaussian distribution in the mixture, where weights(i) is
    + *                the weight for Gaussian i, and weights.sum == 1
    + * @param gaussians Array of MultivariateGaussian where gaussians(i) represents
    + *                  the Multivariate Gaussian (Normal) Distribution for Gaussian i
      */
     @Experimental
     class GaussianMixtureModel(
    -  val weights: Array[Double], 
    -  val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{
    -  
    +  val weights: Array[Double],
    +  val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable {
    +
       require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
     
       override protected def formatVersion = "1.0"
    @@ -65,20 +65,24 @@ class GaussianMixtureModel(
         val responsibilityMatrix = predictSoft(points)
         responsibilityMatrix.map(r => r.indexOf(r.max))
       }
    -  
    +
    +  /** Java-friendly version of [[predict()]] */
    +  def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
    +    predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
    +
       /**
        * Given the input vectors, return the membership value of each vector
    -   * to all mixture components. 
    +   * to all mixture components.
        */
       def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
         val sc = points.sparkContext
         val bcDists = sc.broadcast(gaussians)
         val bcWeights = sc.broadcast(weights)
    -    points.map { x => 
    +    points.map { x =>
           computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k)
         }
       }
    -  
    +
       /**
        * Compute the partial assignments for each vector
        */
    @@ -90,7 +94,7 @@ class GaussianMixtureModel(
         val p = weights.zip(dists).map {
           case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt)
         }
    -    val pSum = p.sum 
    +    val pSum = p.sum
         for (i <- 0 until k) {
           p(i) /= pSum
         }
    @@ -127,13 +131,13 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
           val dataArray = Array.tabulate(weights.length) { i =>
             Data(weights(i), gaussians(i).mu, gaussians(i).sigma)
           }
    -      sc.parallelize(dataArray, 1).toDF().saveAsParquetFile(Loader.dataPath(path))
    +      sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path))
         }
     
         def load(sc: SparkContext, path: String): GaussianMixtureModel = {
           val dataPath = Loader.dataPath(path)
           val sqlContext = new SQLContext(sc)
    -      val dataFrame = sqlContext.parquetFile(dataPath)
    +      val dataFrame = sqlContext.read.parquet(dataPath)
           val dataArray = dataFrame.select("weight", "mu", "sigma").collect()
     
           // Check schema explicitly since erasure makes it hard to use match-case for checking.
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
    index ba228b11fcec3..8ecb3df11d95e 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
    @@ -110,7 +110,7 @@ object KMeansModel extends Loader[KMeansModel] {
           val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
             Cluster(id, point)
           }.toDF()
    -      dataRDD.saveAsParquetFile(Loader.dataPath(path))
    +      dataRDD.write.parquet(Loader.dataPath(path))
         }
     
         def load(sc: SparkContext, path: String): KMeansModel = {
    @@ -120,7 +120,7 @@ object KMeansModel extends Loader[KMeansModel] {
           assert(className == thisClassName)
           assert(formatVersion == thisFormatVersion)
           val k = (metadata \ "k").extract[Int]
    -      val centriods = sqlContext.parquetFile(Loader.dataPath(path))
    +      val centriods = sqlContext.read.parquet(Loader.dataPath(path))
           Loader.checkSchema[Cluster](centriods.schema)
           val localCentriods = centriods.map(Cluster.apply).collect()
           assert(k == localCentriods.size)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
    index 6cf26445f20a0..974b26924dfb8 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
    @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
     import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
     
     import org.apache.spark.annotation.Experimental
    +import org.apache.spark.api.java.JavaPairRDD
     import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
     import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
     import org.apache.spark.rdd.RDD
    @@ -345,6 +346,11 @@ class DistributedLDAModel private (
         }
       }
     
    +  /** Java-friendly version of [[topicDistributions]] */
    +  def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = {
    +    JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
    +  }
    +
       // TODO:
       // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
    index 6fa2fe053c6a4..8e5154b902d1d 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
    @@ -273,7 +273,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
        * Default: 1024, following the original Online LDA paper.
        */
       def setTau0(tau0: Double): this.type = {
    -    require(tau0 > 0,  s"LDA tau0 must be positive, but was set to $tau0")
    +    require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0")
         this.tau0 = tau0
         this
       }
    @@ -339,7 +339,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
     
       override private[clustering] def initialize(
           docs: RDD[(Long, Vector)],
    -      lda: LDA):  OnlineLDAOptimizer = {
    +      lda: LDA): OnlineLDAOptimizer = {
         this.k = lda.getK
         this.corpusSize = docs.count()
         this.vocabSize = docs.first()._2.size
    @@ -458,7 +458,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
        * uses digamma which is accurate but expensive.
        */
       private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
    -    val rowSum =  sum(alpha(breeze.linalg.*, ::))
    +    val rowSum = sum(alpha(breeze.linalg.*, ::))
         val digAlpha = digamma(alpha)
         val digRowSum = digamma(rowSum)
         val result = digAlpha(::, breeze.linalg.*) - digRowSum
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
    index aa53e88d59856..e7a243f854e33 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
    @@ -74,7 +74,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
           sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
     
           val dataRDD = model.assignments.toDF()
    -      dataRDD.saveAsParquetFile(Loader.dataPath(path))
    +      dataRDD.write.parquet(Loader.dataPath(path))
         }
     
         def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
    @@ -86,7 +86,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
           assert(formatVersion == thisFormatVersion)
     
           val k = (metadata \ "k").extract[Int]
    -      val assignments = sqlContext.parquetFile(Loader.dataPath(path))
    +      val assignments = sqlContext.read.parquet(Loader.dataPath(path))
           Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema)
     
           val assignmentsRDD = assignments.map {
    @@ -121,7 +121,7 @@ class PowerIterationClustering private[clustering] (
       import org.apache.spark.mllib.clustering.PowerIterationClustering._
     
       /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100,
    -   *  initMode: "random"}. 
    +   *  initMode: "random"}.
        */
       def this() = this(k = 2, maxIterations = 100, initMode = "random")
     
    @@ -243,7 +243,7 @@ object PowerIterationClustering extends Logging {
     
       /**
        * Generates random vertex properties (v0) to start power iteration.
    -   * 
    +   *
        * @param g a graph representing the normalized affinity matrix (W)
        * @return a graph with edges representing W and vertices representing a random vector
        *         with unit 1-norm
    @@ -266,7 +266,7 @@ object PowerIterationClustering extends Logging {
        * Generates the degree vector as the vertex properties (v0) to start power iteration.
        * It is not exactly the node degrees but just the normalized sum similarities. Call it
        * as degree vector because it is used in the PIC paper.
    -   * 
    +   *
        * @param g a graph representing the normalized affinity matrix (W)
        * @return a graph with edges representing W and vertices representing the degree vector
        */
    @@ -276,7 +276,7 @@ object PowerIterationClustering extends Logging {
         val v0 = g.vertices.mapValues(_ / sum)
         GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges)
       }
    - 
    +
       /**
        * Runs power iteration.
        * @param g input graph with edges representing the normalized affinity matrix (W) and vertices
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
    index 812014a041719..d9b34cec64894 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
    @@ -21,8 +21,10 @@ import scala.reflect.ClassTag
     
     import org.apache.spark.Logging
     import org.apache.spark.annotation.Experimental
    +import org.apache.spark.api.java.JavaSparkContext._
     import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
     import org.apache.spark.rdd.RDD
    +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream}
     import org.apache.spark.streaming.dstream.DStream
     import org.apache.spark.util.Utils
     import org.apache.spark.util.random.XORShiftRandom
    @@ -178,7 +180,7 @@ class StreamingKMeans(
     
       /** Set the decay factor directly (for forgetful algorithms). */
       def setDecayFactor(a: Double): this.type = {
    -    this.decayFactor = decayFactor
    +    this.decayFactor = a
         this
       }
     
    @@ -234,6 +236,9 @@ class StreamingKMeans(
         }
       }
     
    +  /** Java-friendly version of `trainOn`. */
    +  def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream)
    +
       /**
        * Use the clustering model to make predictions on batches of data from a DStream.
        *
    @@ -245,6 +250,11 @@ class StreamingKMeans(
         data.map(model.predict)
       }
     
    +  /** Java-friendly version of `predictOn`. */
    +  def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = {
    +    JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]])
    +  }
    +
       /**
        * Use the model to make predictions on the values of a DStream and carry over its keys.
        *
    @@ -257,6 +267,14 @@ class StreamingKMeans(
         data.mapValues(model.predict)
       }
     
    +  /** Java-friendly version of `predictOnValues`. */
    +  def predictOnValues[K](
    +      data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = {
    +    implicit val tag = fakeClassTag[K]
    +    JavaPairDStream.fromPairDStream(
    +      predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Integer)]])
    +  }
    +
       /** Check whether cluster centers have been initialized. */
       private[this] def assertInitialized(): Unit = {
         if (model.clusterCenters == null) {
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
    index a8378a76d20ae..bf6eb1d5bd2ab 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
    @@ -19,6 +19,7 @@ package org.apache.spark.mllib.evaluation
     
     import org.apache.spark.rdd.RDD
     import org.apache.spark.SparkContext._
    +import org.apache.spark.sql.DataFrame
     
     /**
      * Evaluator for multilabel classification.
    @@ -27,6 +28,13 @@ import org.apache.spark.SparkContext._
      */
     class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {
     
    +  /**
    +   * An auxiliary constructor taking a DataFrame.
    +   * @param predictionAndLabels a DataFrame with two double array columns: prediction and label
    +   */
    +  private[mllib] def this(predictionAndLabels: DataFrame) =
    +    this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray)))
    +
       private lazy val numDocs: Long = predictionAndLabels.count()
     
       private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
    index b9b54b93c27fa..5b5a2a1450f7f 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
    @@ -31,6 +31,8 @@ import org.apache.spark.rdd.RDD
      * ::Experimental::
      * Evaluator for ranking algorithms.
      *
    + * Java users should use [[RankingMetrics$.of]] to create a [[RankingMetrics]] instance.
    + *
      * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs.
      */
     @Experimental
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
    index c6057c7f837b1..5f8c1dea237b4 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
    @@ -38,7 +38,8 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf
     
       protected def isSorted(array: Array[Int]): Boolean = {
         var i = 1
    -    while (i < array.length) {
    +    val len = array.length
    +    while (i < len) {
           if (array(i) < array(i-1)) return false
           i += 1
         }
    @@ -107,7 +108,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf
      *                       (ordered by statistic value descending)
      */
     @Experimental
    -class ChiSqSelector (val numTopFeatures: Int) {
    +class ChiSqSelector (val numTopFeatures: Int) extends Serializable {
     
       /**
        * Returns a ChiSquared feature selector.
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
    index b0985baf9b278..d67fe6c3ee4f8 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
    @@ -25,10 +25,10 @@ import org.apache.spark.mllib.linalg._
      * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a
      * provided "weight" vector. In other words, it scales each column of the dataset by a scalar
      * multiplier.
    - * @param scalingVector The values used to scale the reference vector's individual components.
    + * @param scalingVec The values used to scale the reference vector's individual components.
      */
     @Experimental
    -class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer {
    +class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer {
     
       /**
        * Does the hadamard product transformation.
    @@ -37,15 +37,15 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer {
        * @return transformed vector.
        */
       override def transform(vector: Vector): Vector = {
    -    require(vector.size == scalingVector.size,
    -      s"vector sizes do not match: Expected ${scalingVector.size} but found ${vector.size}")
    +    require(vector.size == scalingVec.size,
    +      s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}")
         vector match {
           case dv: DenseVector =>
             val values: Array[Double] = dv.values.clone()
    -        val dim = scalingVector.size
    +        val dim = scalingVec.size
             var i = 0
             while (i < dim) {
    -          values(i) *= scalingVector(i)
    +          values(i) *= scalingVec(i)
               i += 1
             }
             Vectors.dense(values)
    @@ -54,7 +54,7 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer {
             val dim = values.length
             var i = 0
             while (i < dim) {
    -          values(i) *= scalingVector(indices(i))
    +          values(i) *= scalingVec(indices(i))
               i += 1
             }
             Vectors.sparse(size, indices, values)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
    index a89eea0e21be2..3fab7ea79befc 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
    @@ -144,7 +144,7 @@ private object IDF {
              * Since arrays are initialized to 0 by default,
              * we just omit changing those entries.
              */
    -        if(df(j) >= minDocFreq) {
    +        if (df(j) >= minDocFreq) {
               inv(j) = math.log((m + 1.0) / (df(j) + 1.0))
             }
             j += 1
    @@ -159,7 +159,7 @@ private object IDF {
      * Represents an IDF model that can transform term frequency vectors.
      */
     @Experimental
    -class IDFModel private[mllib] (val idf: Vector) extends Serializable {
    +class IDFModel private[spark] (val idf: Vector) extends Serializable {
     
       /**
        * Transforms term frequency (TF) vectors to TF-IDF vectors.
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
    index 4e01e402b4283..2a66263d8b7d6 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
    @@ -68,7 +68,7 @@ class PCA(val k: Int) {
      * @param k number of principal components.
      * @param pc a principal components Matrix. Each column is one principal component.
      */
    -class PCAModel private[mllib] (val k: Int, val pc: DenseMatrix) extends VectorTransformer {
    +class PCAModel private[spark] (val k: Int, val pc: DenseMatrix) extends VectorTransformer {
       /**
        * Transform a vector by computed Principal Components.
        *
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
    index 6ae6917eae595..c73b8f258060d 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
    @@ -90,7 +90,7 @@ class StandardScalerModel (
     
       @DeveloperApi
       def setWithMean(withMean: Boolean): this.type = {
    -    require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null")
    +    require(!(withMean && this.mean == null), "cannot set withMean to true while mean is null")
         this.withMean = withMean
         this
       }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
    index 98e83112f52ae..f087d06d2a46a 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
    @@ -42,32 +42,32 @@ import org.apache.spark.util.random.XORShiftRandom
     import org.apache.spark.sql.{SQLContext, Row}
     
     /**
    - *  Entry in vocabulary 
    + *  Entry in vocabulary
      */
     private case class VocabWord(
       var word: String,
       var cn: Int,
       var point: Array[Int],
       var code: Array[Int],
    -  var codeLen:Int
    +  var codeLen: Int
     )
     
     /**
      * :: Experimental ::
      * Word2Vec creates vector representation of words in a text corpus.
      * The algorithm first constructs a vocabulary from the corpus
    - * and then learns vector representation of words in the vocabulary. 
    - * The vector representation can be used as features in 
    + * and then learns vector representation of words in the vocabulary.
    + * The vector representation can be used as features in
      * natural language processing and machine learning algorithms.
    - * 
    - * We used skip-gram model in our implementation and hierarchical softmax 
    + *
    + * We used skip-gram model in our implementation and hierarchical softmax
      * method to train the model. The variable names in the implementation
      * matches the original C implementation.
      *
    - * For original C implementation, see https://code.google.com/p/word2vec/ 
    - * For research papers, see 
    + * For original C implementation, see https://code.google.com/p/word2vec/
    + * For research papers, see
      * Efficient Estimation of Word Representations in Vector Space
    - * and 
    + * and
      * Distributed Representations of Words and Phrases and their Compositionality.
      */
     @Experimental
    @@ -79,7 +79,7 @@ class Word2Vec extends Serializable with Logging {
       private var numIterations = 1
       private var seed = Utils.random.nextLong()
       private var minCount = 5
    -  
    +
       /**
        * Sets vector size (default: 100).
        */
    @@ -122,15 +122,15 @@ class Word2Vec extends Serializable with Logging {
         this
       }
     
    -  /** 
    -   * Sets minCount, the minimum number of times a token must appear to be included in the word2vec 
    +  /**
    +   * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
        * model's vocabulary (default: 5).
        */
       def setMinCount(minCount: Int): this.type = {
         this.minCount = minCount
         this
       }
    -  
    +
       private val EXP_TABLE_SIZE = 1000
       private val MAX_EXP = 6
       private val MAX_CODE_LENGTH = 40
    @@ -150,14 +150,17 @@ class Word2Vec extends Serializable with Logging {
           .map(x => VocabWord(
             x._1,
             x._2,
    -        new Array[Int](MAX_CODE_LENGTH), 
    -        new Array[Int](MAX_CODE_LENGTH), 
    +        new Array[Int](MAX_CODE_LENGTH),
    +        new Array[Int](MAX_CODE_LENGTH),
             0))
           .filter(_.cn >= minCount)
           .collect()
           .sortWith((a, b) => a.cn > b.cn)
    -    
    +
         vocabSize = vocab.length
    +    require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " +
    +      "the setting of minCount, which could be large enough to remove all your words in sentences.")
    +
         var a = 0
         while (a < vocabSize) {
           vocabHash += vocab(a).word -> a
    @@ -195,8 +198,8 @@ class Word2Vec extends Serializable with Logging {
         }
         var pos1 = vocabSize - 1
         var pos2 = vocabSize
    -    
    -    var min1i = 0 
    +
    +    var min1i = 0
         var min2i = 0
     
         a = 0
    @@ -265,15 +268,15 @@ class Word2Vec extends Serializable with Logging {
         val words = dataset.flatMap(x => x)
     
         learnVocab(words)
    -    
    +
         createBinaryTree()
    -    
    +
         val sc = dataset.context
     
         val expTable = sc.broadcast(createExpTable())
         val bcVocab = sc.broadcast(vocab)
         val bcVocabHash = sc.broadcast(vocabHash)
    -    
    +
         val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>
           new Iterator[Array[Int]] {
             def hasNext: Boolean = iter.hasNext
    @@ -294,7 +297,7 @@ class Word2Vec extends Serializable with Logging {
             }
           }
         }
    -    
    +
         val newSentences = sentences.repartition(numPartitions).cache()
         val initRandom = new XORShiftRandom(seed)
     
    @@ -399,7 +402,7 @@ class Word2Vec extends Serializable with Logging {
           }
         }
         newSentences.unpersist()
    -    
    +
         val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
         var i = 0
         while (i < vocabSize) {
    @@ -428,7 +431,7 @@ class Word2Vec extends Serializable with Logging {
      * Word2Vec model
      */
     @Experimental
    -class Word2VecModel private[mllib] (
    +class Word2VecModel private[spark] (
         model: Map[String, Array[Float]]) extends Serializable with Saveable {
     
       // wordList: Ordered list of words obtained from model.
    @@ -466,7 +469,7 @@ class Word2VecModel private[mllib] (
         val norm1 = blas.snrm2(n, v1, 1)
         val norm2 = blas.snrm2(n, v2, 1)
         if (norm1 == 0 || norm2 == 0) return 0.0
    -    blas.sdot(n, v1, 1, v2,1) / norm1 / norm2
    +    blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2
       }
     
       override protected def formatVersion = "1.0"
    @@ -477,7 +480,7 @@ class Word2VecModel private[mllib] (
     
       /**
        * Transforms a word to its vector representation
    -   * @param word a word 
    +   * @param word a word
        * @return vector representation of word
        */
       def transform(word: String): Vector = {
    @@ -492,18 +495,18 @@ class Word2VecModel private[mllib] (
       /**
        * Find synonyms of a word
        * @param word a word
    -   * @param num number of synonyms to find  
    +   * @param num number of synonyms to find
        * @return array of (word, cosineSimilarity)
        */
       def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
         val vector = transform(word)
    -    findSynonyms(vector,num)
    +    findSynonyms(vector, num)
       }
     
       /**
        * Find synonyms of the vector representation of a word
        * @param vector vector representation of a word
    -   * @param num number of synonyms to find  
    +   * @param num number of synonyms to find
        * @return array of (word, cosineSimilarity)
        */
       def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
    @@ -556,7 +559,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
         def load(sc: SparkContext, path: String): Word2VecModel = {
           val dataPath = Loader.dataPath(path)
           val sqlContext = new SQLContext(sc)
    -      val dataFrame = sqlContext.parquetFile(dataPath)
    +      val dataFrame = sqlContext.read.parquet(dataPath)
     
           val dataArray = dataFrame.select("word", "vector").collect()
     
    @@ -580,7 +583,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
           sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
     
           val dataArray = model.toSeq.map { case (w, v) => Data(w, v) }
    -      sc.parallelize(dataArray.toSeq, 1).toDF().saveAsParquetFile(Loader.dataPath(path))
    +      sc.parallelize(dataArray.toSeq, 1).toDF().write.parquet(Loader.dataPath(path))
         }
       }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
    new file mode 100644
    index 0000000000000..72d0ea0c12e1e
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
    @@ -0,0 +1,119 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.mllib.fpm
    +
    +import scala.reflect.ClassTag
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.api.java.JavaRDD
    +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
    +import org.apache.spark.mllib.fpm.AssociationRules.Rule
    +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
    +import org.apache.spark.rdd.RDD
    +
    +/**
    + * :: Experimental ::
    + *
    + * Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates
    + * association rules which have a single item as the consequent.
    + *
    + * @since 1.5.0
    + */
    +@Experimental
    +class AssociationRules private[fpm] (
    +    private var minConfidence: Double) extends Logging with Serializable {
    +
    +  /**
    +   * Constructs a default instance with default parameters {minConfidence = 0.8}.
    +   *
    +   * @since 1.5.0
    +   */
    +  def this() = this(0.8)
    +
    +  /**
    +   * Sets the minimal confidence (default: `0.8`).
    +   *
    +   * @since 1.5.0
    +   */
    +  def setMinConfidence(minConfidence: Double): this.type = {
    +    require(minConfidence >= 0.0 && minConfidence <= 1.0)
    +    this.minConfidence = minConfidence
    +    this
    +  }
    +
    +  /**
    +   * Computes the association rules with confidence above [[minConfidence]].
    +   * @param freqItemsets frequent itemset model obtained from [[FPGrowth]]
    +   * @return a [[Set[Rule[Item]]] containing the assocation rules.
    +   *
    +   * @since 1.5.0
    +   */
    +  def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = {
    +    // For candidate rule X => Y, generate (X, (Y, freq(X union Y)))
    +    val candidates = freqItemsets.flatMap { itemset =>
    +      val items = itemset.items
    +      items.flatMap { item =>
    +        items.partition(_ == item) match {
    +          case (consequent, antecedent) if !antecedent.isEmpty =>
    +            Some((antecedent.toSeq, (consequent.toSeq, itemset.freq)))
    +          case _ => None
    +        }
    +      }
    +    }
    +
    +    // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence
    +    candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq)))
    +      .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) =>
    +      new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent)
    +    }.filter(_.confidence >= minConfidence)
    +  }
    +
    +  def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = {
    +    val tag = fakeClassTag[Item]
    +    run(freqItemsets.rdd)(tag)
    +  }
    +}
    +
    +object AssociationRules {
    +
    +  /**
    +   * :: Experimental ::
    +   *
    +   * An association rule between sets of items.
    +   * @param antecedent hypotheses of the rule
    +   * @param consequent conclusion of the rule
    +   * @tparam Item item type
    +   *
    +   * @since 1.5.0
    +   */
    +  @Experimental
    +  class Rule[Item] private[fpm] (
    +      val antecedent: Array[Item],
    +      val consequent: Array[Item],
    +      freqUnion: Double,
    +      freqAntecedent: Double) extends Serializable {
    +
    +    def confidence: Double = freqUnion.toDouble / freqAntecedent
    +
    +    require(antecedent.toSet.intersect(consequent.toSet).isEmpty, {
    +      val sharedItems = antecedent.toSet.intersect(consequent.toSet)
    +      s"A valid association rule must have disjoint antecedent and " +
    +        s"consequent but ${sharedItems} is present in both."
    +    })
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
    index efa8459d3cdba..e2370a52f4930 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
    @@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
     import org.apache.spark.annotation.Experimental
     import org.apache.spark.api.java.JavaRDD
     import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
    -import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
    +import org.apache.spark.mllib.fpm.FPGrowth._
     import org.apache.spark.rdd.RDD
     import org.apache.spark.storage.StorageLevel
     
    @@ -38,9 +38,21 @@ import org.apache.spark.storage.StorageLevel
      * Model trained by [[FPGrowth]], which holds frequent itemsets.
      * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]]
      * @tparam Item item type
    + *
    + * @since 1.3.0
      */
     @Experimental
    -class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable
    +class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable {
    +  /**
    +   * Generates association rules for the [[Item]]s in [[freqItemsets]].
    +   * @param confidence minimal confidence of the rules produced
    +   * @since 1.5.0
    +   */
    +  def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = {
    +    val associationRules = new AssociationRules(confidence)
    +    associationRules.run(freqItemsets)
    +  }
    +}
     
     /**
      * :: Experimental ::
    @@ -58,6 +70,8 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex
      *
      * @see [[http://en.wikipedia.org/wiki/Association_rule_learning Association rule learning
      *       (Wikipedia)]]
    + *
    + * @since 1.3.0
      */
     @Experimental
     class FPGrowth private (
    @@ -67,11 +81,15 @@ class FPGrowth private (
       /**
        * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same
        * as the input data}.
    +   *
    +   * @since 1.3.0
        */
       def this() = this(0.3, -1)
     
       /**
        * Sets the minimal support level (default: `0.3`).
    +   *
    +   * @since 1.3.0
        */
       def setMinSupport(minSupport: Double): this.type = {
         this.minSupport = minSupport
    @@ -80,6 +98,8 @@ class FPGrowth private (
     
       /**
        * Sets the number of partitions used by parallel FP-growth (default: same as input data).
    +   *
    +   * @since 1.3.0
        */
       def setNumPartitions(numPartitions: Int): this.type = {
         this.numPartitions = numPartitions
    @@ -90,6 +110,8 @@ class FPGrowth private (
        * Computes an FP-Growth model that contains frequent itemsets.
        * @param data input data set, each element contains a transaction
        * @return an [[FPGrowthModel]]
    +   *
    +   * @since 1.3.0
        */
       def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = {
         if (data.getStorageLevel == StorageLevel.NONE) {
    @@ -190,6 +212,8 @@ class FPGrowth private (
     
     /**
      * :: Experimental ::
    + *
    + * @since 1.3.0
      */
     @Experimental
     object FPGrowth {
    @@ -199,11 +223,15 @@ object FPGrowth {
        * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead.
        * @param freq frequency
        * @tparam Item item type
    +   *
    +   * @since 1.3.0
        */
       class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable {
     
         /**
          * Returns items in a Java List.
    +     *
    +     * @since 1.3.0
          */
         def javaItems: java.util.List[Item] = {
           items.toList.asJava
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
    new file mode 100644
    index 0000000000000..39c48b084e550
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
    @@ -0,0 +1,113 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.fpm
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.annotation.Experimental
    +
    +/**
    + *
    + * :: Experimental ::
    + *
    + * Calculate all patterns of a projected database in local.
    + */
    +@Experimental
    +private[fpm] object LocalPrefixSpan extends Logging with Serializable {
    +
    +  /**
    +   * Calculate all patterns of a projected database.
    +   * @param minCount minimum count
    +   * @param maxPatternLength maximum pattern length
    +   * @param prefix prefix
    +   * @param projectedDatabase the projected dabase
    +   * @return a set of sequential pattern pairs,
    +   *         the key of pair is sequential pattern (a list of items),
    +   *         the value of pair is the pattern's count.
    +   */
    +  def run(
    +      minCount: Long,
    +      maxPatternLength: Int,
    +      prefix: Array[Int],
    +      projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
    +    val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
    +    val frequentPatternAndCounts = frequentPrefixAndCounts
    +      .map(x => (prefix ++ Array(x._1), x._2))
    +    val prefixProjectedDatabases = getPatternAndProjectedDatabase(
    +      prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
    +
    +    val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
    +    if (continueProcess) {
    +      val nextPatterns = prefixProjectedDatabases
    +        .map(x => run(minCount, maxPatternLength, x._1, x._2))
    +        .reduce(_ ++ _)
    +      frequentPatternAndCounts ++ nextPatterns
    +    } else {
    +      frequentPatternAndCounts
    +    }
    +  }
    +
    +  /**
    +   * calculate suffix sequence following a prefix in a sequence
    +   * @param prefix prefix
    +   * @param sequence sequence
    +   * @return suffix sequence
    +   */
    +  def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
    +    val index = sequence.indexOf(prefix)
    +    if (index == -1) {
    +      Array()
    +    } else {
    +      sequence.drop(index + 1)
    +    }
    +  }
    +
    +  /**
    +   * Generates frequent items by filtering the input data using minimal count level.
    +   * @param minCount the absolute minimum count
    +   * @param sequences sequences data
    +   * @return array of item and count pair
    +   */
    +  private def getFreqItemAndCounts(
    +      minCount: Long,
    +      sequences: Array[Array[Int]]): Array[(Int, Long)] = {
    +    sequences.flatMap(_.distinct)
    +      .groupBy(x => x)
    +      .mapValues(_.length.toLong)
    +      .filter(_._2 >= minCount)
    +      .toArray
    +  }
    +
    +  /**
    +   * Get the frequent prefixes' projected database.
    +   * @param prePrefix the frequent prefixes' prefix
    +   * @param frequentPrefixes frequent prefixes
    +   * @param sequences sequences data
    +   * @return prefixes and projected database
    +   */
    +  private def getPatternAndProjectedDatabase(
    +      prePrefix: Array[Int],
    +      frequentPrefixes: Array[Int],
    +      sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
    +    val filteredProjectedDatabase = sequences
    +      .map(x => x.filter(frequentPrefixes.contains(_)))
    +    frequentPrefixes.map { x =>
    +      val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
    +      (prePrefix ++ Array(x), sub)
    +    }.filter(x => x._2.nonEmpty)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
    new file mode 100644
    index 0000000000000..9d8c60ef0fc45
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
    @@ -0,0 +1,157 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.fpm
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.storage.StorageLevel
    +
    +/**
    + *
    + * :: Experimental ::
    + *
    + * A parallel PrefixSpan algorithm to mine sequential pattern.
    + * The PrefixSpan algorithm is described in
    + * [[http://doi.org/10.1109/ICDE.2001.914830]].
    + *
    + * @param minSupport the minimal support level of the sequential pattern, any pattern appears
    + *                   more than  (minSupport * size-of-the-dataset) times will be output
    + * @param maxPatternLength the maximal length of the sequential pattern, any pattern appears
    + *                   less than maxPatternLength will be output
    + *
    + * @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining
    + *       (Wikipedia)]]
    + */
    +@Experimental
    +class PrefixSpan private (
    +    private var minSupport: Double,
    +    private var maxPatternLength: Int) extends Logging with Serializable {
    +
    +  /**
    +   * Constructs a default instance with default parameters
    +   * {minSupport: `0.1`, maxPatternLength: `10`}.
    +   */
    +  def this() = this(0.1, 10)
    +
    +  /**
    +   * Sets the minimal support level (default: `0.1`).
    +   */
    +  def setMinSupport(minSupport: Double): this.type = {
    +    require(minSupport >= 0 && minSupport <= 1,
    +      "The minimum support value must be between 0 and 1, including 0 and 1.")
    +    this.minSupport = minSupport
    +    this
    +  }
    +
    +  /**
    +   * Sets maximal pattern length (default: `10`).
    +   */
    +  def setMaxPatternLength(maxPatternLength: Int): this.type = {
    +    require(maxPatternLength >= 1,
    +      "The maximum pattern length value must be greater than 0.")
    +    this.maxPatternLength = maxPatternLength
    +    this
    +  }
    +
    +  /**
    +   * Find the complete set of sequential patterns in the input sequences.
    +   * @param sequences input data set, contains a set of sequences,
    +   *                  a sequence is an ordered list of elements.
    +   * @return a set of sequential pattern pairs,
    +   *         the key of pair is pattern (a list of elements),
    +   *         the value of pair is the pattern's count.
    +   */
    +  def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
    +    if (sequences.getStorageLevel == StorageLevel.NONE) {
    +      logWarning("Input data is not cached.")
    +    }
    +    val minCount = getMinCount(sequences)
    +    val lengthOnePatternsAndCounts =
    +      getFreqItemAndCounts(minCount, sequences).collect()
    +    val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
    +      lengthOnePatternsAndCounts.map(_._1), sequences)
    +    val groupedProjectedDatabase = prefixAndProjectedDatabase
    +      .map(x => (x._1.toSeq, x._2))
    +      .groupByKey()
    +      .map(x => (x._1.toArray, x._2.toArray))
    +    val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
    +    val lengthOnePatternsAndCountsRdd =
    +      sequences.sparkContext.parallelize(
    +        lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
    +    val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
    +    allPatterns
    +  }
    +
    +  /**
    +   * Get the minimum count (sequences count * minSupport).
    +   * @param sequences input data set, contains a set of sequences,
    +   * @return minimum count,
    +   */
    +  private def getMinCount(sequences: RDD[Array[Int]]): Long = {
    +    if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
    +  }
    +
    +  /**
    +   * Generates frequent items by filtering the input data using minimal count level.
    +   * @param minCount the absolute minimum count
    +   * @param sequences original sequences data
    +   * @return array of item and count pair
    +   */
    +  private def getFreqItemAndCounts(
    +      minCount: Long,
    +      sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
    +    sequences.flatMap(_.distinct.map((_, 1L)))
    +      .reduceByKey(_ + _)
    +      .filter(_._2 >= minCount)
    +  }
    +
    +  /**
    +   * Get the frequent prefixes' projected database.
    +   * @param frequentPrefixes frequent prefixes
    +   * @param sequences sequences data
    +   * @return prefixes and projected database
    +   */
    +  private def getPrefixAndProjectedDatabase(
    +      frequentPrefixes: Array[Int],
    +      sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
    +    val filteredSequences = sequences.map { p =>
    +      p.filter (frequentPrefixes.contains(_) )
    +    }
    +    filteredSequences.flatMap { x =>
    +      frequentPrefixes.map { y =>
    +        val sub = LocalPrefixSpan.getSuffix(y, x)
    +        (Array(y), sub)
    +      }.filter(_._2.nonEmpty)
    +    }
    +  }
    +
    +  /**
    +   * calculate the patterns in local.
    +   * @param minCount the absolute minimum count
    +   * @param data patterns and projected sequences data data
    +   * @return patterns
    +   */
    +  private def getPatternsInLocal(
    +      minCount: Long,
    +      data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
    +    data.flatMap { x =>
    +      LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
    +    }
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
    index 87052e1ba8539..3523f1804325d 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
    @@ -213,9 +213,9 @@ private[spark] object BLAS extends Serializable with Logging {
       def scal(a: Double, x: Vector): Unit = {
         x match {
           case sx: SparseVector =>
    -        f2jBLAS.dscal(sx.values.size, a, sx.values, 1)
    +        f2jBLAS.dscal(sx.values.length, a, sx.values, 1)
           case dx: DenseVector =>
    -        f2jBLAS.dscal(dx.values.size, a, dx.values, 1)
    +        f2jBLAS.dscal(dx.values.length, a, dx.values, 1)
           case _ =>
             throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.")
         }
    @@ -228,7 +228,7 @@ private[spark] object BLAS extends Serializable with Logging {
         }
         _nativeBLAS
       }
    - 
    +
       /**
        * A := alpha * x * x^T^ + A
        * @param alpha a real scalar that will be multiplied to x * x^T^.
    @@ -264,7 +264,7 @@ private[spark] object BLAS extends Serializable with Logging {
             j += 1
           }
           i += 1
    -    }    
    +    }
       }
     
       private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) {
    @@ -463,7 +463,7 @@ private[spark] object BLAS extends Serializable with Logging {
       def gemv(
           alpha: Double,
           A: Matrix,
    -      x: DenseVector,
    +      x: Vector,
           beta: Double,
           y: DenseVector): Unit = {
         require(A.numCols == x.size,
    @@ -473,27 +473,32 @@ private[spark] object BLAS extends Serializable with Logging {
         if (alpha == 0.0) {
           logDebug("gemv: alpha is equal to 0. Returning y.")
         } else {
    -      A match {
    -        case sparse: SparseMatrix =>
    -          gemv(alpha, sparse, x, beta, y)
    -        case dense: DenseMatrix =>
    -          gemv(alpha, dense, x, beta, y)
    +      (A, x) match {
    +        case (smA: SparseMatrix, dvx: DenseVector) =>
    +          gemv(alpha, smA, dvx, beta, y)
    +        case (smA: SparseMatrix, svx: SparseVector) =>
    +          gemv(alpha, smA, svx, beta, y)
    +        case (dmA: DenseMatrix, dvx: DenseVector) =>
    +          gemv(alpha, dmA, dvx, beta, y)
    +        case (dmA: DenseMatrix, svx: SparseVector) =>
    +          gemv(alpha, dmA, svx, beta, y)
             case _ =>
    -          throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.")
    +          throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " +
    +            s"${A.getClass} and vector type ${x.getClass}.")
           }
         }
       }
     
       /**
        * y := alpha * A * x + beta * y
    -   * For `DenseMatrix` A.
    +   * For `DenseMatrix` A and `DenseVector` x.
        */
       private def gemv(
           alpha: Double,
           A: DenseMatrix,
           x: DenseVector,
           beta: Double,
    -      y: DenseVector): Unit =  {
    +      y: DenseVector): Unit = {
         val tStrA = if (A.isTransposed) "T" else "N"
         val mA = if (!A.isTransposed) A.numRows else A.numCols
         val nA = if (!A.isTransposed) A.numCols else A.numRows
    @@ -503,14 +508,134 @@ private[spark] object BLAS extends Serializable with Logging {
     
       /**
        * y := alpha * A * x + beta * y
    -   * For `SparseMatrix` A.
    +   * For `DenseMatrix` A and `SparseVector` x.
    +   */
    +  private def gemv(
    +      alpha: Double,
    +      A: DenseMatrix,
    +      x: SparseVector,
    +      beta: Double,
    +      y: DenseVector): Unit = {
    +    val mA: Int = A.numRows
    +    val nA: Int = A.numCols
    +
    +    val Avals = A.values
    +
    +    val xIndices = x.indices
    +    val xNnz = xIndices.length
    +    val xValues = x.values
    +    val yValues = y.values
    +
    +    if (alpha == 0.0) {
    +      scal(beta, y)
    +      return
    +    }
    +
    +    if (A.isTransposed) {
    +      var rowCounterForA = 0
    +      while (rowCounterForA < mA) {
    +        var sum = 0.0
    +        var k = 0
    +        while (k < xNnz) {
    +          sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA)
    +          k += 1
    +        }
    +        yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA)
    +        rowCounterForA += 1
    +      }
    +    } else {
    +      var rowCounterForA = 0
    +      while (rowCounterForA < mA) {
    +        var sum = 0.0
    +        var k = 0
    +        while (k < xNnz) {
    +          sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA)
    +          k += 1
    +        }
    +        yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA)
    +        rowCounterForA += 1
    +      }
    +    }
    +  }
    +
    +  /**
    +   * y := alpha * A * x + beta * y
    +   * For `SparseMatrix` A and `SparseVector` x.
    +   */
    +  private def gemv(
    +      alpha: Double,
    +      A: SparseMatrix,
    +      x: SparseVector,
    +      beta: Double,
    +      y: DenseVector): Unit = {
    +    val xValues = x.values
    +    val xIndices = x.indices
    +    val xNnz = xIndices.length
    +
    +    val yValues = y.values
    +
    +    val mA: Int = A.numRows
    +    val nA: Int = A.numCols
    +
    +    val Avals = A.values
    +    val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs
    +    val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices
    +
    +    if (alpha == 0.0) {
    +      scal(beta, y)
    +      return
    +    }
    +
    +    if (A.isTransposed) {
    +      var rowCounter = 0
    +      while (rowCounter < mA) {
    +        var i = Arows(rowCounter)
    +        val indEnd = Arows(rowCounter + 1)
    +        var sum = 0.0
    +        var k = 0
    +        while (k < xNnz && i < indEnd) {
    +          if (xIndices(k) == Acols(i)) {
    +            sum += Avals(i) * xValues(k)
    +            i += 1
    +          }
    +          k += 1
    +        }
    +        yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
    +        rowCounter += 1
    +      }
    +    } else {
    +      scal(beta, y)
    +
    +      var colCounterForA = 0
    +      var k = 0
    +      while (colCounterForA < nA && k < xNnz) {
    +        if (xIndices(k) == colCounterForA) {
    +          var i = Acols(colCounterForA)
    +          val indEnd = Acols(colCounterForA + 1)
    +
    +          val xTemp = xValues(k) * alpha
    +          while (i < indEnd) {
    +            val rowIndex = Arows(i)
    +            yValues(Arows(i)) += Avals(i) * xTemp
    +            i += 1
    +          }
    +          k += 1
    +        }
    +        colCounterForA += 1
    +      }
    +    }
    +  }
    +
    +  /**
    +   * y := alpha * A * x + beta * y
    +   * For `SparseMatrix` A and `DenseVector` x.
        */
       private def gemv(
           alpha: Double,
           A: SparseMatrix,
           x: DenseVector,
           beta: Double,
    -      y: DenseVector): Unit =  {
    +      y: DenseVector): Unit = {
         val xValues = x.values
         val yValues = y.values
         val mA: Int = A.numRows
    @@ -534,10 +659,7 @@ private[spark] object BLAS extends Serializable with Logging {
             rowCounter += 1
           }
         } else {
    -      // Scale vector first if `beta` is not equal to 0.0
    -      if (beta != 0.0) {
    -        scal(beta, y)
    -      }
    +      scal(beta, y)
           // Perform matrix-vector multiplication and add to y
           var colCounterForA = 0
           while (colCounterForA < nA) {
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
    index 866936aa4f118..ae3ba3099c878 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
    @@ -81,7 +81,7 @@ private[mllib] object EigenValueDecomposition {
     
         require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE,
           s"k = $k and/or n = $n are too large to compute an eigendecomposition")
    -    
    +
         var ido = new intW(0)
         var info = new intW(0)
         var resid = new Array[Double](n)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
    index 3fa5e068d16d4..0df07663405a3 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
    @@ -24,9 +24,9 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash
     import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
     
     import org.apache.spark.annotation.DeveloperApi
    -import org.apache.spark.sql.Row
    -import org.apache.spark.sql.types._
     import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
    +import org.apache.spark.sql.catalyst.InternalRow
    +import org.apache.spark.sql.types._
     
     /**
      * Trait for a local matrix.
    @@ -77,8 +77,13 @@ sealed trait Matrix extends Serializable {
         C
       }
     
    -  /** Convenience method for `Matrix`-`DenseVector` multiplication. */
    +  /** Convenience method for `Matrix`-`DenseVector` multiplication. For binary compatibility. */
       def multiply(y: DenseVector): DenseVector = {
    +    multiply(y.asInstanceOf[Vector])
    +  }
    +
    +  /** Convenience method for `Matrix`-`Vector` multiplication. */
    +  def multiply(y: Vector): DenseVector = {
         val output = new DenseVector(new Array[Double](numRows))
         BLAS.gemv(1.0, this, y, 0.0, output)
         output
    @@ -109,6 +114,16 @@ sealed trait Matrix extends Serializable {
        *          corresponding value in the matrix with type `Double`.
        */
       private[spark] def foreachActive(f: (Int, Int, Double) => Unit)
    +
    +  /**
    +   * Find the number of non-zero active values.
    +   */
    +  def numNonzeros: Int
    +
    +  /**
    +   * Find the number of values stored explicitly. These values can be zero as well.
    +   */
    +  def numActives: Int
     }
     
     @DeveloperApi
    @@ -132,7 +147,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
           ))
       }
     
    -  override def serialize(obj: Any): Row = {
    +  override def serialize(obj: Any): InternalRow = {
         val row = new GenericMutableRow(7)
         obj match {
           case sm: SparseMatrix =>
    @@ -158,9 +173,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
     
       override def deserialize(datum: Any): Matrix = {
         datum match {
    -      // TODO: something wrong with UDT serialization, should never happen.
    -      case m: Matrix => m
    -      case row: Row =>
    +      case row: InternalRow =>
             require(row.length == 7,
               s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7")
             val tpe = row.getByte(0)
    @@ -188,10 +201,13 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
         }
       }
     
    -  override def hashCode(): Int = 1994
    +  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
    +  override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode()
     
       override def typeName: String = "matrix"
     
    +  override def pyUDT: String = "pyspark.mllib.linalg.MatrixUDT"
    +
       private[spark] override def asNullable: MatrixUDT = this
     }
     
    @@ -273,7 +289,8 @@ class DenseMatrix(
     
       override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone())
     
    -  private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f))
    +  private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
    +    isTransposed)
     
       private[mllib] def update(f: Double => Double): DenseMatrix = {
         val len = values.length
    @@ -315,6 +332,10 @@ class DenseMatrix(
         }
       }
     
    +  override def numNonzeros: Int = values.count(_ != 0)
    +
    +  override def numActives: Int = values.length
    +
       /**
        * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed
        * set to false.
    @@ -535,7 +556,7 @@ class SparseMatrix(
       }
     
       private[mllib] def map(f: Double => Double) =
    -    new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f))
    +    new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed)
     
       private[mllib] def update(f: Double => Double): SparseMatrix = {
         val len = values.length
    @@ -584,6 +605,11 @@ class SparseMatrix(
       def toDense: DenseMatrix = {
         new DenseMatrix(numRows, numCols, toArray)
       }
    +
    +  override def numNonzeros: Int = values.count(_ != 0)
    +
    +  override def numActives: Int = values.length
    +
     }
     
     /**
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
    index e4ba9a243737d..68c933752a959 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
    @@ -28,7 +28,7 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
     import org.apache.spark.SparkException
     import org.apache.spark.annotation.DeveloperApi
     import org.apache.spark.mllib.util.NumericParser
    -import org.apache.spark.sql.Row
    +import org.apache.spark.sql.catalyst.InternalRow
     import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
     import org.apache.spark.sql.types._
     
    @@ -181,29 +181,28 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
           StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
       }
     
    -  override def serialize(obj: Any): Row = {
    -    val row = new GenericMutableRow(4)
    +  override def serialize(obj: Any): InternalRow = {
         obj match {
           case SparseVector(size, indices, values) =>
    +        val row = new GenericMutableRow(4)
             row.setByte(0, 0)
             row.setInt(1, size)
             row.update(2, indices.toSeq)
             row.update(3, values.toSeq)
    +        row
           case DenseVector(values) =>
    +        val row = new GenericMutableRow(4)
             row.setByte(0, 1)
             row.setNullAt(1)
             row.setNullAt(2)
             row.update(3, values.toSeq)
    +        row
         }
    -    row
       }
     
       override def deserialize(datum: Any): Vector = {
         datum match {
    -      // TODO: something wrong with UDT serialization
    -      case v: Vector =>
    -        v
    -      case row: Row =>
    +      case row: InternalRow =>
             require(row.length == 4,
               s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
             val tpe = row.getByte(0)
    @@ -231,7 +230,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
         }
       }
     
    -  override def hashCode: Int = 7919
    +  // see [SPARK-8647], this achieves the needed constant hash code without constant no.
    +  override def hashCode(): Int = classOf[VectorUDT].getName.hashCode()
     
       override def typeName: String = "vector"
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
    index 3be530fa07537..1c33b43ea7a8a 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
    @@ -146,7 +146,7 @@ class IndexedRowMatrix(
           val indexedRows = indices.zip(svd.U.rows).map { case (i, v) =>
             IndexedRow(i, v)
           }
    -      new IndexedRowMatrix(indexedRows, nRows, nCols)
    +      new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt)
         } else {
           null
         }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
    index 9a89a6f3a515f..1626da9c3d2ee 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
    @@ -219,7 +219,7 @@ class RowMatrix(
     
         val computeMode = mode match {
           case "auto" =>
    -        if(k > 5000) {
    +        if (k > 5000) {
               logWarning(s"computing svd with k=$k and n=$n, please check necessity")
             }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
    index 4b7d0589c973b..ab7611fd077ef 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
    @@ -19,13 +19,14 @@ package org.apache.spark.mllib.optimization
     
     import scala.collection.mutable.ArrayBuffer
     
    -import breeze.linalg.{DenseVector => BDV}
    +import breeze.linalg.{DenseVector => BDV, norm}
     
     import org.apache.spark.annotation.{Experimental, DeveloperApi}
     import org.apache.spark.Logging
     import org.apache.spark.rdd.RDD
     import org.apache.spark.mllib.linalg.{Vectors, Vector}
     
    +
     /**
      * Class used to solve an optimization problem using Gradient Descent.
      * @param gradient Gradient function to be used.
    @@ -38,6 +39,7 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
       private var numIterations: Int = 100
       private var regParam: Double = 0.0
       private var miniBatchFraction: Double = 1.0
    +  private var convergenceTol: Double = 0.001
     
       /**
        * Set the initial step size of SGD for the first step. Default 1.0.
    @@ -75,6 +77,23 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
         this
       }
     
    +  /**
    +   * Set the convergence tolerance. Default 0.001
    +   * convergenceTol is a condition which decides iteration termination.
    +   * The end of iteration is decided based on below logic.
    +   * - If the norm of the new solution vector is >1, the diff of solution vectors
    +   *   is compared to relative tolerance which means normalizing by the norm of
    +   *   the new solution vector.
    +   * - If the norm of the new solution vector is <=1, the diff of solution vectors
    +   *   is compared to absolute tolerance which is not normalizing.
    +   * Must be between 0.0 and 1.0 inclusively.
    +   */
    +  def setConvergenceTol(tolerance: Double): this.type = {
    +    require(0.0 <= tolerance && tolerance <= 1.0)
    +    this.convergenceTol = tolerance
    +    this
    +  }
    +
       /**
        * Set the gradient function (of the loss function of one single data example)
        * to be used for SGD.
    @@ -112,7 +131,8 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
           numIterations,
           regParam,
           miniBatchFraction,
    -      initialWeights)
    +      initialWeights,
    +      convergenceTol)
         weights
       }
     
    @@ -131,17 +151,20 @@ object GradientDescent extends Logging {
        * Sampling, and averaging the subgradients over this subset is performed using one standard
        * spark map-reduce in each iteration.
        *
    -   * @param data - Input data for SGD. RDD of the set of data examples, each of
    -   *               the form (label, [feature values]).
    -   * @param gradient - Gradient object (used to compute the gradient of the loss function of
    -   *                   one single data example)
    -   * @param updater - Updater function to actually perform a gradient step in a given direction.
    -   * @param stepSize - initial step size for the first step
    -   * @param numIterations - number of iterations that SGD should be run.
    -   * @param regParam - regularization parameter
    -   * @param miniBatchFraction - fraction of the input data set that should be used for
    -   *                            one iteration of SGD. Default value 1.0.
    -   *
    +   * @param data Input data for SGD. RDD of the set of data examples, each of
    +   *             the form (label, [feature values]).
    +   * @param gradient Gradient object (used to compute the gradient of the loss function of
    +   *                 one single data example)
    +   * @param updater Updater function to actually perform a gradient step in a given direction.
    +   * @param stepSize initial step size for the first step
    +   * @param numIterations number of iterations that SGD should be run.
    +   * @param regParam regularization parameter
    +   * @param miniBatchFraction fraction of the input data set that should be used for
    +   *                          one iteration of SGD. Default value 1.0.
    +   * @param convergenceTol Minibatch iteration will end before numIterations if the relative
    +   *                       difference between the current weight and the previous weight is less
    +   *                       than this value. In measuring convergence, L2 norm is calculated.
    +   *                       Default value 0.001. Must be between 0.0 and 1.0 inclusively.
        * @return A tuple containing two elements. The first element is a column matrix containing
        *         weights for every feature, and the second element is an array containing the
        *         stochastic loss computed for every iteration.
    @@ -154,9 +177,20 @@ object GradientDescent extends Logging {
           numIterations: Int,
           regParam: Double,
           miniBatchFraction: Double,
    -      initialWeights: Vector): (Vector, Array[Double]) = {
    +      initialWeights: Vector,
    +      convergenceTol: Double): (Vector, Array[Double]) = {
    +
    +    // convergenceTol should be set with non minibatch settings
    +    if (miniBatchFraction < 1.0 && convergenceTol > 0.0) {
    +      logWarning("Testing against a convergenceTol when using miniBatchFraction " +
    +        "< 1.0 can be unstable because of the stochasticity in sampling.")
    +    }
     
         val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
    +    // Record previous weight and current one to calculate solution vector difference
    +
    +    var previousWeights: Option[Vector] = None
    +    var currentWeights: Option[Vector] = None
     
         val numExamples = data.count()
     
    @@ -179,9 +213,11 @@ object GradientDescent extends Logging {
          * if it's L2 updater; for L1 updater, the same logic is followed.
          */
         var regVal = updater.compute(
    -      weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
    +      weights, Vectors.zeros(weights.size), 0, 1, regParam)._2
     
    -    for (i <- 1 to numIterations) {
    +    var converged = false // indicates whether converged based on convergenceTol
    +    var i = 1
    +    while (!converged && i <= numIterations) {
           val bcWeights = data.context.broadcast(weights)
           // Sample a subset (fraction miniBatchFraction) of the total data
           // compute and sum up the subgradients on this subset (this is one map-reduce)
    @@ -204,12 +240,21 @@ object GradientDescent extends Logging {
              */
             stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
             val update = updater.compute(
    -          weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam)
    +          weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble),
    +          stepSize, i, regParam)
             weights = update._1
             regVal = update._2
    +
    +        previousWeights = currentWeights
    +        currentWeights = Some(weights)
    +        if (previousWeights != None && currentWeights != None) {
    +          converged = isConverged(previousWeights.get,
    +            currentWeights.get, convergenceTol)
    +        }
           } else {
             logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero")
           }
    +      i += 1
         }
     
         logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
    @@ -218,4 +263,32 @@ object GradientDescent extends Logging {
         (weights, stochasticLossHistory.toArray)
     
       }
    +
    +  def runMiniBatchSGD(
    +      data: RDD[(Double, Vector)],
    +      gradient: Gradient,
    +      updater: Updater,
    +      stepSize: Double,
    +      numIterations: Int,
    +      regParam: Double,
    +      miniBatchFraction: Double,
    +      initialWeights: Vector): (Vector, Array[Double]) =
    +    GradientDescent.runMiniBatchSGD(data, gradient, updater, stepSize, numIterations,
    +                                    regParam, miniBatchFraction, initialWeights, 0.001)
    +
    +
    +  private def isConverged(
    +      previousWeights: Vector,
    +      currentWeights: Vector,
    +      convergenceTol: Double): Boolean = {
    +    // To compare with convergence tolerance.
    +    val previousBDV = previousWeights.toBreeze.toDenseVector
    +    val currentBDV = currentWeights.toBreeze.toDenseVector
    +
    +    // This represents the difference of updated weights in the iteration.
    +    val solutionVecDiff: Double = norm(previousBDV - currentBDV)
    +
    +    solutionVecDiff < convergenceTol * Math.max(norm(currentBDV), 1.0)
    +  }
    +
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
    index 3ed3a5b9b3843..9f463e0cafb6f 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
    @@ -116,7 +116,8 @@ class L1Updater extends Updater {
         // Apply proximal operator (soft thresholding)
         val shrinkageVal = regParam * thisIterStepSize
         var i = 0
    -    while (i < brzWeights.length) {
    +    val len = brzWeights.length
    +    while (i < len) {
           val wi = brzWeights(i)
           brzWeights(i) = signum(wi) * max(0.0, abs(wi) - shrinkageVal)
           i += 1
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala
    index 354e90f3eeaa6..5e882d4ebb10b 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala
    @@ -23,13 +23,16 @@ import javax.xml.transform.stream.StreamResult
     import org.jpmml.model.JAXBUtil
     
     import org.apache.spark.SparkContext
    +import org.apache.spark.annotation.{DeveloperApi, Experimental}
     import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
     
     /**
    + * :: DeveloperApi ::
      * Export model to the PMML format
      * Predictive Model Markup Language (PMML) is an XML-based file format
      * developed by the Data Mining Group (www.dmg.org).
      */
    +@DeveloperApi
     trait PMMLExportable {
     
       /**
    @@ -41,30 +44,38 @@ trait PMMLExportable {
       }
     
       /**
    +   * :: Experimental ::
        * Export the model to a local file in PMML format
        */
    +  @Experimental
       def toPMML(localPath: String): Unit = {
         toPMML(new StreamResult(new File(localPath)))
       }
     
       /**
    +   * :: Experimental ::
        * Export the model to a directory on a distributed file system in PMML format
        */
    +  @Experimental
       def toPMML(sc: SparkContext, path: String): Unit = {
         val pmml = toPMML()
         sc.parallelize(Array(pmml), 1).saveAsTextFile(path)
       }
     
       /**
    +   * :: Experimental ::
        * Export the model to the OutputStream in PMML format
        */
    +  @Experimental
       def toPMML(outputStream: OutputStream): Unit = {
         toPMML(new StreamResult(outputStream))
       }
     
       /**
    +   * :: Experimental ::
        * Export the model to a String in PMML format
        */
    +  @Experimental
       def toPMML(): String = {
         val writer = new StringWriter
         toPMML(new StreamResult(writer))
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
    index 34b447584e521..622b53a252ac5 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
    @@ -27,10 +27,10 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel
      * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel
      */
     private[mllib] class BinaryClassificationPMMLModelExport(
    -    model : GeneralizedLinearModel, 
    +    model : GeneralizedLinearModel,
         description : String,
         normalizationMethod : RegressionNormalizationMethodType,
    -    threshold: Double) 
    +    threshold: Double)
       extends PMMLModelExport {
     
       populateBinaryClassificationPMML()
    @@ -72,7 +72,7 @@ private[mllib] class BinaryClassificationPMMLModelExport(
                .withUsageType(FieldUsageType.ACTIVE))
              regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
            }
    -       
    +
            // add target field
            val targetField = FieldName.create("target")
            dataDictionary
    @@ -80,9 +80,9 @@ private[mllib] class BinaryClassificationPMMLModelExport(
            miningSchema
              .withMiningFields(new MiningField(targetField)
              .withUsageType(FieldUsageType.TARGET))
    -       
    +
            dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
    -       
    +
            pmml.setDataDictionary(dataDictionary)
            pmml.withModels(regressionModel)
          }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
    index ebdeae50bb32f..c5fdecd3ca17f 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
    @@ -25,7 +25,7 @@ import scala.beans.BeanProperty
     import org.dmg.pmml.{Application, Header, PMML, Timestamp}
     
     private[mllib] trait PMMLModelExport {
    -  
    +
       /**
        * Holder of the exported model in PMML format
        */
    @@ -33,7 +33,7 @@ private[mllib] trait PMMLModelExport {
       val pmml: PMML = new PMML
     
       setHeader(pmml)
    -  
    +
       private def setHeader(pmml: PMML): Unit = {
         val version = getClass.getPackage.getImplementationVersion
         val app = new Application().withName("Apache Spark MLlib").withVersion(version)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
    index c16e83d6a067d..29bd689e1185a 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
    @@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.LinearRegressionModel
     import org.apache.spark.mllib.regression.RidgeRegressionModel
     
     private[mllib] object PMMLModelExportFactory {
    -  
    +
       /**
    -   * Factory object to help creating the necessary PMMLModelExport implementation 
    +   * Factory object to help creating the necessary PMMLModelExport implementation
        * taking as input the machine learning model (for example KMeansModel).
        */
       def createPMMLModelExport(model: Any): PMMLModelExport = {
    @@ -44,7 +44,7 @@ private[mllib] object PMMLModelExportFactory {
             new GeneralizedLinearPMMLModelExport(lasso, "lasso regression")
           case svm: SVMModel =>
             new BinaryClassificationPMMLModelExport(
    -          svm, "linear SVM", RegressionNormalizationMethodType.NONE, 
    +          svm, "linear SVM", RegressionNormalizationMethodType.NONE,
               svm.getThreshold.getOrElse(0.0))
           case logistic: LogisticRegressionModel =>
             if (logistic.numClasses == 2) {
    @@ -60,5 +60,5 @@ private[mllib] object PMMLModelExportFactory {
               "PMML Export not supported for model: " + model.getClass.getName)
         }
       }
    -  
    +
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
    index 8341bb86afd71..174d5e0f6c9f0 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
    @@ -52,7 +52,7 @@ object RandomRDDs {
           numPartitions: Int = 0,
           seed: Long = Utils.random.nextLong()): RDD[Double] = {
         val uniform = new UniformGenerator()
    -    randomRDD(sc, uniform,  size, numPartitionsOrDefault(sc, numPartitions), seed)
    +    randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed)
       }
     
       /**
    @@ -234,7 +234,7 @@ object RandomRDDs {
        *
        * @param sc SparkContext used to create the RDD.
        * @param shape shape parameter (> 0) for the gamma distribution
    -   * @param scale scale parameter (> 0) for the gamma distribution  
    +   * @param scale scale parameter (> 0) for the gamma distribution
        * @param size Size of the RDD.
        * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
        * @param seed Random seed (default: a random long integer).
    @@ -293,7 +293,7 @@ object RandomRDDs {
        *
        * @param sc SparkContext used to create the RDD.
        * @param mean mean for the log normal distribution
    -   * @param std standard deviation for the log normal distribution  
    +   * @param std standard deviation for the log normal distribution
        * @param size Size of the RDD.
        * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
        * @param seed Random seed (default: a random long integer).
    @@ -671,7 +671,7 @@ object RandomRDDs {
        *
        * @param sc SparkContext used to create the RDD.
        * @param shape shape parameter (> 0) for the gamma distribution.
    -   * @param scale scale parameter (> 0) for the gamma distribution. 
    +   * @param scale scale parameter (> 0) for the gamma distribution.
        * @param numRows Number of Vectors in the RDD.
        * @param numCols Number of elements in each Vector.
        * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
    index dddefe1944e9d..93290e6508529 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
    @@ -175,7 +175,7 @@ class ALS private (
       /**
        * :: DeveloperApi ::
        * Sets storage level for final RDDs (user/product used in MatrixFactorizationModel). The default
    -   * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g. 
    +   * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g.
        * `MEMORY_AND_DISK_SER` and set `spark.rdd.compress` to `true` to reduce the space requirement,
        * at the cost of speed.
        */
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
    index 88c2148403313..43d219a49cf4e 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
    @@ -22,6 +22,7 @@ import java.lang.{Integer => JavaInteger}
     
     import scala.collection.mutable
     
    +import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
     import com.github.fommil.netlib.BLAS.{getInstance => blas}
     import org.apache.hadoop.fs.Path
     import org.json4s._
    @@ -79,6 +80,30 @@ class MatrixFactorizationModel(
         blas.ddot(rank, userVector, 1, productVector, 1)
       }
     
    +  /**
    +   * Return approximate numbers of users and products in the given usersProducts tuples.
    +   * This method is based on `countApproxDistinct` in class `RDD`.
    +   *
    +   * @param usersProducts  RDD of (user, product) pairs.
    +   * @return approximate numbers of users and products.
    +   */
    +  private[this] def countApproxDistinctUserProduct(usersProducts: RDD[(Int, Int)]): (Long, Long) = {
    +    val zeroCounterUser = new HyperLogLogPlus(4, 0)
    +    val zeroCounterProduct = new HyperLogLogPlus(4, 0)
    +    val aggregated = usersProducts.aggregate((zeroCounterUser, zeroCounterProduct))(
    +      (hllTuple: (HyperLogLogPlus, HyperLogLogPlus), v: (Int, Int)) => {
    +        hllTuple._1.offer(v._1)
    +        hllTuple._2.offer(v._2)
    +        hllTuple
    +      },
    +      (h1: (HyperLogLogPlus, HyperLogLogPlus), h2: (HyperLogLogPlus, HyperLogLogPlus)) => {
    +        h1._1.addAll(h2._1)
    +        h1._2.addAll(h2._2)
    +        h1
    +      })
    +    (aggregated._1.cardinality(), aggregated._2.cardinality())
    +  }
    +
       /**
        * Predict the rating of many users for many products.
        * The output RDD has an element per each element in the input RDD (including all duplicates)
    @@ -88,12 +113,30 @@ class MatrixFactorizationModel(
        * @return RDD of Ratings.
        */
       def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = {
    -    val users = userFeatures.join(usersProducts).map {
    -      case (user, (uFeatures, product)) => (product, (user, uFeatures))
    -    }
    -    users.join(productFeatures).map {
    -      case (product, ((user, uFeatures), pFeatures)) =>
    -        Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
    +    // Previously the partitions of ratings are only based on the given products.
    +    // So if the usersProducts given for prediction contains only few products or
    +    // even one product, the generated ratings will be pushed into few or single partition
    +    // and can't use high parallelism.
    +    // Here we calculate approximate numbers of users and products. Then we decide the
    +    // partitions should be based on users or products.
    +    val (usersCount, productsCount) = countApproxDistinctUserProduct(usersProducts)
    +
    +    if (usersCount < productsCount) {
    +      val users = userFeatures.join(usersProducts).map {
    +        case (user, (uFeatures, product)) => (product, (user, uFeatures))
    +      }
    +      users.join(productFeatures).map {
    +        case (product, ((user, uFeatures), pFeatures)) =>
    +          Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
    +      }
    +    } else {
    +      val products = productFeatures.join(usersProducts.map(_.swap)).map {
    +        case (product, (pFeatures, user)) => (user, (product, pFeatures))
    +      }
    +      products.join(userFeatures).map {
    +        case (user, ((product, pFeatures), uFeatures)) =>
    +          Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
    +      }
         }
       }
     
    @@ -281,8 +324,8 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
           val metadata = compact(render(
             ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
           sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
    -      model.userFeatures.toDF("id", "features").saveAsParquetFile(userPath(path))
    -      model.productFeatures.toDF("id", "features").saveAsParquetFile(productPath(path))
    +      model.userFeatures.toDF("id", "features").write.parquet(userPath(path))
    +      model.productFeatures.toDF("id", "features").write.parquet(productPath(path))
         }
     
         def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
    @@ -292,11 +335,11 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
           assert(className == thisClassName)
           assert(formatVersion == thisFormatVersion)
           val rank = (metadata \ "rank").extract[Int]
    -      val userFeatures = sqlContext.parquetFile(userPath(path))
    +      val userFeatures = sqlContext.read.parquet(userPath(path))
             .map { case Row(id: Int, features: Seq[_]) =>
               (id, features.asInstanceOf[Seq[Double]].toArray)
             }
    -      val productFeatures = sqlContext.parquetFile(productPath(path))
    +      val productFeatures = sqlContext.read.parquet(productPath(path))
             .map { case Row(id: Int, features: Seq[_]) =>
             (id, features.asInstanceOf[Seq[Double]].toArray)
           }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
    index 26be30ff9d6fd..6709bd79bc820 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
    @@ -195,11 +195,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
          */
         val initialWeights = {
           if (numOfLinearPredictor == 1) {
    -        Vectors.dense(new Array[Double](numFeatures))
    +        Vectors.zeros(numFeatures)
           } else if (addIntercept) {
    -        Vectors.dense(new Array[Double]((numFeatures + 1) * numOfLinearPredictor))
    +        Vectors.zeros((numFeatures + 1) * numOfLinearPredictor)
           } else {
    -        Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor))
    +        Vectors.zeros(numFeatures * numOfLinearPredictor)
           }
         }
         run(input, initialWeights)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
    index be2a00c2dfea4..f3b46c75c05f3 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
    @@ -69,7 +69,8 @@ class IsotonicRegressionModel (
       /** Asserts the input array is monotone with the given ordering. */
       private def assertOrdered(xs: Array[Double])(implicit ord: Ordering[Double]): Unit = {
         var i = 1
    -    while (i < xs.length) {
    +    val len = xs.length
    +    while (i < len) {
           require(ord.compare(xs(i - 1), xs(i)) <= 0,
             s"Elements (${xs(i - 1)}, ${xs(i)}) are not ordered.")
           i += 1
    @@ -169,26 +170,26 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
         case class Data(boundary: Double, prediction: Double)
     
         def save(
    -        sc: SparkContext, 
    -        path: String, 
    -        boundaries: Array[Double], 
    -        predictions: Array[Double], 
    +        sc: SparkContext,
    +        path: String,
    +        boundaries: Array[Double],
    +        predictions: Array[Double],
             isotonic: Boolean): Unit = {
           val sqlContext = new SQLContext(sc)
     
           val metadata = compact(render(
    -        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ 
    +        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
               ("isotonic" -> isotonic)))
           sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
     
           sqlContext.createDataFrame(
             boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) }
    -      ).saveAsParquetFile(dataPath(path))
    +      ).write.parquet(dataPath(path))
         }
     
         def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
           val sqlContext = new SQLContext(sc)
    -      val dataRDD = sqlContext.parquetFile(dataPath(path))
    +      val dataRDD = sqlContext.read.parquet(dataPath(path))
     
           checkSchema[Data](dataRDD.schema)
           val dataArray = dataRDD.select("boundary", "prediction").collect()
    @@ -202,7 +203,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
       override def load(sc: SparkContext, path: String): IsotonicRegressionModel = {
         implicit val formats = DefaultFormats
         val (loadedClassName, version, metadata) = loadMetadata(sc, path)
    -    val isotonic =  (metadata \ "isotonic").extract[Boolean]
    +    val isotonic = (metadata \ "isotonic").extract[Boolean]
         val classNameV1_0 = SaveLoadV1_0.thisClassName
         (loadedClassName, version) match {
           case (className, "1.0") if className == classNameV1_0 =>
    @@ -329,11 +330,12 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
         }
     
         var i = 0
    -    while (i < input.length) {
    +    val len = input.length
    +    while (i < len) {
           var j = i
     
           // Find monotonicity violating sequence, if any.
    -      while (j < input.length - 1 && input(j)._1 > input(j + 1)._1) {
    +      while (j < len - 1 && input(j)._1 > input(j + 1)._1) {
             j = j + 1
           }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
    index e0c03d8180c7a..7d28ffad45c92 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
    @@ -73,7 +73,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] {
     
     /**
      * Train a regression model with L2-regularization using Stochastic Gradient Descent.
    - * This solves the l1-regularized least squares regression formulation
    + * This solves the l2-regularized least squares regression formulation
      *          f(weights) = 1/2n ||A weights-y||^2^  + regParam/2 ||weights||^2^
      * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
      * its corresponding right hand side label y.
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
    index cea8f3f47307b..141052ba813ee 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
    @@ -83,21 +83,15 @@ abstract class StreamingLinearAlgorithm[
           throw new IllegalArgumentException("Model must be initialized before starting training.")
         }
         data.foreachRDD { (rdd, time) =>
    -      val initialWeights =
    -        model match {
    -          case Some(m) =>
    -            m.weights
    -          case None =>
    -            val numFeatures = rdd.first().features.size
    -            Vectors.dense(numFeatures)
    +      if (!rdd.isEmpty) {
    +        model = Some(algorithm.run(rdd, model.get.weights))
    +        logInfo(s"Model updated at time ${time.toString}")
    +        val display = model.get.weights.size match {
    +          case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
    +          case _ => model.get.weights.toArray.mkString("[", ",", "]")
             }
    -      model = Some(algorithm.run(rdd, initialWeights))
    -      logInfo("Model updated at time %s".format(time.toString))
    -      val display = model.get.weights.size match {
    -        case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
    -        case _ => model.get.weights.toArray.mkString("[", ",", "]")
    +        logInfo(s"Current model: weights, ${display}")
           }
    -      logInfo("Current model: weights, %s".format (display))
         }
       }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
    index a49153bf73c0d..c6d04464a12ba 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
    @@ -79,10 +79,16 @@ class StreamingLinearRegressionWithSGD private[mllib] (
         this
       }
     
    -  /** Set the initial weights. Default: [0.0, 0.0]. */
    +  /** Set the initial weights. */
       def setInitialWeights(initialWeights: Vector): this.type = {
         this.model = Some(algorithm.createModel(initialWeights, 0.0))
         this
       }
     
    +  /** Set the convergence tolerance. */
    +  def setConvergenceTol(tolerance: Double): this.type = {
    +    this.algorithm.optimizer.setConvergenceTol(tolerance)
    +    this
    +  }
    +
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
    index b55944f74f623..317d3a5702636 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
    @@ -60,7 +60,7 @@ private[regression] object GLMRegressionModel {
           val data = Data(weights, intercept)
           val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
           // TODO: repartition with 1 partition after SPARK-5532 gets fixed
    -      dataRDD.saveAsParquetFile(Loader.dataPath(path))
    +      dataRDD.write.parquet(Loader.dataPath(path))
         }
     
         /**
    @@ -72,7 +72,7 @@ private[regression] object GLMRegressionModel {
         def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = {
           val datapath = Loader.dataPath(path)
           val sqlContext = new SQLContext(sc)
    -      val dataRDD = sqlContext.parquetFile(datapath)
    +      val dataRDD = sqlContext.read.parquet(datapath)
           val dataArray = dataRDD.select("weights", "intercept").take(1)
           assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
           val data = dataArray(0)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
    index 79747cc5d7d74..58a50f9c19f14 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
    @@ -17,52 +17,101 @@
     
     package org.apache.spark.mllib.stat
     
    +import com.github.fommil.netlib.BLAS.{getInstance => blas}
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.api.java.JavaRDD
     import org.apache.spark.rdd.RDD
     
    -private[stat] object KernelDensity {
    +/**
    + * :: Experimental ::
    + * Kernel density estimation. Given a sample from a population, estimate its probability density
    + * function at each of the given evaluation points using kernels. Only Gaussian kernel is supported.
    + *
    + * Scala example:
    + *
    + * {{{
    + * val sample = sc.parallelize(Seq(0.0, 1.0, 4.0, 4.0))
    + * val kd = new KernelDensity()
    + *   .setSample(sample)
    + *   .setBandwidth(3.0)
    + * val densities = kd.estimate(Array(-1.0, 2.0, 5.0))
    + * }}}
    + */
    +@Experimental
    +class KernelDensity extends Serializable {
    +
    +  import KernelDensity._
    +
    +  /** Bandwidth of the kernel function. */
    +  private var bandwidth: Double = 1.0
    +
    +  /** A sample from a population. */
    +  private var sample: RDD[Double] = _
    +
       /**
    -   * Given a set of samples from a distribution, estimates its density at the set of given points.
    -   * Uses a Gaussian kernel with the given standard deviation.
    +   * Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`).
        */
    -  def estimate(samples: RDD[Double], standardDeviation: Double,
    -      evaluationPoints: Array[Double]): Array[Double] = {
    -    if (standardDeviation <= 0.0) {
    -      throw new IllegalArgumentException("Standard deviation must be positive")
    -    }
    +  def setBandwidth(bandwidth: Double): this.type = {
    +    require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.")
    +    this.bandwidth = bandwidth
    +    this
    +  }
     
    -    // This gets used in each Gaussian PDF computation, so compute it up front
    -    val logStandardDeviationPlusHalfLog2Pi =
    -      math.log(standardDeviation) + 0.5 * math.log(2 * math.Pi)
    +  /**
    +   * Sets the sample to use for density estimation.
    +   */
    +  def setSample(sample: RDD[Double]): this.type = {
    +    this.sample = sample
    +    this
    +  }
    +
    +  /**
    +   * Sets the sample to use for density estimation (for Java users).
    +   */
    +  def setSample(sample: JavaRDD[java.lang.Double]): this.type = {
    +    this.sample = sample.rdd.asInstanceOf[RDD[Double]]
    +    this
    +  }
    +
    +  /**
    +   * Estimates probability density function at the given array of points.
    +   */
    +  def estimate(points: Array[Double]): Array[Double] = {
    +    val sample = this.sample
    +    val bandwidth = this.bandwidth
    +
    +    require(sample != null, "Must set sample before calling estimate.")
     
    -    val (points, count) = samples.aggregate((new Array[Double](evaluationPoints.length), 0))(
    +    val n = points.length
    +    // This gets used in each Gaussian PDF computation, so compute it up front
    +    val logStandardDeviationPlusHalfLog2Pi = math.log(bandwidth) + 0.5 * math.log(2 * math.Pi)
    +    val (densities, count) = sample.aggregate((new Array[Double](n), 0L))(
           (x, y) => {
             var i = 0
    -        while (i < evaluationPoints.length) {
    -          x._1(i) += normPdf(y, standardDeviation, logStandardDeviationPlusHalfLog2Pi,
    -            evaluationPoints(i))
    +        while (i < n) {
    +          x._1(i) += normPdf(y, bandwidth, logStandardDeviationPlusHalfLog2Pi, points(i))
               i += 1
             }
    -        (x._1, i)
    +        (x._1, x._2 + 1)
           },
           (x, y) => {
    -        var i = 0
    -        while (i < evaluationPoints.length) {
    -          x._1(i) += y._1(i)
    -          i += 1
    -        }
    +        blas.daxpy(n, 1.0, y._1, 1, x._1, 1)
             (x._1, x._2 + y._2)
           })
    -
    -    var i = 0
    -    while (i < points.length) {
    -      points(i) /= count
    -      i += 1
    -    }
    -    points
    +    blas.dscal(n, 1.0 / count, densities, 1)
    +    densities
       }
    +}
    +
    +private object KernelDensity {
     
    -  private def normPdf(mean: Double, standardDeviation: Double,
    -      logStandardDeviationPlusHalfLog2Pi: Double, x: Double): Double = {
    +  /** Evaluates the PDF of a normal distribution. */
    +  def normPdf(
    +      mean: Double,
    +      standardDeviation: Double,
    +      logStandardDeviationPlusHalfLog2Pi: Double,
    +      x: Double): Double = {
         val x0 = x - mean
         val x1 = x0 / standardDeviation
         val logDensity = -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
    index fcc2a148791bd..d321cc554c1cc 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
    @@ -70,23 +70,30 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
         require(n == sample.size, s"Dimensions mismatch when adding new sample." +
           s" Expecting $n but got ${sample.size}.")
     
    +    val localCurrMean = currMean
    +    val localCurrM2n = currM2n
    +    val localCurrM2 = currM2
    +    val localCurrL1 = currL1
    +    val localNnz = nnz
    +    val localCurrMax = currMax
    +    val localCurrMin = currMin
         sample.foreachActive { (index, value) =>
           if (value != 0.0) {
    -        if (currMax(index) < value) {
    -          currMax(index) = value
    +        if (localCurrMax(index) < value) {
    +          localCurrMax(index) = value
             }
    -        if (currMin(index) > value) {
    -          currMin(index) = value
    +        if (localCurrMin(index) > value) {
    +          localCurrMin(index) = value
             }
     
    -        val prevMean = currMean(index)
    +        val prevMean = localCurrMean(index)
             val diff = value - prevMean
    -        currMean(index) = prevMean + diff / (nnz(index) + 1.0)
    -        currM2n(index) += (value - currMean(index)) * diff
    -        currM2(index) += value * value
    -        currL1(index) += math.abs(value)
    +        localCurrMean(index) = prevMean + diff / (localNnz(index) + 1.0)
    +        localCurrM2n(index) += (value - localCurrMean(index)) * diff
    +        localCurrM2(index) += value * value
    +        localCurrL1(index) += math.abs(value)
     
    -        nnz(index) += 1.0
    +        localNnz(index) += 1.0
           }
         }
     
    @@ -130,14 +137,14 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
           }
         } else if (totalCnt == 0 && other.totalCnt != 0) {
           this.n = other.n
    -      this.currMean = other.currMean.clone
    -      this.currM2n = other.currM2n.clone
    -      this.currM2 = other.currM2.clone
    -      this.currL1 = other.currL1.clone
    +      this.currMean = other.currMean.clone()
    +      this.currM2n = other.currM2n.clone()
    +      this.currM2 = other.currM2.clone()
    +      this.currL1 = other.currL1.clone()
           this.totalCnt = other.totalCnt
    -      this.nnz = other.nnz.clone
    -      this.currMax = other.currMax.clone
    -      this.currMin = other.currMin.clone
    +      this.nnz = other.nnz.clone()
    +      this.currMax = other.currMax.clone()
    +      this.currMin = other.currMin.clone()
         }
         this
       }
    @@ -165,7 +172,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
         if (denominator > 0.0) {
           val deltaMean = currMean
           var i = 0
    -      while (i < currM2n.size) {
    +      val len = currM2n.length
    +      while (i < len) {
             realVariance(i) =
               currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
             realVariance(i) /= denominator
    @@ -211,7 +219,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
         val realMagnitude = Array.ofDim[Double](n)
     
         var i = 0
    -    while (i < currM2.size) {
    +    val len = currM2.length
    +    while (i < len) {
           realMagnitude(i) = math.sqrt(currM2(i))
           i += 1
         }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
    index 32561620ac914..90332028cfb3a 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
    @@ -17,12 +17,16 @@
     
     package org.apache.spark.mllib.stat
     
    +import scala.annotation.varargs
    +
     import org.apache.spark.annotation.Experimental
    +import org.apache.spark.api.java.JavaRDD
     import org.apache.spark.mllib.linalg.distributed.RowMatrix
     import org.apache.spark.mllib.linalg.{Matrix, Vector}
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.stat.correlation.Correlations
    -import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult}
    +import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult, KolmogorovSmirnovTest,
    +  KolmogorovSmirnovTestResult}
     import org.apache.spark.rdd.RDD
     
     /**
    @@ -80,6 +84,10 @@ object Statistics {
        */
       def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y)
     
    +  /** Java-friendly version of [[corr()]] */
    +  def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double =
    +    corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]])
    +
       /**
        * Compute the correlation for the input RDDs using the specified method.
        * Methods currently supported: `pearson` (default), `spearman`.
    @@ -96,6 +104,10 @@ object Statistics {
        */
       def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)
     
    +  /** Java-friendly version of [[corr()]] */
    +  def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double =
    +    corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method)
    +
       /**
        * Conduct Pearson's chi-squared goodness of fit test of the observed data against the
        * expected distribution.
    @@ -151,16 +163,37 @@ object Statistics {
       }
     
       /**
    -   * Given an empirical distribution defined by the input RDD of samples, estimate its density at
    -   * each of the given evaluation points using a Gaussian kernel.
    +   * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a
    +   * continuous distribution. By comparing the largest difference between the empirical cumulative
    +   * distribution of the sample data and the theoretical distribution we can provide a test for the
    +   * the null hypothesis that the sample data comes from that theoretical distribution.
    +   * For more information on KS Test:
    +   * @see [[https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test]]
        *
    -   * @param samples The samples RDD used to define the empirical distribution.
    -   * @param standardDeviation The standard deviation of the kernel Gaussians.
    -   * @param evaluationPoints The points at which to estimate densities.
    -   * @return An array the same size as evaluationPoints with the density at each point.
    +   * @param data an `RDD[Double]` containing the sample of data to test
    +   * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value
    +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test
    +   *        statistic, p-value, and null hypothesis.
    +   */
    +  def kolmogorovSmirnovTest(data: RDD[Double], cdf: Double => Double)
    +    : KolmogorovSmirnovTestResult = {
    +    KolmogorovSmirnovTest.testOneSample(data, cdf)
    +  }
    +
    +  /**
    +   * Convenience function to conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability
    +   * distribution equality. Currently supports the normal distribution, taking as parameters
    +   * the mean and standard deviation.
    +   * (distName = "norm")
    +   * @param data an `RDD[Double]` containing the sample of data to test
    +   * @param distName a `String` name for a theoretical distribution
    +   * @param params `Double*` specifying the parameters to be used for the theoretical distribution
    +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test
    +   *        statistic, p-value, and null hypothesis.
        */
    -  def kernelDensity(samples: RDD[Double], standardDeviation: Double,
    -      evaluationPoints: Iterable[Double]): Array[Double] = {
    -    KernelDensity.estimate(samples, standardDeviation, evaluationPoints.toArray)
    +  @varargs
    +  def kolmogorovSmirnovTest(data: RDD[Double], distName: String, params: Double*)
    +    : KolmogorovSmirnovTestResult = {
    +    KolmogorovSmirnovTest.testOneSample(data, distName, params: _*)
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
    index cd6add9d60b0d..cf51b24ff777f 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
    @@ -29,102 +29,102 @@ import org.apache.spark.mllib.util.MLUtils
      * the event that the covariance matrix is singular, the density will be computed in a
      * reduced dimensional subspace under which the distribution is supported.
      * (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]])
    - * 
    + *
      * @param mu The mean vector of the distribution
      * @param sigma The covariance matrix of the distribution
      */
     @DeveloperApi
     class MultivariateGaussian (
    -    val mu: Vector, 
    +    val mu: Vector,
         val sigma: Matrix) extends Serializable {
     
       require(sigma.numCols == sigma.numRows, "Covariance matrix must be square")
       require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size")
    -  
    +
       private val breezeMu = mu.toBreeze.toDenseVector
    -  
    +
       /**
        * private[mllib] constructor
    -   * 
    +   *
        * @param mu The mean vector of the distribution
        * @param sigma The covariance matrix of the distribution
        */
       private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = {
         this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma))
       }
    -  
    +
       /**
        * Compute distribution dependent constants:
        *    rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t
    -   *    u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) 
    +   *    u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
        */
       private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
    -  
    +
       /** Returns density of this multivariate Gaussian at given point, x */
       def pdf(x: Vector): Double = {
         pdf(x.toBreeze)
       }
    -  
    +
       /** Returns the log-density of this multivariate Gaussian at given point, x */
       def logpdf(x: Vector): Double = {
         logpdf(x.toBreeze)
       }
    -  
    +
       /** Returns density of this multivariate Gaussian at given point, x */
       private[mllib] def pdf(x: BV[Double]): Double = {
         math.exp(logpdf(x))
       }
    -  
    +
       /** Returns the log-density of this multivariate Gaussian at given point, x */
       private[mllib] def logpdf(x: BV[Double]): Double = {
         val delta = x - breezeMu
         val v = rootSigmaInv * delta
         u + v.t * v * -0.5
       }
    -  
    +
       /**
        * Calculate distribution dependent components used for the density function:
        *    pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu))
        * where k is length of the mean vector.
    -   * 
    -   * We here compute distribution-fixed parts 
    +   *
    +   * We here compute distribution-fixed parts
        *  log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
        * and
        *  D^(-1/2)^ * U, where sigma = U * D * U.t
    -   *  
    +   *
        * Both the determinant and the inverse can be computed from the singular value decomposition
        * of sigma.  Noting that covariance matrices are always symmetric and positive semi-definite,
        * we can use the eigendecomposition. We also do not compute the inverse directly; noting
    -   * that 
    -   * 
    +   * that
    +   *
        *    sigma = U * D * U.t
    -   *    inv(Sigma) = U * inv(D) * U.t 
    +   *    inv(Sigma) = U * inv(D) * U.t
        *               = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U)
    -   * 
    +   *
        * and thus
    -   * 
    +   *
        *    -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U  * (x-mu))^2^
    -   *  
    -   * To guard against singular covariance matrices, this method computes both the 
    +   *
    +   * To guard against singular covariance matrices, this method computes both the
        * pseudo-determinant and the pseudo-inverse (Moore-Penrose).  Singular values are considered
        * to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and
        * relation to the maximum singular value (same tolerance used by, e.g., Octave).
        */
       private def calculateCovarianceConstants: (DBM[Double], Double) = {
         val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t
    -    
    +
         // For numerical stability, values are considered to be non-zero only if they exceed tol.
         // This prevents any inverted value from exceeding (eps * n * max(d))^-1
         val tol = MLUtils.EPSILON * max(d) * d.length
    -    
    +
         try {
           // log(pseudo-determinant) is sum of the logs of all non-zero singular values
           val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum
    -      
    -      // calculate the root-pseudo-inverse of the diagonal matrix of singular values 
    +
    +      // calculate the root-pseudo-inverse of the diagonal matrix of singular values
           // by inverting the square root of all non-zero values
           val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray))
    -    
    +
           (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
         } catch {
           case uex: UnsupportedOperationException =>
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
    index ea82d39b72c03..23c8d7c7c8075 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
    @@ -196,7 +196,7 @@ private[stat] object ChiSqTest extends Logging {
        * Pearson's independence test on the input contingency matrix.
        * TODO: optimize for SparseMatrix when it becomes supported.
        */
    -  def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = {
    +  def chiSquaredMatrix(counts: Matrix, methodName: String = PEARSON.name): ChiSqTestResult = {
         val method = methodFromString(methodName)
         val numRows = counts.numRows
         val numCols = counts.numCols
    @@ -205,8 +205,10 @@ private[stat] object ChiSqTest extends Logging {
         val colSums = new Array[Double](numCols)
         val rowSums = new Array[Double](numRows)
         val colMajorArr = counts.toArray
    +    val colMajorArrLen = colMajorArr.length
    +
         var i = 0
    -    while (i < colMajorArr.size) {
    +    while (i < colMajorArrLen) {
           val elem = colMajorArr(i)
           if (elem < 0.0) {
             throw new IllegalArgumentException("Contingency table cannot contain negative entries.")
    @@ -220,7 +222,7 @@ private[stat] object ChiSqTest extends Logging {
         // second pass to collect statistic
         var statistic = 0.0
         var j = 0
    -    while (j < colMajorArr.size) {
    +    while (j < colMajorArrLen) {
           val col = j / numRows
           val colSum = colSums(col)
           if (colSum == 0.0) {
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
    new file mode 100644
    index 0000000000000..d89b0059d83f3
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
    @@ -0,0 +1,194 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.stat.test
    +
    +import scala.annotation.varargs
    +
    +import org.apache.commons.math3.distribution.{NormalDistribution, RealDistribution}
    +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.rdd.RDD
    +
    +/**
    + * Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a
    + * continuous distribution. By comparing the largest difference between the empirical cumulative
    + * distribution of the sample data and the theoretical distribution we can provide a test for the
    + * the null hypothesis that the sample data comes from that theoretical distribution.
    + * For more information on KS Test:
    + * @see [[https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test]]
    + *
    + * Implementation note: We seek to implement the KS test with a minimal number of distributed
    + * passes. We sort the RDD, and then perform the following operations on a per-partition basis:
    + * calculate an empirical cumulative distribution value for each observation, and a theoretical
    + * cumulative distribution value. We know the latter to be correct, while the former will be off by
    + * a constant (how large the constant is depends on how many values precede it in other partitions).
    + * However, given that this constant simply shifts the empirical CDF upwards, but doesn't
    + * change its shape, and furthermore, that constant is the same within a given partition, we can
    + * pick 2 values in each partition that can potentially resolve to the largest global distance.
    + * Namely, we pick the minimum distance and the maximum distance. Additionally, we keep track of how
    + * many elements are in each partition. Once these three values have been returned for every
    + * partition, we can collect and operate locally. Locally, we can now adjust each distance by the
    + * appropriate constant (the cumulative sum of number of elements in the prior partitions divided by
    + * thedata set size). Finally, we take the maximum absolute value, and this is the statistic.
    + */
    +private[stat] object KolmogorovSmirnovTest extends Logging {
    +
    +  // Null hypothesis for the type of KS test to be included in the result.
    +  object NullHypothesis extends Enumeration {
    +    type NullHypothesis = Value
    +    val OneSampleTwoSided = Value("Sample follows theoretical distribution")
    +  }
    +
    +  /**
    +   * Runs a KS test for 1 set of sample data, comparing it to a theoretical distribution
    +   * @param data `RDD[Double]` data on which to run test
    +   * @param cdf `Double => Double` function to calculate the theoretical CDF
    +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the test
    +   *        results (p-value, statistic, and null hypothesis)
    +   */
    +  def testOneSample(data: RDD[Double], cdf: Double => Double): KolmogorovSmirnovTestResult = {
    +    val n = data.count().toDouble
    +    val localData = data.sortBy(x => x).mapPartitions { part =>
    +      val partDiffs = oneSampleDifferences(part, n, cdf) // local distances
    +      searchOneSampleCandidates(partDiffs) // candidates: local extrema
    +    }.collect()
    +    val ksStat = searchOneSampleStatistic(localData, n) // result: global extreme
    +    evalOneSampleP(ksStat, n.toLong)
    +  }
    +
    +  /**
    +   * Runs a KS test for 1 set of sample data, comparing it to a theoretical distribution
    +   * @param data `RDD[Double]` data on which to run test
    +   * @param distObj `RealDistribution` a theoretical distribution
    +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the test
    +   *        results (p-value, statistic, and null hypothesis)
    +   */
    +  def testOneSample(data: RDD[Double], distObj: RealDistribution): KolmogorovSmirnovTestResult = {
    +    val cdf = (x: Double) => distObj.cumulativeProbability(x)
    +    testOneSample(data, cdf)
    +  }
    +
    +  /**
    +   * Calculate unadjusted distances between the empirical CDF and the theoretical CDF in a
    +   * partition
    +   * @param partData `Iterator[Double]` 1 partition of a sorted RDD
    +   * @param n `Double` the total size of the RDD
    +   * @param cdf `Double => Double` a function the calculates the theoretical CDF of a value
    +   * @return `Iterator[(Double, Double)] `Unadjusted (ie. off by a constant) potential extrema
    +   *        in a partition. The first element corresponds to the (empirical CDF - 1/N) - CDF,
    +   *        the second element corresponds to empirical CDF - CDF.  We can then search the resulting
    +   *        iterator for the minimum of the first and the maximum of the second element, and provide
    +   *        this as a partition's candidate extrema
    +   */
    +  private def oneSampleDifferences(partData: Iterator[Double], n: Double, cdf: Double => Double)
    +    : Iterator[(Double, Double)] = {
    +    // zip data with index (within that partition)
    +    // calculate local (unadjusted) empirical CDF and subtract CDF
    +    partData.zipWithIndex.map { case (v, ix) =>
    +      // dp and dl are later adjusted by constant, when global info is available
    +      val dp = (ix + 1) / n
    +      val dl = ix / n
    +      val cdfVal = cdf(v)
    +      (dl - cdfVal, dp - cdfVal)
    +    }
    +  }
    +
    +  /**
    +   * Search the unadjusted differences in a partition and return the
    +   * two extrema (furthest below and furthest above CDF), along with a count of elements in that
    +   * partition
    +   * @param partDiffs `Iterator[(Double, Double)]` the unadjusted differences between empirical CDF
    +   *                 and CDFin a partition, which come as a tuple of
    +   *                 (empirical CDF - 1/N - CDF, empirical CDF - CDF)
    +   * @return `Iterator[(Double, Double, Double)]` the local extrema and a count of elements
    +   */
    +  private def searchOneSampleCandidates(partDiffs: Iterator[(Double, Double)])
    +    : Iterator[(Double, Double, Double)] = {
    +    val initAcc = (Double.MaxValue, Double.MinValue, 0.0)
    +    val pResults = partDiffs.foldLeft(initAcc) { case ((pMin, pMax, pCt), (dl, dp)) =>
    +      (math.min(pMin, dl), math.max(pMax, dp), pCt + 1)
    +    }
    +    val results = if (pResults == initAcc) Array[(Double, Double, Double)]() else Array(pResults)
    +    results.iterator
    +  }
    +
    +  /**
    +   * Find the global maximum distance between empirical CDF and CDF (i.e. the KS statistic) after
    +   * adjusting local extrema estimates from individual partitions with the amount of elements in
    +   * preceding partitions
    +   * @param localData `Array[(Double, Double, Double)]` A local array containing the collected
    +   *                 results of `searchOneSampleCandidates` across all partitions
    +   * @param n `Double`The size of the RDD
    +   * @return The one-sample Kolmogorov Smirnov Statistic
    +   */
    +  private def searchOneSampleStatistic(localData: Array[(Double, Double, Double)], n: Double)
    +    : Double = {
    +    val initAcc = (Double.MinValue, 0.0)
    +    // adjust differences based on the number of elements preceding it, which should provide
    +    // the correct distance between empirical CDF and CDF
    +    val results = localData.foldLeft(initAcc) { case ((prevMax, prevCt), (minCand, maxCand, ct)) =>
    +      val adjConst = prevCt / n
    +      val dist1 = math.abs(minCand + adjConst)
    +      val dist2 = math.abs(maxCand + adjConst)
    +      val maxVal = Array(prevMax, dist1, dist2).max
    +      (maxVal, prevCt + ct)
    +    }
    +    results._1
    +  }
    +
    +  /**
    +   * A convenience function that allows running the KS test for 1 set of sample data against
    +   * a named distribution
    +   * @param data the sample data that we wish to evaluate
    +   * @param distName the name of the theoretical distribution
    +   * @param params Variable length parameter for distribution's parameters
    +   * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] summarizing the
    +   *        test results (p-value, statistic, and null hypothesis)
    +   */
    +  @varargs
    +  def testOneSample(data: RDD[Double], distName: String, params: Double*)
    +    : KolmogorovSmirnovTestResult = {
    +    val distObj =
    +      distName match {
    +        case "norm" => {
    +          if (params.nonEmpty) {
    +            // parameters are passed, then can only be 2
    +            require(params.length == 2, "Normal distribution requires mean and standard " +
    +              "deviation as parameters")
    +            new NormalDistribution(params(0), params(1))
    +          } else {
    +            // if no parameters passed in initializes to standard normal
    +            logInfo("No parameters specified for normal distribution," +
    +              "initialized to standard normal (i.e. N(0, 1))")
    +            new NormalDistribution(0, 1)
    +          }
    +        }
    +        case  _ => throw new UnsupportedOperationException(s"$distName not yet supported through" +
    +          s" convenience method. Current options are:['norm'].")
    +      }
    +
    +    testOneSample(data, distObj)
    +  }
    +
    +  private def evalOneSampleP(ksStat: Double, n: Long): KolmogorovSmirnovTestResult = {
    +    val pval = 1 - new KolmogorovSmirnovTest().cdf(ksStat, n.toInt)
    +    new KolmogorovSmirnovTestResult(pval, ksStat, NullHypothesis.OneSampleTwoSided.toString)
    +  }
    +}
    +
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
    index 4784f9e947908..f44be13706695 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
    @@ -90,3 +90,20 @@ class ChiSqTestResult private[stat] (override val pValue: Double,
           super.toString
       }
     }
    +
    +/**
    + * :: Experimental ::
    + * Object containing the test results for the Kolmogorov-Smirnov test.
    + */
    +@Experimental
    +class KolmogorovSmirnovTestResult private[stat] (
    +    override val pValue: Double,
    +    override val statistic: Double,
    +    override val nullHypothesis: String) extends TestResult[Int] {
    +
    +  override val degreesOfFreedom = 0
    +
    +  override def toString: String = {
    +    "Kolmogorov-Smirnov test summary:\n" + super.toString
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
    index dfe3a0b6913ef..cecd1fed896d5 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
    @@ -169,7 +169,7 @@ object DecisionTree extends Serializable with Logging {
           numClasses: Int,
           maxBins: Int,
           quantileCalculationStrategy: QuantileStrategy,
    -      categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
    +      categoricalFeaturesInfo: Map[Int, Int]): DecisionTreeModel = {
         val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
           quantileCalculationStrategy, categoricalFeaturesInfo)
         new DecisionTree(strategy).run(input)
    @@ -768,7 +768,7 @@ object DecisionTree extends Serializable with Logging {
        */
       private def calculatePredictImpurity(
           leftImpurityCalculator: ImpurityCalculator,
    -      rightImpurityCalculator: ImpurityCalculator): (Predict, Double) =  {
    +      rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
         val parentNodeAgg = leftImpurityCalculator.copy
         parentNodeAgg.add(rightImpurityCalculator)
         val predict = calculatePredict(parentNodeAgg)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
    index 1f779584dcffd..a835f96d5d0e3 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
    @@ -60,12 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
       def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
         val algo = boostingStrategy.treeStrategy.algo
         algo match {
    -      case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false)
    +      case Regression =>
    +        GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
           case Classification =>
             // Map labels to -1, +1 so binary classification can be treated as regression.
             val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
    -        GradientBoostedTrees.boost(remappedInput,
    -          remappedInput, boostingStrategy, validate=false)
    +        GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
           case _ =>
             throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
         }
    @@ -93,8 +93,8 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
           validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
         val algo = boostingStrategy.treeStrategy.algo
         algo match {
    -      case Regression => GradientBoostedTrees.boost(
    -        input, validationInput, boostingStrategy, validate=true)
    +      case Regression =>
    +        GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true)
           case Classification =>
             // Map labels to -1, +1 so binary classification can be treated as regression.
             val remappedInput = input.map(
    @@ -102,7 +102,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
             val remappedValidationInput = validationInput.map(
               x => new LabeledPoint((x.label * 2) - 1, x.features))
             GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
    -          validate=true)
    +          validate = true)
           case _ =>
             throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
         }
    @@ -270,7 +270,7 @@ object GradientBoostedTrees extends Logging {
         logInfo(s"$timer")
     
         if (persistedInput) input.unpersist()
    -    
    +
         if (validate) {
           new GradientBoostedTreesModel(
             boostingStrategy.treeStrategy.algo,
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
    index 055e60c7d9c95..069959976a188 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
    @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model._
     import org.apache.spark.rdd.RDD
     import org.apache.spark.storage.StorageLevel
     import org.apache.spark.util.Utils
    +import org.apache.spark.util.random.SamplingUtils
     
     /**
      * :: Experimental ::
    @@ -248,7 +249,7 @@ private class RandomForest (
           try {
             nodeIdCache.get.deleteAllCheckpoints()
           } catch {
    -        case e:IOException =>
    +        case e: IOException =>
               logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
           }
         }
    @@ -473,9 +474,8 @@ object RandomForest extends Serializable with Logging {
           val (treeIndex, node) = nodeQueue.head
           // Choose subset of features for node (if subsampling).
           val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
    -        // TODO: Use more efficient subsampling?  (use selection-and-rejection or reservoir)
    -        Some(rng.shuffle(Range(0, metadata.numFeatures).toList)
    -          .take(metadata.numFeaturesPerNode).toArray)
    +        Some(SamplingUtils.reservoirSampleAndCount(Range(0,
    +          metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
           } else {
             None
           }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
    index 60e2ab2bb829e..72eb24c49264a 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
    @@ -111,11 +111,12 @@ private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) {
        * Add the stats from another calculator into this one, modifying and returning this calculator.
        */
       def add(other: ImpurityCalculator): ImpurityCalculator = {
    -    require(stats.size == other.stats.size,
    +    require(stats.length == other.stats.length,
           s"Two ImpurityCalculator instances cannot be added with different counts sizes." +
    -        s"  Sizes are ${stats.size} and ${other.stats.size}.")
    +        s"  Sizes are ${stats.length} and ${other.stats.length}.")
         var i = 0
    -    while (i < other.stats.size) {
    +    val len = other.stats.length
    +    while (i < len) {
           stats(i) += other.stats(i)
           i += 1
         }
    @@ -127,11 +128,12 @@ private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) {
        * calculator.
        */
       def subtract(other: ImpurityCalculator): ImpurityCalculator = {
    -    require(stats.size == other.stats.size,
    +    require(stats.length == other.stats.length,
           s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." +
    -      s"  Sizes are ${stats.size} and ${other.stats.size}.")
    +      s"  Sizes are ${stats.length} and ${other.stats.length}.")
         var i = 0
    -    while (i < other.stats.size) {
    +    val len = other.stats.length
    +    while (i < len) {
           stats(i) -= other.stats(i)
           i += 1
         }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
    index 331af428533de..f2c78bbabff0b 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
    @@ -198,7 +198,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
             val driverMemory = sc.getConf.getOption("spark.driver.memory")
               .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY")))
               .map(Utils.memoryStringToMb)
    -          .getOrElse(512)
    +          .getOrElse(Utils.DEFAULT_DRIVER_MEM_MB)
             if (driverMemory <= memThreshold) {
               logWarning(s"$thisClassName.save() was called, but it may fail because of too little" +
                 s" driver memory (${driverMemory}m)." +
    @@ -223,14 +223,14 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
           val dataRDD: DataFrame = sc.parallelize(nodes)
             .map(NodeData.apply(0, _))
             .toDF()
    -      dataRDD.saveAsParquetFile(Loader.dataPath(path))
    +      dataRDD.write.parquet(Loader.dataPath(path))
         }
     
         def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = {
           val datapath = Loader.dataPath(path)
           val sqlContext = new SQLContext(sc)
           // Load Parquet data.
    -      val dataRDD = sqlContext.parquetFile(datapath)
    +      val dataRDD = sqlContext.read.parquet(datapath)
           // Check schema explicitly since erasure makes it hard to use match-case for checking.
           Loader.checkSchema[NodeData](dataRDD.schema)
           val nodes = dataRDD.map(NodeData.apply)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
    index 431a839817eac..a6d1398fc267b 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
    @@ -83,7 +83,7 @@ class Node (
       def predict(features: Vector) : Double = {
         if (isLeaf) {
           predict.predict
    -    } else{
    +    } else {
           if (split.get.featureType == Continuous) {
             if (features(split.get.feature) <= split.get.threshold) {
               leftNode.get.predict(features)
    @@ -151,9 +151,9 @@ class Node (
               s"(feature ${split.feature} > ${split.threshold})"
             }
             case Categorical => if (left) {
    -          s"(feature ${split.feature} in ${split.categories.mkString("{",",","}")})"
    +          s"(feature ${split.feature} in ${split.categories.mkString("{", ",", "}")})"
             } else {
    -          s"(feature ${split.feature} not in ${split.categories.mkString("{",",","}")})"
    +          s"(feature ${split.feature} not in ${split.categories.mkString("{", ",", "}")})"
             }
           }
         }
    @@ -161,9 +161,9 @@ class Node (
         if (isLeaf) {
           prefix + s"Predict: ${predict.predict}\n"
         } else {
    -      prefix + s"If ${splitToString(split.get, left=true)}\n" +
    +      prefix + s"If ${splitToString(split.get, left = true)}\n" +
             leftNode.get.subtreeToString(indentFactor + 1) +
    -        prefix + s"Else ${splitToString(split.get, left=false)}\n" +
    +        prefix + s"Else ${splitToString(split.get, left = false)}\n" +
             rightNode.get.subtreeToString(indentFactor + 1)
         }
       }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
    index 8341219bfa71c..905c5fb42bd44 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
    @@ -387,7 +387,7 @@ private[tree] object TreeEnsembleModel extends Logging {
             val driverMemory = sc.getConf.getOption("spark.driver.memory")
               .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY")))
               .map(Utils.memoryStringToMb)
    -          .getOrElse(512)
    +          .getOrElse(Utils.DEFAULT_DRIVER_MEM_MB)
             if (driverMemory <= memThreshold) {
               logWarning(s"$className.save() was called, but it may fail because of too little" +
                 s" driver memory (${driverMemory}m)." +
    @@ -414,7 +414,7 @@ private[tree] object TreeEnsembleModel extends Logging {
           val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) =>
             tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node))
           }.toDF()
    -      dataRDD.saveAsParquetFile(Loader.dataPath(path))
    +      dataRDD.write.parquet(Loader.dataPath(path))
         }
     
         /**
    @@ -437,7 +437,7 @@ private[tree] object TreeEnsembleModel extends Logging {
             treeAlgo: String): Array[DecisionTreeModel] = {
           val datapath = Loader.dataPath(path)
           val sqlContext = new SQLContext(sc)
    -      val nodes = sqlContext.parquetFile(datapath).map(NodeData.apply)
    +      val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply)
           val trees = constructTrees(nodes)
           trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo)))
         }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
    index 6eaebaf7dba9f..e6bcff48b022c 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
    @@ -64,8 +64,10 @@ object KMeansDataGenerator {
     
       def main(args: Array[String]) {
         if (args.length < 6) {
    +      // scalastyle:off println
           println("Usage: KMeansGenerator " +
             "      []")
    +      // scalastyle:on println
           System.exit(1)
         }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
    index b1a4517344970..87eeb5db05d26 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
    @@ -107,7 +107,8 @@ object LinearDataGenerator {
     
         x.foreach { v =>
           var i = 0
    -      while (i < v.length) {
    +      val len = v.length
    +      while (i < len) {
             v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
             i += 1
           }
    @@ -152,8 +153,10 @@ object LinearDataGenerator {
     
       def main(args: Array[String]) {
         if (args.length < 2) {
    +      // scalastyle:off println
           println("Usage: LinearDataGenerator " +
             "  [num_examples] [num_features] [num_partitions]")
    +      // scalastyle:on println
           System.exit(1)
         }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
    index 9d802678c4a77..c09cbe69bb971 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
    @@ -64,8 +64,10 @@ object LogisticRegressionDataGenerator {
     
       def main(args: Array[String]) {
         if (args.length != 5) {
    +      // scalastyle:off println
           println("Usage: LogisticRegressionGenerator " +
             "    ")
    +      // scalastyle:on println
           System.exit(1)
         }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
    index 0c5b4f9d04a74..16f430599a515 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
    @@ -55,8 +55,10 @@ import org.apache.spark.rdd.RDD
     object MFDataGenerator {
       def main(args: Array[String]) {
         if (args.length < 2) {
    +      // scalastyle:off println
           println("Usage: MFDataGenerator " +
             "  [m] [n] [rank] [trainSampFact] [noise] [sigma] [test] [testSampFact]")
    +      // scalastyle:on println
           System.exit(1)
         }
     
    @@ -82,8 +84,7 @@ object MFDataGenerator {
         BLAS.gemm(z, A, B, 1.0, fullData)
     
         val df = rank * (m + n - rank)
    -    val sampSize = scala.math.min(scala.math.round(trainSampFact * df),
    -      scala.math.round(.99 * m * n)).toInt
    +    val sampSize = math.min(math.round(trainSampFact * df), math.round(.99 * m * n)).toInt
         val rand = new Random()
         val mn = m * n
         val shuffled = rand.shuffle((0 until mn).toList)
    @@ -102,8 +103,8 @@ object MFDataGenerator {
     
         // optionally generate testing data
         if (test) {
    -      val testSampSize = scala.math
    -        .min(scala.math.round(sampSize * testSampFact),scala.math.round(mn - sampSize)).toInt
    +      val testSampSize = math.min(
    +        math.round(sampSize * testSampFact), math.round(mn - sampSize)).toInt
           val testOmega = shuffled.slice(sampSize, sampSize + testSampSize)
           val testOrdered = testOmega.sortWith(_ < _).toArray
           val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
    index 681f4c618d302..7c5cfa7bd84ce 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
    @@ -82,6 +82,18 @@ object MLUtils {
               val value = indexAndValue(1).toDouble
               (index, value)
             }.unzip
    +
    +        // check if indices are one-based and in ascending order
    +        var previous = -1
    +        var i = 0
    +        val indicesLength = indices.length
    +        while (i < indicesLength) {
    +          val current = indices(i)
    +          require(current > previous, "indices should be one-based and in ascending order" )
    +          previous = current
    +          i += 1
    +        }
    +
             (label, indices.toArray, values.toArray)
           }
     
    @@ -258,14 +270,30 @@ object MLUtils {
        * Returns a new vector with `1.0` (bias) appended to the input vector.
        */
       def appendBias(vector: Vector): Vector = {
    -    val vector1 = vector.toBreeze match {
    -      case dv: BDV[Double] => BDV.vertcat(dv, new BDV[Double](Array(1.0)))
    -      case sv: BSV[Double] => BSV.vertcat(sv, new BSV[Double](Array(0), Array(1.0), 1))
    -      case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
    +    vector match {
    +      case dv: DenseVector =>
    +        val inputValues = dv.values
    +        val inputLength = inputValues.length
    +        val outputValues = Array.ofDim[Double](inputLength + 1)
    +        System.arraycopy(inputValues, 0, outputValues, 0, inputLength)
    +        outputValues(inputLength) = 1.0
    +        Vectors.dense(outputValues)
    +      case sv: SparseVector =>
    +        val inputValues = sv.values
    +        val inputIndices = sv.indices
    +        val inputValuesLength = inputValues.length
    +        val dim = sv.size
    +        val outputValues = Array.ofDim[Double](inputValuesLength + 1)
    +        val outputIndices = Array.ofDim[Int](inputValuesLength + 1)
    +        System.arraycopy(inputValues, 0, outputValues, 0, inputValuesLength)
    +        System.arraycopy(inputIndices, 0, outputIndices, 0, inputValuesLength)
    +        outputValues(inputValuesLength) = 1.0
    +        outputIndices(inputValuesLength) = dim
    +        Vectors.sparse(dim + 1, outputIndices, outputValues)
    +      case _ => throw new IllegalArgumentException(s"Do not support vector type ${vector.getClass}")
         }
    -    Vectors.fromBreeze(vector1)
       }
    - 
    +
       /**
        * Returns the squared Euclidean distance between two vectors. The following formula will be used
        * if it does not introduce too much numerical error:
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala
    index 308f7f3578e21..a841c5caf0142 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala
    @@ -98,6 +98,8 @@ private[mllib] object NumericParser {
             }
           } else if (token == ")") {
             parsing = false
    +      } else if (token.trim.isEmpty){
    +          // ignore whitespaces between delim chars, e.g. ", ["
           } else {
             // expecting a number
             items.append(parseDouble(token))
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
    index a8e30cc9d730c..ad20b7694a779 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
    @@ -37,8 +37,10 @@ object SVMDataGenerator {
     
       def main(args: Array[String]) {
         if (args.length < 2) {
    +      // scalastyle:off println
           println("Usage: SVMGenerator " +
             "  [num_examples] [num_features] [num_partitions]")
    +      // scalastyle:on println
           System.exit(1)
         }
     
    diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
    index 7e7189a2b1d53..f75e024a713ee 100644
    --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
    +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
    @@ -84,7 +84,7 @@ public void logisticRegressionWithSetters() {
           .setThreshold(0.6)
           .setProbabilityCol("myProbability");
         LogisticRegressionModel model = lr.fit(dataset);
    -    LogisticRegression parent = model.parent();
    +    LogisticRegression parent = (LogisticRegression) model.parent();
         assert(parent.getMaxIter() == 10);
         assert(parent.getRegParam() == 1.0);
         assert(parent.getThreshold() == 0.6);
    @@ -110,7 +110,7 @@ public void logisticRegressionWithSetters() {
         // Call fit() with new params, and check as many params as we can.
         LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
             lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
    -    LogisticRegression parent2 = model2.parent();
    +    LogisticRegression parent2 = (LogisticRegression) model2.parent();
         assert(parent2.getMaxIter() == 5);
         assert(parent2.getRegParam() == 0.1);
         assert(parent2.getThreshold() == 0.4);
    diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
    new file mode 100644
    index 0000000000000..d5bd230a957a1
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
    @@ -0,0 +1,80 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature;
    +
    +import com.google.common.collect.Lists;
    +import org.junit.After;
    +import org.junit.Assert;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.Row;
    +import org.apache.spark.sql.RowFactory;
    +import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.types.DataTypes;
    +import org.apache.spark.sql.types.Metadata;
    +import org.apache.spark.sql.types.StructField;
    +import org.apache.spark.sql.types.StructType;
    +
    +public class JavaBucketizerSuite {
    +  private transient JavaSparkContext jsc;
    +  private transient SQLContext jsql;
    +
    +  @Before
    +  public void setUp() {
    +    jsc = new JavaSparkContext("local", "JavaBucketizerSuite");
    +    jsql = new SQLContext(jsc);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    jsc.stop();
    +    jsc = null;
    +  }
    +
    +  @Test
    +  public void bucketizerTest() {
    +    double[] splits = {-0.5, 0.0, 0.5};
    +
    +    JavaRDD data = jsc.parallelize(Lists.newArrayList(
    +      RowFactory.create(-0.5),
    +      RowFactory.create(-0.3),
    +      RowFactory.create(0.0),
    +      RowFactory.create(0.2)
    +    ));
    +    StructType schema = new StructType(new StructField[] {
    +      new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
    +    });
    +    DataFrame dataset = jsql.createDataFrame(data, schema);
    +
    +    Bucketizer bucketizer = new Bucketizer()
    +      .setInputCol("feature")
    +      .setOutputCol("result")
    +      .setSplits(splits);
    +
    +    Row[] result = bucketizer.transform(dataset).select("result").collect();
    +
    +    for (Row r : result) {
    +      double index = r.getDouble(0);
    +      Assert.assertTrue((index >= 0) && (index <= 1));
    +    }
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
    new file mode 100644
    index 0000000000000..845eed61c45c6
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
    @@ -0,0 +1,78 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature;
    +
    +import com.google.common.collect.Lists;
    +import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
    +import org.junit.After;
    +import org.junit.Assert;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.linalg.Vector;
    +import org.apache.spark.mllib.linalg.VectorUDT;
    +import org.apache.spark.mllib.linalg.Vectors;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.Row;
    +import org.apache.spark.sql.RowFactory;
    +import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.types.Metadata;
    +import org.apache.spark.sql.types.StructField;
    +import org.apache.spark.sql.types.StructType;
    +
    +public class JavaDCTSuite {
    +  private transient JavaSparkContext jsc;
    +  private transient SQLContext jsql;
    +
    +  @Before
    +  public void setUp() {
    +    jsc = new JavaSparkContext("local", "JavaDCTSuite");
    +    jsql = new SQLContext(jsc);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    jsc.stop();
    +    jsc = null;
    +  }
    +
    +  @Test
    +  public void javaCompatibilityTest() {
    +    double[] input = new double[] {1D, 2D, 3D, 4D};
    +    JavaRDD data = jsc.parallelize(Lists.newArrayList(
    +      RowFactory.create(Vectors.dense(input))
    +    ));
    +    DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{
    +      new StructField("vec", (new VectorUDT()), false, Metadata.empty())
    +    }));
    +
    +    double[] expectedResult = input.clone();
    +    (new DoubleDCT_1D(input.length)).forward(expectedResult, true);
    +
    +    DCT dct = new DCT()
    +      .setInputCol("vec")
    +      .setOutputCol("resultVec");
    +
    +    Row[] result = dct.transform(dataset).select("resultVec").collect();
    +    Vector resultVec = result[0].getAs("resultVec");
    +
    +    Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6);
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
    index 23463ab5fe848..599e9cfd23ad4 100644
    --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
    +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
    @@ -55,25 +55,30 @@ public void tearDown() {
       @Test
       public void hashingTF() {
         JavaRDD jrdd = jsc.parallelize(Lists.newArrayList(
    -      RowFactory.create(0, "Hi I heard about Spark"),
    -      RowFactory.create(0, "I wish Java could use case classes"),
    -      RowFactory.create(1, "Logistic regression models are neat")
    +      RowFactory.create(0.0, "Hi I heard about Spark"),
    +      RowFactory.create(0.0, "I wish Java could use case classes"),
    +      RowFactory.create(1.0, "Logistic regression models are neat")
         ));
         StructType schema = new StructType(new StructField[]{
           new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
           new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
         });
    -    DataFrame sentenceDataFrame = jsql.createDataFrame(jrdd, schema);
     
    -    Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words");
    -    DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame);
    +    DataFrame sentenceData = jsql.createDataFrame(jrdd, schema);
    +    Tokenizer tokenizer = new Tokenizer()
    +      .setInputCol("sentence")
    +      .setOutputCol("words");
    +    DataFrame wordsData = tokenizer.transform(sentenceData);
         int numFeatures = 20;
         HashingTF hashingTF = new HashingTF()
           .setInputCol("words")
    -      .setOutputCol("features")
    +      .setOutputCol("rawFeatures")
           .setNumFeatures(numFeatures);
    -    DataFrame featurized = hashingTF.transform(wordsDataFrame);
    -    for (Row r : featurized.select("features", "words", "label").take(3)) {
    +    DataFrame featurizedData = hashingTF.transform(wordsData);
    +    IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
    +    IDFModel idfModel = idf.fit(featurizedData);
    +    DataFrame rescaledData = idfModel.transform(featurizedData);
    +    for (Row r : rescaledData.select("features", "label").take(3)) {
           Vector features = r.getAs(0);
           Assert.assertEquals(features.size(), numFeatures);
         }
    diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
    new file mode 100644
    index 0000000000000..d82f3b7e8c076
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
    @@ -0,0 +1,71 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature;
    +
    +import java.util.List;
    +
    +import com.google.common.collect.Lists;
    +import org.junit.After;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.linalg.Vectors;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.SQLContext;
    +
    +public class JavaNormalizerSuite {
    +  private transient JavaSparkContext jsc;
    +  private transient SQLContext jsql;
    +
    +  @Before
    +  public void setUp() {
    +    jsc = new JavaSparkContext("local", "JavaNormalizerSuite");
    +    jsql = new SQLContext(jsc);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    jsc.stop();
    +    jsc = null;
    +  }
    +
    +  @Test
    +  public void normalizer() {
    +    // The tests are to check Java compatibility.
    +    List points = Lists.newArrayList(
    +      new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)),
    +      new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
    +      new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
    +    );
    +    DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
    +      VectorIndexerSuite.FeatureData.class);
    +    Normalizer normalizer = new Normalizer()
    +      .setInputCol("features")
    +      .setOutputCol("normFeatures");
    +
    +    // Normalize each Vector using $L^2$ norm.
    +    DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2));
    +    l2NormData.count();
    +
    +    // Normalize each Vector using $L^\infty$ norm.
    +    DataFrame lInfNormData =
    +      normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY));
    +    lInfNormData.count();
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
    new file mode 100644
    index 0000000000000..5cf43fec6f29e
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
    @@ -0,0 +1,114 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature;
    +
    +import java.io.Serializable;
    +import java.util.List;
    +
    +import scala.Tuple2;
    +
    +import com.google.common.collect.Lists;
    +import org.junit.After;
    +import org.junit.Assert;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.function.Function;
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.linalg.distributed.RowMatrix;
    +import org.apache.spark.mllib.linalg.Matrix;
    +import org.apache.spark.mllib.linalg.Vector;
    +import org.apache.spark.mllib.linalg.Vectors;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.Row;
    +import org.apache.spark.sql.SQLContext;
    +
    +public class JavaPCASuite implements Serializable {
    +  private transient JavaSparkContext jsc;
    +  private transient SQLContext sqlContext;
    +
    +  @Before
    +  public void setUp() {
    +    jsc = new JavaSparkContext("local", "JavaPCASuite");
    +    sqlContext = new SQLContext(jsc);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    jsc.stop();
    +    jsc = null;
    +  }
    +
    +  public static class VectorPair implements Serializable {
    +    private Vector features = Vectors.dense(0.0);
    +    private Vector expected = Vectors.dense(0.0);
    +
    +    public void setFeatures(Vector features) {
    +      this.features = features;
    +    }
    +
    +    public Vector getFeatures() {
    +      return this.features;
    +    }
    +
    +    public void setExpected(Vector expected) {
    +      this.expected = expected;
    +    }
    +
    +    public Vector getExpected() {
    +      return this.expected;
    +    }
    +  }
    +
    +  @Test
    +  public void testPCA() {
    +    List points = Lists.newArrayList(
    +      Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0}),
    +      Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
    +      Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
    +    );
    +    JavaRDD dataRDD = jsc.parallelize(points, 2);
    +
    +    RowMatrix mat = new RowMatrix(dataRDD.rdd());
    +    Matrix pc = mat.computePrincipalComponents(3);
    +    JavaRDD expected = mat.multiply(pc).rows().toJavaRDD();
    +
    +    JavaRDD featuresExpected = dataRDD.zip(expected).map(
    +      new Function, VectorPair>() {
    +        public VectorPair call(Tuple2 pair) {
    +          VectorPair featuresExpected = new VectorPair();
    +          featuresExpected.setFeatures(pair._1());
    +          featuresExpected.setExpected(pair._2());
    +          return featuresExpected;
    +        }
    +      }
    +    );
    +
    +    DataFrame df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
    +    PCAModel pca = new PCA()
    +      .setInputCol("features")
    +      .setOutputCol("pca_features")
    +      .setK(3)
    +      .fit(df);
    +    List result = pca.transform(df).select("pca_features", "expected").toJavaRDD().collect();
    +    for (Row r : result) {
    +      Assert.assertEquals(r.get(1), r.get(0));
    +    }
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
    new file mode 100644
    index 0000000000000..5e8211c2c5118
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
    @@ -0,0 +1,91 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature;
    +
    +import com.google.common.collect.Lists;
    +import org.junit.After;
    +import org.junit.Assert;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.linalg.Vector;
    +import org.apache.spark.mllib.linalg.VectorUDT;
    +import org.apache.spark.mllib.linalg.Vectors;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.Row;
    +import org.apache.spark.sql.RowFactory;
    +import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.types.Metadata;
    +import org.apache.spark.sql.types.StructField;
    +import org.apache.spark.sql.types.StructType;
    +
    +public class JavaPolynomialExpansionSuite {
    +  private transient JavaSparkContext jsc;
    +  private transient SQLContext jsql;
    +
    +  @Before
    +  public void setUp() {
    +    jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite");
    +    jsql = new SQLContext(jsc);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    jsc.stop();
    +    jsc = null;
    +  }
    +
    +  @Test
    +  public void polynomialExpansionTest() {
    +    PolynomialExpansion polyExpansion = new PolynomialExpansion()
    +      .setInputCol("features")
    +      .setOutputCol("polyFeatures")
    +      .setDegree(3);
    +
    +    JavaRDD data = jsc.parallelize(Lists.newArrayList(
    +      RowFactory.create(
    +        Vectors.dense(-2.0, 2.3),
    +        Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)
    +      ),
    +      RowFactory.create(Vectors.dense(0.0, 0.0), Vectors.dense(new double[9])),
    +      RowFactory.create(
    +        Vectors.dense(0.6, -1.1),
    +        Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331)
    +      )
    +    ));
    +
    +    StructType schema = new StructType(new StructField[] {
    +      new StructField("features", new VectorUDT(), false, Metadata.empty()),
    +      new StructField("expected", new VectorUDT(), false, Metadata.empty())
    +    });
    +
    +    DataFrame dataset = jsql.createDataFrame(data, schema);
    +
    +    Row[] pairs = polyExpansion.transform(dataset)
    +      .select("polyFeatures", "expected")
    +      .collect();
    +
    +    for (Row r : pairs) {
    +      double[] polyFeatures = ((Vector)r.get(0)).toArray();
    +      double[] expected = ((Vector)r.get(1)).toArray();
    +      Assert.assertArrayEquals(polyFeatures, expected, 1e-1);
    +    }
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
    new file mode 100644
    index 0000000000000..74eb2733f06ef
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
    @@ -0,0 +1,71 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature;
    +
    +import java.util.List;
    +
    +import com.google.common.collect.Lists;
    +import org.junit.After;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.linalg.Vectors;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.SQLContext;
    +
    +public class JavaStandardScalerSuite {
    +  private transient JavaSparkContext jsc;
    +  private transient SQLContext jsql;
    +
    +  @Before
    +  public void setUp() {
    +    jsc = new JavaSparkContext("local", "JavaStandardScalerSuite");
    +    jsql = new SQLContext(jsc);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    jsc.stop();
    +    jsc = null;
    +  }
    +
    +  @Test
    +  public void standardScaler() {
    +    // The tests are to check Java compatibility.
    +    List points = Lists.newArrayList(
    +      new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)),
    +      new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
    +      new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
    +    );
    +    DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
    +      VectorIndexerSuite.FeatureData.class);
    +    StandardScaler scaler = new StandardScaler()
    +      .setInputCol("features")
    +      .setOutputCol("scaledFeatures")
    +      .setWithStd(true)
    +      .setWithMean(false);
    +
    +    // Compute summary statistics by fitting the StandardScaler
    +    StandardScalerModel scalerModel = scaler.fit(dataFrame);
    +
    +    // Normalize each feature to have unit standard deviation.
    +    DataFrame scaledData = scalerModel.transform(dataFrame);
    +    scaledData.count();
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
    new file mode 100644
    index 0000000000000..35b18c5308f61
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
    @@ -0,0 +1,77 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature;
    +
    +import java.util.Arrays;
    +
    +import org.junit.After;
    +import org.junit.Assert;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.Row;
    +import org.apache.spark.sql.RowFactory;
    +import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.types.StructField;
    +import org.apache.spark.sql.types.StructType;
    +import static org.apache.spark.sql.types.DataTypes.*;
    +
    +public class JavaStringIndexerSuite {
    +  private transient JavaSparkContext jsc;
    +  private transient SQLContext sqlContext;
    +
    +  @Before
    +  public void setUp() {
    +    jsc = new JavaSparkContext("local", "JavaStringIndexerSuite");
    +    sqlContext = new SQLContext(jsc);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    jsc.stop();
    +    sqlContext = null;
    +  }
    +
    +  @Test
    +  public void testStringIndexer() {
    +    StructType schema = createStructType(new StructField[] {
    +      createStructField("id", IntegerType, false),
    +      createStructField("label", StringType, false)
    +    });
    +    JavaRDD rdd = jsc.parallelize(
    +      Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")));
    +    DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
    +
    +    StringIndexer indexer = new StringIndexer()
    +      .setInputCol("label")
    +      .setOutputCol("labelIndex");
    +    DataFrame output = indexer.fit(dataset).transform(dataset);
    +
    +    Assert.assertArrayEquals(
    +      new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) },
    +      output.orderBy("id").select("id", "labelIndex").collect());
    +  }
    +
    +  /** An alias for RowFactory.create. */
    +  private Row c(Object... values) {
    +    return RowFactory.create(values);
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
    new file mode 100644
    index 0000000000000..b7c564caad3bd
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
    @@ -0,0 +1,78 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature;
    +
    +import java.util.Arrays;
    +
    +import org.junit.After;
    +import org.junit.Assert;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.linalg.Vector;
    +import org.apache.spark.mllib.linalg.VectorUDT;
    +import org.apache.spark.mllib.linalg.Vectors;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.Row;
    +import org.apache.spark.sql.RowFactory;
    +import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.types.*;
    +import static org.apache.spark.sql.types.DataTypes.*;
    +
    +public class JavaVectorAssemblerSuite {
    +  private transient JavaSparkContext jsc;
    +  private transient SQLContext sqlContext;
    +
    +  @Before
    +  public void setUp() {
    +    jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite");
    +    sqlContext = new SQLContext(jsc);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    jsc.stop();
    +    jsc = null;
    +  }
    +
    +  @Test
    +  public void testVectorAssembler() {
    +    StructType schema = createStructType(new StructField[] {
    +      createStructField("id", IntegerType, false),
    +      createStructField("x", DoubleType, false),
    +      createStructField("y", new VectorUDT(), false),
    +      createStructField("name", StringType, false),
    +      createStructField("z", new VectorUDT(), false),
    +      createStructField("n", LongType, false)
    +    });
    +    Row row = RowFactory.create(
    +      0, 0.0, Vectors.dense(1.0, 2.0), "a",
    +      Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
    +    JavaRDD rdd = jsc.parallelize(Arrays.asList(row));
    +    DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
    +    VectorAssembler assembler = new VectorAssembler()
    +      .setInputCols(new String[] {"x", "y", "z", "n"})
    +      .setOutputCol("features");
    +    DataFrame output = assembler.transform(dataset);
    +    Assert.assertEquals(
    +      Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}),
    +      output.select("features").first().getAs(0));
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
    index 161100134c92d..c7ae5468b9429 100644
    --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
    +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
    @@ -19,6 +19,7 @@
     
     import java.io.Serializable;
     import java.util.List;
    +import java.util.Map;
     
     import org.junit.After;
     import org.junit.Assert;
    @@ -64,7 +65,8 @@ public void vectorIndexerAPI() {
           .setMaxCategories(2);
         VectorIndexerModel model = indexer.fit(data);
         Assert.assertEquals(model.numFeatures(), 2);
    -    Assert.assertEquals(model.categoryMaps().size(), 1);
    +    Map> categoryMaps = model.javaCategoryMaps();
    +    Assert.assertEquals(categoryMaps.size(), 1);
         DataFrame indexedData = model.transform(data);
       }
     }
    diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
    new file mode 100644
    index 0000000000000..39c70157f83c0
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
    @@ -0,0 +1,76 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature;
    +
    +import com.google.common.collect.Lists;
    +import org.junit.After;
    +import org.junit.Assert;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.linalg.Vector;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.Row;
    +import org.apache.spark.sql.RowFactory;
    +import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.types.*;
    +
    +public class JavaWord2VecSuite {
    +  private transient JavaSparkContext jsc;
    +  private transient SQLContext sqlContext;
    +
    +  @Before
    +  public void setUp() {
    +    jsc = new JavaSparkContext("local", "JavaWord2VecSuite");
    +    sqlContext = new SQLContext(jsc);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    jsc.stop();
    +    jsc = null;
    +  }
    +
    +  @Test
    +  public void testJavaWord2Vec() {
    +    JavaRDD jrdd = jsc.parallelize(Lists.newArrayList(
    +      RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))),
    +      RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))),
    +      RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" ")))
    +    ));
    +    StructType schema = new StructType(new StructField[]{
    +      new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
    +    });
    +    DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema);
    +
    +    Word2Vec word2Vec = new Word2Vec()
    +      .setInputCol("text")
    +      .setOutputCol("result")
    +      .setVectorSize(3)
    +      .setMinCount(0);
    +    Word2VecModel model = word2Vec.fit(documentDF);
    +    DataFrame result = model.transform(documentDF);
    +
    +    for (Row r: result.select("result").collect()) {
    +      double[] polyFeatures = ((Vector)r.get(0)).toArray();
    +      Assert.assertEquals(polyFeatures.length, 3);
    +    }
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
    index e7df10dfa63ac..9890155e9f865 100644
    --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
    +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
    @@ -50,6 +50,7 @@ public void testParams() {
         testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
         Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
         Assert.assertEquals(testParams.getMyStringParam(), "a");
    +    Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0);
       }
     
       @Test
    diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
    index 8abe575610d19..3ae09d39ef500 100644
    --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
    +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
    @@ -21,43 +21,90 @@
     
     import com.google.common.collect.Lists;
     
    +import org.apache.spark.ml.util.Identifiable$;
    +
     /**
      * A subclass of Params for testing.
      */
     public class JavaTestParams extends JavaParams {
     
    -  public IntParam myIntParam;
    +  public JavaTestParams() {
    +    this.uid_ = Identifiable$.MODULE$.randomUID("javaTestParams");
    +    init();
    +  }
    +
    +  public JavaTestParams(String uid) {
    +    this.uid_ = uid;
    +    init();
    +  }
    +
    +  private String uid_;
    +
    +  @Override
    +  public String uid() {
    +    return uid_;
    +  }
    +
    +  private IntParam myIntParam_;
    +  public IntParam myIntParam() { return myIntParam_; }
     
    -  public int getMyIntParam() { return (Integer)getOrDefault(myIntParam); }
    +  public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); }
     
       public JavaTestParams setMyIntParam(int value) {
    -    set(myIntParam, value); return this;
    +    set(myIntParam_, value);
    +    return this;
       }
     
    -  public DoubleParam myDoubleParam;
    +  private DoubleParam myDoubleParam_;
    +  public DoubleParam myDoubleParam() { return myDoubleParam_; }
     
    -  public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); }
    +  public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); }
     
       public JavaTestParams setMyDoubleParam(double value) {
    -    set(myDoubleParam, value); return this;
    +    set(myDoubleParam_, value);
    +    return this;
       }
     
    -  public Param myStringParam;
    +  private Param myStringParam_;
    +  public Param myStringParam() { return myStringParam_; }
     
    -  public String getMyStringParam() { return (String)getOrDefault(myStringParam); }
    +  public String getMyStringParam() { return getOrDefault(myStringParam_); }
     
       public JavaTestParams setMyStringParam(String value) {
    -    set(myStringParam, value); return this;
    +    set(myStringParam_, value);
    +    return this;
       }
     
    -  public JavaTestParams() {
    -    myIntParam = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
    -    myDoubleParam = new DoubleParam(this, "myDoubleParam", "this is a double param",
    +  private DoubleArrayParam myDoubleArrayParam_;
    +  public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; }
    +
    +  public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); }
    +
    +  public JavaTestParams setMyDoubleArrayParam(double[] value) {
    +    set(myDoubleArrayParam_, value);
    +    return this;
    +  }
    +
    +  private void init() {
    +    myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
    +    myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param",
           ParamValidators.inRange(0.0, 1.0));
         List validStrings = Lists.newArrayList("a", "b");
    -    myStringParam = new Param(this, "myStringParam", "this is a string param",
    +    myStringParam_ = new Param(this, "myStringParam", "this is a string param",
           ParamValidators.inArray(validStrings));
    -    setDefault(myIntParam, 1);
    -    setDefault(myDoubleParam, 0.5);
    +    myDoubleArrayParam_ =
    +      new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");
    +
    +    setDefault(myIntParam(), 1);
    +    setDefault(myIntParam().w(1));
    +    setDefault(myDoubleParam(), 0.5);
    +    setDefault(myIntParam().w(1), myDoubleParam().w(0.5));
    +    setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
    +    setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
    +  }
    +
    +  @Override
    +  public JavaTestParams copy(ParamMap extra) {
    +    return defaultCopy(extra);
       }
     }
    diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
    index a82b86d560b6e..d591a456864e4 100644
    --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
    +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
    @@ -77,14 +77,14 @@ public void linearRegressionWithSetters() {
             .setMaxIter(10)
             .setRegParam(1.0);
         LinearRegressionModel model = lr.fit(dataset);
    -    LinearRegression parent = model.parent();
    +    LinearRegression parent = (LinearRegression) model.parent();
         assertEquals(10, parent.getMaxIter());
         assertEquals(1.0, parent.getRegParam(), 0.0);
     
         // Call fit() with new params, and check as many params as we can.
         LinearRegressionModel model2 =
             lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
    -    LinearRegression parent2 = model2.parent();
    +    LinearRegression parent2 = (LinearRegression) model2.parent();
         assertEquals(5, parent2.getMaxIter());
         assertEquals(0.1, parent2.getRegParam(), 0.0);
         assertEquals("thePred", model2.getPredictionCol());
    diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
    new file mode 100644
    index 0000000000000..928301523fba9
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
    @@ -0,0 +1,40 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.util
    +
    +import org.apache.spark.SparkFunSuite
    +
    +class IdentifiableSuite extends SparkFunSuite {
    +
    +  import IdentifiableSuite.Test
    +
    +  test("Identifiable") {
    +    val test0 = new Test("test_0")
    +    assert(test0.uid === "test_0")
    +
    +    val test1 = new Test
    +    assert(test1.uid.startsWith("test_"))
    +  }
    +}
    +
    +object IdentifiableSuite {
    +
    +  class Test(override val uid: String) extends Identifiable {
    +    def this() = this(Identifiable.randomUID("test"))
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
    index 71fb7f13c39c2..3771c0ea7ad83 100644
    --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
    +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
    @@ -108,7 +108,7 @@ public Vector call(LabeledPoint v) throws Exception {
       @Test
       public void testModelTypeSetters() {
         NaiveBayes nb = new NaiveBayes()
    -        .setModelType("Bernoulli")
    -        .setModelType("Multinomial");
    +      .setModelType("bernoulli")
    +      .setModelType("multinomial");
       }
     }
    diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
    similarity index 95%
    rename from mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java
    rename to mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
    index 640d2ec55e4e7..55787f8606d48 100644
    --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java
    +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
    @@ -15,7 +15,7 @@
      * limitations under the License.
      */
     
    -package org.apache.spark.ml.classification;
    +package org.apache.spark.mllib.classification;
     
     import java.io.Serializable;
     import java.util.List;
    @@ -28,7 +28,6 @@
     import org.junit.Test;
     
     import org.apache.spark.SparkConf;
    -import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD;
     import org.apache.spark.mllib.linalg.Vector;
     import org.apache.spark.mllib.linalg.Vectors;
     import org.apache.spark.mllib.regression.LabeledPoint;
    diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
    new file mode 100644
    index 0000000000000..467a7a69e8f30
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
    @@ -0,0 +1,64 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.clustering;
    +
    +import java.io.Serializable;
    +import java.util.List;
    +
    +import com.google.common.collect.Lists;
    +import org.junit.After;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import static org.junit.Assert.assertEquals;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.linalg.Vector;
    +import org.apache.spark.mllib.linalg.Vectors;
    +
    +public class JavaGaussianMixtureSuite implements Serializable {
    +  private transient JavaSparkContext sc;
    +
    +  @Before
    +  public void setUp() {
    +    sc = new JavaSparkContext("local", "JavaGaussianMixture");
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    sc.stop();
    +    sc = null;
    +  }
    +
    +  @Test
    +  public void runGaussianMixture() {
    +    List points = Lists.newArrayList(
    +      Vectors.dense(1.0, 2.0, 6.0),
    +      Vectors.dense(1.0, 3.0, 0.0),
    +      Vectors.dense(1.0, 4.0, 6.0)
    +    );
    +
    +    JavaRDD data = sc.parallelize(points, 2);
    +    GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
    +      .run(data);
    +    assertEquals(model.gaussians().length, 2);
    +    JavaRDD predictions = model.predict(data);
    +    predictions.first();
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
    index 96c2da169961f..b48f190f599a2 100644
    --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
    +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
    @@ -28,12 +28,13 @@
     import org.junit.Before;
     import org.junit.Test;
     
    +import org.apache.spark.api.java.function.Function;
     import org.apache.spark.api.java.JavaPairRDD;
     import org.apache.spark.api.java.JavaRDD;
     import org.apache.spark.api.java.JavaSparkContext;
     import org.apache.spark.mllib.linalg.Matrix;
     import org.apache.spark.mllib.linalg.Vector;
    -
    +import org.apache.spark.mllib.linalg.Vectors;
     
     public class JavaLDASuite implements Serializable {
       private transient JavaSparkContext sc;
    @@ -107,6 +108,18 @@ public void distributedLDAModel() {
         // Check: log probabilities
         assert(model.logLikelihood() < 0.0);
         assert(model.logPrior() < 0.0);
    +
    +    // Check: topic distributions
    +    JavaPairRDD topicDistributions = model.javaTopicDistributions();
    +    // SPARK-5562. since the topicDistribution returns the distribution of the non empty docs
    +    // over topics. Compare it against nonEmptyCorpus instead of corpus
    +    JavaPairRDD nonEmptyCorpus = corpus.filter(
    +      new Function, Boolean>() {
    +        public Boolean call(Tuple2 tuple2) {
    +          return Vectors.norm(tuple2._2(), 1.0) != 0.0;
    +        }
    +    });
    +    assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
       }
     
       @Test
    diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
    new file mode 100644
    index 0000000000000..3b0e879eec77f
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
    @@ -0,0 +1,82 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.clustering;
    +
    +import java.io.Serializable;
    +import java.util.List;
    +
    +import scala.Tuple2;
    +
    +import com.google.common.collect.Lists;
    +import org.junit.After;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import static org.apache.spark.streaming.JavaTestUtils.*;
    +
    +import org.apache.spark.SparkConf;
    +import org.apache.spark.mllib.linalg.Vector;
    +import org.apache.spark.mllib.linalg.Vectors;
    +import org.apache.spark.streaming.Duration;
    +import org.apache.spark.streaming.api.java.JavaDStream;
    +import org.apache.spark.streaming.api.java.JavaPairDStream;
    +import org.apache.spark.streaming.api.java.JavaStreamingContext;
    +
    +public class JavaStreamingKMeansSuite implements Serializable {
    +
    +  protected transient JavaStreamingContext ssc;
    +
    +  @Before
    +  public void setUp() {
    +    SparkConf conf = new SparkConf()
    +      .setMaster("local[2]")
    +      .setAppName("test")
    +      .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
    +    ssc = new JavaStreamingContext(conf, new Duration(1000));
    +    ssc.checkpoint("checkpoint");
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    ssc.stop();
    +    ssc = null;
    +  }
    +
    +  @Test
    +  @SuppressWarnings("unchecked")
    +  public void javaAPI() {
    +    List trainingBatch = Lists.newArrayList(
    +      Vectors.dense(1.0),
    +      Vectors.dense(0.0));
    +    JavaDStream training =
    +      attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2);
    +    List> testBatch = Lists.newArrayList(
    +      new Tuple2(10, Vectors.dense(1.0)),
    +      new Tuple2(11, Vectors.dense(0.0)));
    +    JavaPairDStream test = JavaPairDStream.fromJavaDStream(
    +      attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2));
    +    StreamingKMeans skmeans = new StreamingKMeans()
    +      .setK(1)
    +      .setDecayFactor(1.0)
    +      .setInitialCenters(new Vector[]{Vectors.dense(1.0)}, new double[]{0.0});
    +    skmeans.trainOn(training);
    +    JavaPairDStream prediction = skmeans.predictOnValues(test);
    +    attachTestOutputStream(prediction.count());
    +    runStreams(ssc, 2, 2);
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
    new file mode 100644
    index 0000000000000..b3815ae6039c0
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
    @@ -0,0 +1,58 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.mllib.fpm;
    +
    +import java.io.Serializable;
    +
    +import org.junit.After;
    +import org.junit.Before;
    +import org.junit.Test;
    +import com.google.common.collect.Lists;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
    +
    +
    +public class JavaAssociationRulesSuite implements Serializable {
    +  private transient JavaSparkContext sc;
    +
    +  @Before
    +  public void setUp() {
    +    sc = new JavaSparkContext("local", "JavaFPGrowth");
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    sc.stop();
    +    sc = null;
    +  }
    +
    +  @Test
    +  public void runAssociationRules() {
    +
    +    @SuppressWarnings("unchecked")
    +    JavaRDD> freqItemsets = sc.parallelize(Lists.newArrayList(
    +      new FreqItemset(new String[] {"a"}, 15L),
    +      new FreqItemset(new String[] {"b"}, 35L),
    +      new FreqItemset(new String[] {"a", "b"}, 18L)
    +    ));
    +
    +    JavaRDD> results = (new AssociationRules()).run(freqItemsets);
    +  }
    +}
    +
    diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
    index bd0edf2b9ea62..9ce2c52dca8b6 100644
    --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
    +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
    @@ -29,7 +29,6 @@
     
     import org.apache.spark.api.java.JavaRDD;
     import org.apache.spark.api.java.JavaSparkContext;
    -import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
     
     public class JavaFPGrowthSuite implements Serializable {
       private transient JavaSparkContext sc;
    @@ -62,10 +61,10 @@ public void runFPGrowth() {
           .setNumPartitions(2)
           .run(rdd);
     
    -    List> freqItemsets = model.freqItemsets().toJavaRDD().collect();
    +    List> freqItemsets = model.freqItemsets().toJavaRDD().collect();
         assertEquals(18, freqItemsets.size());
     
    -    for (FreqItemset itemset: freqItemsets) {
    +    for (FPGrowth.FreqItemset itemset: freqItemsets) {
           // Test return types.
           List items = itemset.javaItems();
           long freq = itemset.freq();
    diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
    new file mode 100644
    index 0000000000000..62f7f26b7c98f
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
    @@ -0,0 +1,56 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.stat;
    +
    +import java.io.Serializable;
    +
    +import com.google.common.collect.Lists;
    +import org.junit.After;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import static org.junit.Assert.assertEquals;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +
    +public class JavaStatisticsSuite implements Serializable {
    +  private transient JavaSparkContext sc;
    +
    +  @Before
    +  public void setUp() {
    +    sc = new JavaSparkContext("local", "JavaStatistics");
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    sc.stop();
    +    sc = null;
    +  }
    +
    +  @Test
    +  public void testCorr() {
    +    JavaRDD x = sc.parallelize(Lists.newArrayList(1.0, 2.0, 3.0, 4.0));
    +    JavaRDD y = sc.parallelize(Lists.newArrayList(1.1, 2.2, 3.1, 4.3));
    +
    +    Double corr1 = Statistics.corr(x, y);
    +    Double corr2 = Statistics.corr(x, y, "pearson");
    +    // Check default method
    +    assertEquals(corr1, corr2);
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
    index 2b04a3034782e..63d2fa31c7499 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
    @@ -17,15 +17,18 @@
     
     package org.apache.spark.ml
     
    +import scala.collection.JavaConverters._
    +
     import org.mockito.Matchers.{any, eq => meq}
     import org.mockito.Mockito.when
    -import org.scalatest.FunSuite
     import org.scalatest.mock.MockitoSugar.mock
     
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.feature.HashingTF
     import org.apache.spark.ml.param.ParamMap
     import org.apache.spark.sql.DataFrame
     
    -class PipelineSuite extends FunSuite {
    +class PipelineSuite extends SparkFunSuite {
     
       abstract class MyModel extends Model[MyModel]
     
    @@ -81,4 +84,28 @@ class PipelineSuite extends FunSuite {
           pipeline.fit(dataset)
         }
       }
    +
    +  test("PipelineModel.copy") {
    +    val hashingTF = new HashingTF()
    +      .setNumFeatures(100)
    +    val model = new PipelineModel("pipeline", Array[Transformer](hashingTF))
    +    val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10))
    +    require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
    +      "copy should handle extra stage params")
    +  }
    +
    +  test("pipeline model constructors") {
    +    val transform0 = mock[Transformer]
    +    val model1 = mock[MyModel]
    +
    +    val stages = Array(transform0, model1)
    +    val pipelineModel0 = new PipelineModel("pipeline0", stages)
    +    assert(pipelineModel0.uid === "pipeline0")
    +    assert(pipelineModel0.stages === stages)
    +
    +    val stagesAsList = stages.toList.asJava
    +    val pipelineModel1 = new PipelineModel("pipeline1", stagesAsList)
    +    assert(pipelineModel1.uid === "pipeline1")
    +    assert(pipelineModel1.stages === stages)
    +  }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
    index 17ddd335deb6d..512cffb1acb66 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
    @@ -17,9 +17,9 @@
     
     package org.apache.spark.ml.attribute
     
    -import org.scalatest.FunSuite
    +import org.apache.spark.SparkFunSuite
     
    -class AttributeGroupSuite extends FunSuite {
    +class AttributeGroupSuite extends SparkFunSuite {
     
       test("attribute group") {
         val attrs = Array(
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
    index ec9b717e41ce8..c5fd2f9d5a22a 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
    @@ -17,11 +17,10 @@
     
     package org.apache.spark.ml.attribute
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.sql.types._
     
    -class AttributeSuite extends FunSuite {
    +class AttributeSuite extends SparkFunSuite {
     
       test("default numeric attribute") {
         val attr: NumericAttribute = NumericAttribute.defaultAttr
    @@ -216,5 +215,10 @@ class AttributeSuite extends FunSuite {
         assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute)
         val fldWithMeta = new StructField("x", DoubleType, false, metadata)
         assert(Attribute.fromStructField(fldWithMeta).isNumeric)
    +    // Attribute.fromStructField should accept any NumericType, not just DoubleType
    +    val longFldWithMeta = new StructField("x", LongType, false, metadata)
    +    assert(Attribute.fromStructField(longFldWithMeta).isNumeric)
    +    val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata)
    +    assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
    index 03af4ecd7a7e0..73b4805c4c597 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
    @@ -17,19 +17,18 @@
     
     package org.apache.spark.ml.classification
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.ml.impl.TreeTests
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.ml.tree.LeafNode
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.regression.LabeledPoint
    -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
    -  DecisionTreeSuite => OldDecisionTreeSuite}
    +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
     
    -
    -class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
    +class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       import DecisionTreeClassifierSuite.compareAPIs
     
    @@ -56,6 +55,12 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
           OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
       }
     
    +  test("params") {
    +    ParamsSuite.checkParams(new DecisionTreeClassifier)
    +    val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))
    +    ParamsSuite.checkParams(model)
    +  }
    +
       /////////////////////////////////////////////////////////////////////////////
       // Tests calling train()
       /////////////////////////////////////////////////////////////////////////////
    @@ -251,7 +256,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
       */
     }
     
    -private[ml] object DecisionTreeClassifierSuite extends FunSuite {
    +private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
     
       /**
        * Train 2 decision trees on the given dataset, one using the old API and one using the new API.
    @@ -266,9 +271,9 @@ private[ml] object DecisionTreeClassifierSuite extends FunSuite {
         val oldTree = OldDecisionTree.train(data, oldStrategy)
         val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
         val newTree = dt.fit(newData)
    -    // Use parent, fittingParamMap from newTree since these are not checked anyways.
    +    // Use parent from newTree since this is not checked anyways.
         val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
    -      oldTree, newTree.parent, categoricalFeatures)
    +      oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures)
         TreeTests.checkEqual(oldTreeAsNew, newTree)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
    index 16c758b82c7cd..82c345491bb3c 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
    @@ -17,9 +17,11 @@
     
     package org.apache.spark.ml.classification
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.ml.impl.TreeTests
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.ml.regression.DecisionTreeRegressionModel
    +import org.apache.spark.ml.tree.LeafNode
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
     import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
    @@ -31,7 +33,7 @@ import org.apache.spark.sql.DataFrame
     /**
      * Test suite for [[GBTClassifier]].
      */
    -class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext {
    +class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       import GBTClassifierSuite.compareAPIs
     
    @@ -52,6 +54,14 @@ class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext {
           sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
       }
     
    +  test("params") {
    +    ParamsSuite.checkParams(new GBTClassifier)
    +    val model = new GBTClassificationModel("gbtc",
    +      Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))),
    +      Array(1.0))
    +    ParamsSuite.checkParams(model)
    +  }
    +
       test("Binary classification with continuous features: Log Loss") {
         val categoricalFeatures = Map.empty[Int, Int]
         testCombinations.foreach {
    @@ -128,9 +138,9 @@ private object GBTClassifierSuite {
         val oldModel = oldGBT.run(data)
         val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
         val newModel = gbt.fit(newData)
    -    // Use parent, fittingParamMap from newTree since these are not checked anyways.
    +    // Use parent from newTree since this is not checked anyways.
         val oldModelAsNew = GBTClassificationModel.fromOld(
    -      oldModel, newModel.parent, categoricalFeatures)
    +      oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures)
         TreeTests.checkEqual(oldModelAsNew, newModel)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
    index 4df8016009171..b7dd44753896a 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
    @@ -17,42 +17,38 @@
     
     package org.apache.spark.ml.classification
     
    -import org.scalatest.FunSuite
    -
    -import org.apache.spark.mllib.classification.LogisticRegressionSuite
    -import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.mllib.classification.LogisticRegressionSuite._
    +import org.apache.spark.mllib.linalg.{Vectors, Vector}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
    -import org.apache.spark.sql.{DataFrame, Row, SQLContext}
    -
    +import org.apache.spark.sql.{DataFrame, Row}
     
    -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
    +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     
    -  @transient var sqlContext: SQLContext = _
       @transient var dataset: DataFrame = _
       @transient var binaryDataset: DataFrame = _
       private val eps: Double = 1e-5
     
       override def beforeAll(): Unit = {
         super.beforeAll()
    -    sqlContext = new SQLContext(sc)
    -
    -    dataset = sqlContext.createDataFrame(sc.parallelize(LogisticRegressionSuite
    -      .generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 4))
    -
    -    /**
    -     * Here is the instruction describing how to export the test data into CSV format
    -     * so we can validate the training accuracy compared with R's glmnet package.
    -     *
    -     * import org.apache.spark.mllib.classification.LogisticRegressionSuite
    -     * val nPoints = 10000
    -     * val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
    -     * val xMean = Array(5.843, 3.057, 3.758, 1.199)
    -     * val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
    -     * val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput(
    -     *   weights, xMean, xVariance, true, nPoints, 42), 1)
    -     * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", "
    -     *   + x.features(2) + ", " + x.features(3)).saveAsTextFile("path")
    +
    +    dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
    +
    +    /*
    +       Here is the instruction describing how to export the test data into CSV format
    +       so we can validate the training accuracy compared with R's glmnet package.
    +
    +       import org.apache.spark.mllib.classification.LogisticRegressionSuite
    +       val nPoints = 10000
    +       val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
    +       val xMean = Array(5.843, 3.057, 3.758, 1.199)
    +       val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
    +       val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput(
    +         weights, xMean, xVariance, true, nPoints, 42), 1)
    +       data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", "
    +         + x.features(2) + ", " + x.features(3)).saveAsTextFile("path")
          */
         binaryDataset = {
           val nPoints = 10000
    @@ -60,32 +56,39 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
           val xMean = Array(5.843, 3.057, 3.758, 1.199)
           val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
     
    -      val testData = LogisticRegressionSuite.generateMultinomialLogisticInput(
    -        weights, xMean, xVariance, true, nPoints, 42)
    +      val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)
     
    -      sqlContext.createDataFrame(sc.parallelize(LogisticRegressionSuite
    -        .generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42), 4))
    +      sqlContext.createDataFrame(
    +        generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42))
         }
       }
     
    +  test("params") {
    +    ParamsSuite.checkParams(new LogisticRegression)
    +    val model = new LogisticRegressionModel("logReg", Vectors.dense(0.0), 0.0)
    +    ParamsSuite.checkParams(model)
    +  }
    +
       test("logistic regression: default params") {
         val lr = new LogisticRegression
    -    assert(lr.getLabelCol == "label")
    -    assert(lr.getFeaturesCol == "features")
    -    assert(lr.getPredictionCol == "prediction")
    -    assert(lr.getRawPredictionCol == "rawPrediction")
    -    assert(lr.getProbabilityCol == "probability")
    -    assert(lr.getFitIntercept == true)
    +    assert(lr.getLabelCol === "label")
    +    assert(lr.getFeaturesCol === "features")
    +    assert(lr.getPredictionCol === "prediction")
    +    assert(lr.getRawPredictionCol === "rawPrediction")
    +    assert(lr.getProbabilityCol === "probability")
    +    assert(lr.getFitIntercept)
    +    assert(lr.getStandardization)
         val model = lr.fit(dataset)
         model.transform(dataset)
           .select("label", "probability", "prediction", "rawPrediction")
           .collect()
         assert(model.getThreshold === 0.5)
    -    assert(model.getFeaturesCol == "features")
    -    assert(model.getPredictionCol == "prediction")
    -    assert(model.getRawPredictionCol == "rawPrediction")
    -    assert(model.getProbabilityCol == "probability")
    +    assert(model.getFeaturesCol === "features")
    +    assert(model.getPredictionCol === "prediction")
    +    assert(model.getRawPredictionCol === "rawPrediction")
    +    assert(model.getProbabilityCol === "probability")
         assert(model.intercept !== 0.0)
    +    assert(model.hasParent)
       }
     
       test("logistic regression doesn't fit intercept when fitIntercept is off") {
    @@ -103,7 +106,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
           .setThreshold(0.6)
           .setProbabilityCol("myProbability")
         val model = lr.fit(dataset)
    -    val parent = model.parent
    +    val parent = model.parent.asInstanceOf[LogisticRegression]
         assert(parent.getMaxIter === 10)
         assert(parent.getRegParam === 1.0)
         assert(parent.getThreshold === 0.6)
    @@ -129,12 +132,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
         // Call fit() with new params, and check as many params as we can.
         val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
           lr.probabilityCol -> "theProb")
    -    val parent2 = model2.parent
    +    val parent2 = model2.parent.asInstanceOf[LogisticRegression]
         assert(parent2.getMaxIter === 5)
         assert(parent2.getRegParam === 0.1)
         assert(parent2.getThreshold === 0.4)
         assert(model2.getThreshold === 0.4)
    -    assert(model2.getProbabilityCol == "theProb")
    +    assert(model2.getProbabilityCol === "theProb")
       }
     
       test("logistic regression: Predictor, Classifier methods") {
    @@ -206,267 +209,443 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
       }
     
       test("binary logistic regression with intercept without regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(true)
    -    val model = trainer.fit(binaryDataset)
    -
    -    /**
    -     * Using the following R code to load the data and train the model using glmnet package.
    -     *
    -     * > library("glmnet")
    -     * > data <- read.csv("path", header=FALSE)
    -     * > label = factor(data$V1)
    -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0))
    -     * > weights
    -     * 5 x 1 sparse Matrix of class "dgCMatrix"
    -     *                     s0
    -     * (Intercept)  2.8366423
    -     * data.V2     -0.5895848
    -     * data.V3      0.8931147
    -     * data.V4     -0.3925051
    -     * data.V5     -0.7996864
    +    val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                           s0
    +       (Intercept)  2.8366423
    +       data.V2     -0.5895848
    +       data.V3      0.8931147
    +       data.V4     -0.3925051
    +       data.V5     -0.7996864
          */
         val interceptR = 2.8366423
    -    val weightsR = Array(-0.5895848, 0.8931147, -0.3925051, -0.7996864)
    +    val weightsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864)
     
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
    +    assert(model1.intercept ~== interceptR relTol 1E-3)
    +    assert(model1.weights ~= weightsR relTol 1E-3)
    +
    +    // Without regularization, with or without standardization will converge to the same solution.
    +    assert(model2.intercept ~== interceptR relTol 1E-3)
    +    assert(model2.weights ~= weightsR relTol 1E-3)
       }
     
       test("binary logistic regression without intercept without regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(false)
    -    val model = trainer.fit(binaryDataset)
    -
    -    /**
    -     * Using the following R code to load the data and train the model using glmnet package.
    -     *
    -     * > library("glmnet")
    -     * > data <- read.csv("path", header=FALSE)
    -     * > label = factor(data$V1)
    -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    -     * > weights =
    -     *     coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE))
    -     * > weights
    -     * 5 x 1 sparse Matrix of class "dgCMatrix"
    -     *                     s0
    -     * (Intercept)   .
    -     * data.V2     -0.3534996
    -     * data.V3      1.2964482
    -     * data.V4     -0.3571741
    -     * data.V5     -0.7407946
    +    val trainer1 = (new LogisticRegression).setFitIntercept(false).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(false).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights =
    +           coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                           s0
    +       (Intercept)   .
    +       data.V2     -0.3534996
    +       data.V3      1.2964482
    +       data.V4     -0.3571741
    +       data.V5     -0.7407946
          */
         val interceptR = 0.0
    -    val weightsR = Array(-0.3534996, 1.2964482, -0.3571741, -0.7407946)
    +    val weightsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946)
    +
    +    assert(model1.intercept ~== interceptR relTol 1E-3)
    +    assert(model1.weights ~= weightsR relTol 1E-2)
     
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-2)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-2)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
    +    // Without regularization, with or without standardization should converge to the same solution.
    +    assert(model2.intercept ~== interceptR relTol 1E-3)
    +    assert(model2.weights ~= weightsR relTol 1E-2)
       }
     
       test("binary logistic regression with intercept with L1 regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(true)
    -      .setElasticNetParam(1.0).setRegParam(0.12)
    -    val model = trainer.fit(binaryDataset)
    -
    -    /**
    -     * Using the following R code to load the data and train the model using glmnet package.
    -     *
    -     * > library("glmnet")
    -     * > data <- read.csv("path", header=FALSE)
    -     * > label = factor(data$V1)
    -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12))
    -     * > weights
    -     * 5 x 1 sparse Matrix of class "dgCMatrix"
    -     *                      s0
    -     * (Intercept) -0.05627428
    -     * data.V2       .
    -     * data.V3       .
    -     * data.V4     -0.04325749
    -     * data.V5     -0.02481551
    +    val trainer1 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept) -0.05627428
    +       data.V2       .
    +       data.V3       .
    +       data.V4     -0.04325749
    +       data.V5     -0.02481551
    +     */
    +    val interceptR1 = -0.05627428
    +    val weightsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551)
    +
    +    assert(model1.intercept ~== interceptR1 relTol 1E-2)
    +    assert(model1.weights ~= weightsR1 absTol 2E-2)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
    +           standardize=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                           s0
    +       (Intercept)  0.3722152
    +       data.V2       .
    +       data.V3       .
    +       data.V4     -0.1665453
    +       data.V5       .
          */
    -    val interceptR = -0.05627428
    -    val weightsR = Array(0.0, 0.0, -0.04325749, -0.02481551)
    -
    -    assert(model.intercept ~== interceptR relTol 1E-2)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-2)
    -    assert(model.weights(3) ~== weightsR(3) relTol 2E-2)
    +    val interceptR2 = 0.3722152
    +    val weightsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0)
    +
    +    assert(model2.intercept ~== interceptR2 relTol 1E-2)
    +    assert(model2.weights ~= weightsR2 absTol 1E-3)
       }
     
       test("binary logistic regression without intercept with L1 regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(false)
    -      .setElasticNetParam(1.0).setRegParam(0.12)
    -    val model = trainer.fit(binaryDataset)
    -
    -    /**
    -     * Using the following R code to load the data and train the model using glmnet package.
    -     *
    -     * > library("glmnet")
    -     * > data <- read.csv("path", header=FALSE)
    -     * > label = factor(data$V1)
    -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
    -     *     intercept=FALSE))
    -     * > weights
    -     * 5 x 1 sparse Matrix of class "dgCMatrix"
    -     *                      s0
    -     * (Intercept)   .
    -     * data.V2       .
    -     * data.V3       .
    -     * data.V4     -0.05189203
    -     * data.V5     -0.03891782
    +    val trainer1 = (new LogisticRegression).setFitIntercept(false)
    +      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(false)
    +      .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
    +           intercept=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept)   .
    +       data.V2       .
    +       data.V3       .
    +       data.V4     -0.05189203
    +       data.V5     -0.03891782
          */
    -    val interceptR = 0.0
    -    val weightsR = Array(0.0, 0.0, -0.05189203, -0.03891782)
    +    val interceptR1 = 0.0
    +    val weightsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782)
    +
    +    assert(model1.intercept ~== interceptR1 relTol 1E-3)
    +    assert(model1.weights ~= weightsR1 absTol 1E-3)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12,
    +           intercept=FALSE, standardize=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept)   .
    +       data.V2       .
    +       data.V3       .
    +       data.V4     -0.08420782
    +       data.V5       .
    +     */
    +    val interceptR2 = 0.0
    +    val weightsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0)
     
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-2)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-2)
    +    assert(model2.intercept ~== interceptR2 absTol 1E-3)
    +    assert(model2.weights ~= weightsR2 absTol 1E-3)
       }
     
       test("binary logistic regression with intercept with L2 regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(true)
    -      .setElasticNetParam(0.0).setRegParam(1.37)
    -    val model = trainer.fit(binaryDataset)
    -
    -    /**
    -     * Using the following R code to load the data and train the model using glmnet package.
    -     *
    -     * > library("glmnet")
    -     * > data <- read.csv("path", header=FALSE)
    -     * > label = factor(data$V1)
    -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37))
    -     * > weights
    -     * 5 x 1 sparse Matrix of class "dgCMatrix"
    -     *                      s0
    -     * (Intercept)  0.15021751
    -     * data.V2     -0.07251837
    -     * data.V3      0.10724191
    -     * data.V4     -0.04865309
    -     * data.V5     -0.10062872
    +    val trainer1 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept)  0.15021751
    +       data.V2     -0.07251837
    +       data.V3      0.10724191
    +       data.V4     -0.04865309
    +       data.V5     -0.10062872
    +     */
    +    val interceptR1 = 0.15021751
    +    val weightsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872)
    +
    +    assert(model1.intercept ~== interceptR1 relTol 1E-3)
    +    assert(model1.weights ~= weightsR1 relTol 1E-3)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
    +           standardize=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept)  0.48657516
    +       data.V2     -0.05155371
    +       data.V3      0.02301057
    +       data.V4     -0.11482896
    +       data.V5     -0.06266838
          */
    -    val interceptR = 0.15021751
    -    val weightsR = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872)
    -
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
    +    val interceptR2 = 0.48657516
    +    val weightsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838)
    +
    +    assert(model2.intercept ~== interceptR2 relTol 1E-3)
    +    assert(model2.weights ~= weightsR2 relTol 1E-3)
       }
     
       test("binary logistic regression without intercept with L2 regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(false)
    -      .setElasticNetParam(0.0).setRegParam(1.37)
    -    val model = trainer.fit(binaryDataset)
    -
    -    /**
    -     * Using the following R code to load the data and train the model using glmnet package.
    -     *
    -     * > library("glmnet")
    -     * > data <- read.csv("path", header=FALSE)
    -     * > label = factor(data$V1)
    -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
    -     *     intercept=FALSE))
    -     * > weights
    -     * 5 x 1 sparse Matrix of class "dgCMatrix"
    -     *                      s0
    -     * (Intercept)   .
    -     * data.V2     -0.06099165
    -     * data.V3      0.12857058
    -     * data.V4     -0.04708770
    -     * data.V5     -0.09799775
    +    val trainer1 = (new LogisticRegression).setFitIntercept(false)
    +      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(false)
    +      .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
    +           intercept=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept)   .
    +       data.V2     -0.06099165
    +       data.V3      0.12857058
    +       data.V4     -0.04708770
    +       data.V5     -0.09799775
          */
    -    val interceptR = 0.0
    -    val weightsR = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775)
    +    val interceptR1 = 0.0
    +    val weightsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775)
    +
    +    assert(model1.intercept ~== interceptR1 absTol 1E-3)
    +    assert(model1.weights ~= weightsR1 relTol 1E-2)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37,
    +           intercept=FALSE, standardize=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                             s0
    +       (Intercept)   .
    +       data.V2     -0.005679651
    +       data.V3      0.048967094
    +       data.V4     -0.093714016
    +       data.V5     -0.053314311
    +     */
    +    val interceptR2 = 0.0
    +    val weightsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311)
     
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-2)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-2)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
    +    assert(model2.intercept ~== interceptR2 absTol 1E-3)
    +    assert(model2.weights ~= weightsR2 relTol 1E-2)
       }
     
       test("binary logistic regression with intercept with ElasticNet regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(true)
    -      .setElasticNetParam(0.38).setRegParam(0.21)
    -    val model = trainer.fit(binaryDataset)
    -
    -    /**
    -     * Using the following R code to load the data and train the model using glmnet package.
    -     *
    -     * > library("glmnet")
    -     * > data <- read.csv("path", header=FALSE)
    -     * > label = factor(data$V1)
    -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21))
    -     * > weights
    -     * 5 x 1 sparse Matrix of class "dgCMatrix"
    -     *                      s0
    -     * (Intercept)  0.57734851
    -     * data.V2     -0.05310287
    -     * data.V3       .
    -     * data.V4     -0.08849250
    -     * data.V5     -0.15458796
    +    val trainer1 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept)  0.57734851
    +       data.V2     -0.05310287
    +       data.V3       .
    +       data.V4     -0.08849250
    +       data.V5     -0.15458796
          */
    -    val interceptR = 0.57734851
    -    val weightsR = Array(-0.05310287, 0.0, -0.08849250, -0.15458796)
    -
    -    assert(model.intercept ~== interceptR relTol 6E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 5E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    -    assert(model.weights(2) ~== weightsR(2) relTol 5E-3)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-3)
    +    val interceptR1 = 0.57734851
    +    val weightsR1 = Vectors.dense(-0.05310287, 0.0, -0.08849250, -0.15458796)
    +
    +    assert(model1.intercept ~== interceptR1 relTol 6E-3)
    +    assert(model1.weights ~== weightsR1 absTol 5E-3)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
    +           standardize=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept)  0.51555993
    +       data.V2       .
    +       data.V3       .
    +       data.V4     -0.18807395
    +       data.V5     -0.05350074
    +     */
    +    val interceptR2 = 0.51555993
    +    val weightsR2 = Vectors.dense(0.0, 0.0, -0.18807395, -0.05350074)
    +
    +    assert(model2.intercept ~== interceptR2 relTol 6E-3)
    +    assert(model2.weights ~= weightsR2 absTol 1E-3)
       }
     
       test("binary logistic regression without intercept with ElasticNet regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(false)
    -      .setElasticNetParam(0.38).setRegParam(0.21)
    -    val model = trainer.fit(binaryDataset)
    -
    -    /**
    -     * Using the following R code to load the data and train the model using glmnet package.
    -     *
    -     * > library("glmnet")
    -     * > data <- read.csv("path", header=FALSE)
    -     * > label = factor(data$V1)
    -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
    -     *     intercept=FALSE))
    -     * > weights
    -     * 5 x 1 sparse Matrix of class "dgCMatrix"
    -     *                      s0
    -     * (Intercept)   .
    -     * data.V2     -0.001005743
    -     * data.V3      0.072577857
    -     * data.V4     -0.081203769
    -     * data.V5     -0.142534158
    +    val trainer1 = (new LogisticRegression).setFitIntercept(false)
    +      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(false)
    +      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
    +           intercept=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept)   .
    +       data.V2     -0.001005743
    +       data.V3      0.072577857
    +       data.V4     -0.081203769
    +       data.V5     -0.142534158
          */
    -    val interceptR = 0.0
    -    val weightsR = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158)
    +    val interceptR1 = 0.0
    +    val weightsR1 = Vectors.dense(-0.001005743, 0.072577857, -0.081203769, -0.142534158)
    +
    +    assert(model1.intercept ~== interceptR1 relTol 1E-3)
    +    assert(model1.weights ~= weightsR1 absTol 1E-2)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21,
    +           intercept=FALSE, standardize=FALSE))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept)   .
    +       data.V2       .
    +       data.V3      0.03345223
    +       data.V4     -0.11304532
    +       data.V5       .
    +     */
    +    val interceptR2 = 0.0
    +    val weightsR2 = Vectors.dense(0.0, 0.03345223, -0.11304532, 0.0)
     
    -    assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) absTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) absTol 1E-2)
    -    assert(model.weights(2) ~== weightsR(2) relTol 1E-3)
    -    assert(model.weights(3) ~== weightsR(3) relTol 1E-2)
    +    assert(model2.intercept ~== interceptR2 absTol 1E-3)
    +    assert(model2.weights ~= weightsR2 absTol 1E-3)
       }
     
       test("binary logistic regression with intercept with strong L1 regularization") {
    -    val trainer = (new LogisticRegression).setFitIntercept(true)
    -      .setElasticNetParam(1.0).setRegParam(6.0)
    -    val model = trainer.fit(binaryDataset)
    +    val trainer1 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(true)
    +    val trainer2 = (new LogisticRegression).setFitIntercept(true)
    +      .setElasticNetParam(1.0).setRegParam(6.0).setStandardization(false)
    +
    +    val model1 = trainer1.fit(binaryDataset)
    +    val model2 = trainer2.fit(binaryDataset)
     
         val histogram = binaryDataset.map { case Row(label: Double, features: Vector) => label }
           .treeAggregate(new MultiClassSummarizer)(
    @@ -478,50 +657,48 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
                 classSummarizer1.merge(classSummarizer2)
             }).histogram
     
    -    /**
    -     * For binary logistic regression with strong L1 regularization, all the weights will be zeros.
    -     * As a result,
    -     * {{{
    -     * P(0) = 1 / (1 + \exp(b)), and
    -     * P(1) = \exp(b) / (1 + \exp(b))
    -     * }}}, hence
    -     * {{{
    -     * b = \log{P(1) / P(0)} = \log{count_1 / count_0}
    -     * }}}
    +    /*
    +       For binary logistic regression with strong L1 regularization, all the weights will be zeros.
    +       As a result,
    +       {{{
    +       P(0) = 1 / (1 + \exp(b)), and
    +       P(1) = \exp(b) / (1 + \exp(b))
    +       }}}, hence
    +       {{{
    +       b = \log{P(1) / P(0)} = \log{count_1 / count_0}
    +       }}}
          */
         val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble)
    -    val weightsTheory = Array(0.0, 0.0, 0.0, 0.0)
    -
    -    assert(model.intercept ~== interceptTheory relTol 1E-5)
    -    assert(model.weights(0) ~== weightsTheory(0) absTol 1E-6)
    -    assert(model.weights(1) ~== weightsTheory(1) absTol 1E-6)
    -    assert(model.weights(2) ~== weightsTheory(2) absTol 1E-6)
    -    assert(model.weights(3) ~== weightsTheory(3) absTol 1E-6)
    -
    -    /**
    -     * Using the following R code to load the data and train the model using glmnet package.
    -     *
    -     * > library("glmnet")
    -     * > data <- read.csv("path", header=FALSE)
    -     * > label = factor(data$V1)
    -     * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    -     * > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0))
    -     * > weights
    -     * 5 x 1 sparse Matrix of class "dgCMatrix"
    -     *                      s0
    -     * (Intercept) -0.2480643
    -     * data.V2      0.0000000
    -     * data.V3       .
    -     * data.V4       .
    -     * data.V5       .
    +    val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0)
    +
    +    assert(model1.intercept ~== interceptTheory relTol 1E-5)
    +    assert(model1.weights ~= weightsTheory absTol 1E-6)
    +
    +    assert(model2.intercept ~== interceptTheory relTol 1E-5)
    +    assert(model2.weights ~= weightsTheory absTol 1E-6)
    +
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE)
    +       label = factor(data$V1)
    +       features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
    +       weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0))
    +       weights
    +
    +       5 x 1 sparse Matrix of class "dgCMatrix"
    +                            s0
    +       (Intercept) -0.2480643
    +       data.V2      0.0000000
    +       data.V3       .
    +       data.V4       .
    +       data.V5       .
          */
         val interceptR = -0.248065
    -    val weightsR = Array(0.0, 0.0, 0.0, 0.0)
    +    val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0)
     
    -    assert(model.intercept ~== interceptR relTol 1E-5)
    -    assert(model.weights(0) ~== weightsR(0) absTol 1E-6)
    -    assert(model.weights(1) ~== weightsR(1) absTol 1E-6)
    -    assert(model.weights(2) ~== weightsR(2) absTol 1E-6)
    -    assert(model.weights(3) ~== weightsR(3) absTol 1E-6)
    +    assert(model1.intercept ~== interceptR relTol 1E-5)
    +    assert(model1.weights ~= weightsR absTol 1E-6)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
    index e65ffae918ca9..75cf5bd4ead4f 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
    @@ -17,28 +17,29 @@
     
     package org.apache.spark.ml.classification
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.ml.attribute.NominalAttribute
    +import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
     import org.apache.spark.ml.util.MetadataUtils
    -import org.apache.spark.mllib.classification.LogisticRegressionSuite._
     import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
    +import org.apache.spark.mllib.classification.LogisticRegressionSuite._
     import org.apache.spark.mllib.evaluation.MulticlassMetrics
    +import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.rdd.RDD
    -import org.apache.spark.sql.{DataFrame, SQLContext}
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.types.Metadata
     
    -class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
    +class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
     
    -  @transient var sqlContext: SQLContext = _
       @transient var dataset: DataFrame = _
       @transient var rdd: RDD[LabeledPoint] = _
     
       override def beforeAll(): Unit = {
         super.beforeAll()
    -    sqlContext = new SQLContext(sc)
    +
         val nPoints = 1000
     
         // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2.
    @@ -54,10 +55,17 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
         dataset = sqlContext.createDataFrame(rdd)
       }
     
    +  test("params") {
    +    ParamsSuite.checkParams(new OneVsRest)
    +    val lrModel = new LogisticRegressionModel("lr", Vectors.dense(0.0), 0.0)
    +    val model = new OneVsRestModel("ovr", Metadata.empty, Array(lrModel))
    +    ParamsSuite.checkParams(model)
    +  }
    +
       test("one-vs-rest: default params") {
         val numClasses = 3
         val ova = new OneVsRest()
    -    ova.setClassifier(new LogisticRegression)
    +      .setClassifier(new LogisticRegression)
         assert(ova.getLabelCol === "label")
         assert(ova.getPredictionCol === "prediction")
         val ovaModel = ova.fit(dataset)
    @@ -95,9 +103,40 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
         val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
         ova.fit(datasetWithLabelMetadata)
       }
    +
    +  test("SPARK-8049: OneVsRest shouldn't output temp columns") {
    +    val logReg = new LogisticRegression()
    +      .setMaxIter(1)
    +    val ovr = new OneVsRest()
    +      .setClassifier(logReg)
    +    val output = ovr.fit(dataset).transform(dataset)
    +    assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
    +  }
    +
    +  test("OneVsRest.copy and OneVsRestModel.copy") {
    +    val lr = new LogisticRegression()
    +      .setMaxIter(1)
    +
    +    val ovr = new OneVsRest()
    +    withClue("copy with classifier unset should work") {
    +      ovr.copy(ParamMap(lr.maxIter -> 10))
    +    }
    +    ovr.setClassifier(lr)
    +    val ovr1 = ovr.copy(ParamMap(lr.maxIter -> 10))
    +    require(ovr.getClassifier.getOrDefault(lr.maxIter) === 1, "copy should have no side-effects")
    +    require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10,
    +      "copy should handle extra classifier params")
    +
    +    val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1))
    +    ovrModel.models.foreach { case m: LogisticRegressionModel =>
    +      require(m.getThreshold === 0.1, "copy should handle extra model params")
    +    }
    +  }
     }
     
    -private class MockLogisticRegression extends LogisticRegression {
    +private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
    +
    +  def this() = this("mockLogReg")
     
       setMaxIter(1)
     
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
    index c41def9330504..1b6b69c7dc71e 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
    @@ -17,9 +17,10 @@
     
     package org.apache.spark.ml.classification
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.ml.impl.TreeTests
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.ml.tree.LeafNode
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
    @@ -28,11 +29,10 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
     
    -
     /**
      * Test suite for [[RandomForestClassifier]].
      */
    -class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext {
    +class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       import RandomForestClassifierSuite.compareAPIs
     
    @@ -63,6 +63,13 @@ class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext {
         compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses)
       }
     
    +  test("params") {
    +    ParamsSuite.checkParams(new RandomForestClassifier)
    +    val model = new RandomForestClassificationModel("rfc",
    +      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))))
    +    ParamsSuite.checkParams(model)
    +  }
    +
       test("Binary classification with continuous features:" +
         " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
         val rf = new RandomForestClassifier()
    @@ -158,9 +165,11 @@ private object RandomForestClassifierSuite {
           data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
         val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
         val newModel = rf.fit(newData)
    -    // Use parent, fittingParamMap from newTree since these are not checked anyways.
    +    // Use parent from newTree since this is not checked anyways.
         val oldModelAsNew = RandomForestClassificationModel.fromOld(
    -      oldModel, newModel.parent, categoricalFeatures)
    +      oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
         TreeTests.checkEqual(oldModelAsNew, newModel)
    +    assert(newModel.hasParent)
    +    assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
       }
     }
    diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
    similarity index 65%
    rename from core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala
    rename to mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
    index 8df4f3b554c41..def869fe66777 100644
    --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
    @@ -15,17 +15,14 @@
      * limitations under the License.
      */
     
    -package org.apache.spark.scheduler.cluster.mesos
    +package org.apache.spark.ml.evaluation
     
    -import org.apache.spark.SparkContext
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
     
    -private[spark] object MemoryUtils {
    -  // These defaults copied from YARN
    -  val OVERHEAD_FRACTION = 0.10
    -  val OVERHEAD_MINIMUM = 384
    +class BinaryClassificationEvaluatorSuite extends SparkFunSuite {
     
    -  def calculateTotalMemory(sc: SparkContext): Int = {
    -    sc.conf.getInt("spark.mesos.executor.memoryOverhead",
    -      math.max(OVERHEAD_FRACTION * sc.executorMemory, OVERHEAD_MINIMUM).toInt) + sc.executorMemory
    +  test("params") {
    +    ParamsSuite.checkParams(new BinaryClassificationEvaluator)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
    new file mode 100644
    index 0000000000000..5b203784559e2
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
    @@ -0,0 +1,76 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.evaluation
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.ml.regression.LinearRegression
    +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
    +import org.apache.spark.mllib.util.TestingUtils._
    +
    +class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  test("params") {
    +    ParamsSuite.checkParams(new RegressionEvaluator)
    +  }
    +
    +  test("Regression Evaluator: default params") {
    +    /**
    +     * Here is the instruction describing how to export the test data into CSV format
    +     * so we can validate the metrics compared with R's mmetric package.
    +     *
    +     * import org.apache.spark.mllib.util.LinearDataGenerator
    +     * val data = sc.parallelize(LinearDataGenerator.generateLinearInput(6.3,
    +     *   Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1))
    +     * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1))
    +     *   .saveAsTextFile("path")
    +     */
    +    val dataset = sqlContext.createDataFrame(
    +      sc.parallelize(LinearDataGenerator.generateLinearInput(
    +        6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
    +
    +    /**
    +     * Using the following R code to load the data, train the model and evaluate metrics.
    +     *
    +     * > library("glmnet")
    +     * > library("rminer")
    +     * > data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
    +     * > features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
    +     * > label <- as.numeric(data$V1)
    +     * > model <- glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)
    +     * > rmse <- mmetric(label, predict(model, features), metric='RMSE')
    +     * > mae <- mmetric(label, predict(model, features), metric='MAE')
    +     * > r2 <- mmetric(label, predict(model, features), metric='R2')
    +     */
    +    val trainer = new LinearRegression
    +    val model = trainer.fit(dataset)
    +    val predictions = model.transform(dataset)
    +
    +    // default = rmse
    +    val evaluator = new RegressionEvaluator()
    +    assert(evaluator.evaluate(predictions) ~== -0.1019382 absTol 0.001)
    +
    +    // r2 score
    +    evaluator.setMetricName("r2")
    +    assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001)
    +
    +    // mae
    +    evaluator.setMetricName("mae")
    +    assert(evaluator.evaluate(predictions) ~== -0.08036075 absTol 0.001)
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
    index caf1b759593f3..2086043983661 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
    @@ -17,24 +17,24 @@
     
     package org.apache.spark.ml.feature
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
    -import org.apache.spark.mllib.util.TestingUtils._
    -import org.apache.spark.sql.{DataFrame, Row, SQLContext}
    -
    +import org.apache.spark.sql.{DataFrame, Row}
     
    -class BinarizerSuite extends FunSuite with MLlibTestSparkContext {
    +class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       @transient var data: Array[Double] = _
    -  @transient var sqlContext: SQLContext = _
     
       override def beforeAll(): Unit = {
         super.beforeAll()
    -    sqlContext = new SQLContext(sc)
         data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
       }
     
    +  test("params") {
    +    ParamsSuite.checkParams(new Binarizer)
    +  }
    +
       test("Binarize continuous features with default parameter") {
         val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
         val dataFrame: DataFrame = sqlContext.createDataFrame(
    @@ -52,7 +52,7 @@ class BinarizerSuite extends FunSuite with MLlibTestSparkContext {
     
       test("Binarize continuous features with setter") {
         val threshold: Double = 0.2
    -    val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) 
    +    val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
         val dataFrame: DataFrame = sqlContext.createDataFrame(
             data.zip(thresholdBinarized)).toDF("feature", "expected")
     
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
    index 1900820400aee..ec85e0d151e07 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
    @@ -19,21 +19,17 @@ package org.apache.spark.ml.feature
     
     import scala.util.Random
     
    -import org.scalatest.FunSuite
    -
    -import org.apache.spark.SparkException
    +import org.apache.spark.{SparkException, SparkFunSuite}
    +import org.apache.spark.ml.param.ParamsSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
    -import org.apache.spark.sql.{DataFrame, Row, SQLContext}
    -
    -class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
    +import org.apache.spark.sql.{DataFrame, Row}
     
    -  @transient private var sqlContext: SQLContext = _
    +class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
     
    -  override def beforeAll(): Unit = {
    -    super.beforeAll()
    -    sqlContext = new SQLContext(sc)
    +  test("params") {
    +    ParamsSuite.checkParams(new Bucketizer)
       }
     
       test("Bucket continuous features, without -inf,inf") {
    @@ -117,12 +113,13 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
       }
     }
     
    -private object BucketizerSuite extends FunSuite {
    +private object BucketizerSuite extends SparkFunSuite {
       /** Brute force search for buckets.  Bucket i is defined by the range [split(i), split(i+1)). */
       def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
         require(feature >= splits.head)
         var i = 0
    -    while (i < splits.length - 1) {
    +    val n = splits.length - 1
    +    while (i < n) {
           if (feature < splits(i + 1)) return i
           i += 1
         }
    @@ -138,7 +135,8 @@ private object BucketizerSuite extends FunSuite {
               s" ${splits.mkString(", ")}")
         }
         var i = 0
    -    while (i < splits.length - 1) {
    +    val n = splits.length - 1
    +    while (i < n) {
           // Split i should fall in bucket i.
           testFeature(splits(i), i)
           // Value between splits i,i+1 should be in i, which is also true if the (i+1)-th split is inf.
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala
    new file mode 100644
    index 0000000000000..e90d9d4ef21ff
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala
    @@ -0,0 +1,73 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.ml.feature
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.mllib.util.TestingUtils._
    +
    +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  test("params") {
    +    ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
    +  }
    +
    +  test("CountVectorizerModel common cases") {
    +    val df = sqlContext.createDataFrame(Seq(
    +      (0, "a b c d".split(" ").toSeq,
    +        Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
    +      (1, "a b b c d  a".split(" ").toSeq,
    +        Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))),
    +      (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))),
    +      (3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string
    +      (4, "a notInDict d".split(" ").toSeq,
    +        Vectors.sparse(4, Seq((0, 1.0), (3, 1.0))))  // with words not in vocabulary
    +    )).toDF("id", "words", "expected")
    +    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
    +      .setInputCol("words")
    +      .setOutputCol("features")
    +    val output = cv.transform(df).collect()
    +    output.foreach { p =>
    +      val features = p.getAs[Vector]("features")
    +      val expected = p.getAs[Vector]("expected")
    +      assert(features ~== expected absTol 1e-14)
    +    }
    +  }
    +
    +  test("CountVectorizerModel with minTermFreq") {
    +    val df = sqlContext.createDataFrame(Seq(
    +      (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
    +      (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))),
    +      (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())),
    +      (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq())))
    +    ).toDF("id", "words", "expected")
    +    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
    +      .setInputCol("words")
    +      .setOutputCol("features")
    +      .setMinTermFreq(3)
    +    val output = cv.transform(df).collect()
    +    output.foreach { p =>
    +      val features = p.getAs[Vector]("features")
    +      val expected = p.getAs[Vector]("expected")
    +      assert(features ~== expected absTol 1e-14)
    +    }
    +  }
    +}
    +
    +
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
    new file mode 100644
    index 0000000000000..37ed2367c33f7
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
    @@ -0,0 +1,73 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import scala.beans.BeanInfo
    +
    +import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.sql.{DataFrame, Row}
    +
    +@BeanInfo
    +case class DCTTestData(vec: Vector, wantedVec: Vector)
    +
    +class DCTSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  test("forward transform of discrete cosine matches jTransforms result") {
    +    val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
    +    val inverse = false
    +
    +    testDCT(data, inverse)
    +  }
    +
    +  test("inverse transform of discrete cosine matches jTransforms result") {
    +    val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
    +    val inverse = true
    +
    +    testDCT(data, inverse)
    +  }
    +
    +  private def testDCT(data: Vector, inverse: Boolean): Unit = {
    +    val expectedResultBuffer = data.toArray.clone()
    +    if (inverse) {
    +      (new DoubleDCT_1D(data.size)).inverse(expectedResultBuffer, true)
    +    } else {
    +      (new DoubleDCT_1D(data.size)).forward(expectedResultBuffer, true)
    +    }
    +    val expectedResult = Vectors.dense(expectedResultBuffer)
    +
    +    val dataset = sqlContext.createDataFrame(Seq(
    +      DCTTestData(data, expectedResult)
    +    ))
    +
    +    val transformer = new DCT()
    +      .setInputCol("vec")
    +      .setOutputCol("resultVec")
    +      .setInverse(inverse)
    +
    +    transformer.transform(dataset)
    +      .select("resultVec", "wantedVec")
    +      .collect()
    +      .foreach { case Row(resultVec: Vector, wantedVec: Vector) =>
    +      assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6)
    +    }
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
    new file mode 100644
    index 0000000000000..4157b84b29d01
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
    @@ -0,0 +1,53 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.attribute.AttributeGroup
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.mllib.util.TestingUtils._
    +import org.apache.spark.util.Utils
    +
    +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  test("params") {
    +    ParamsSuite.checkParams(new HashingTF)
    +  }
    +
    +  test("hashingTF") {
    +    val df = sqlContext.createDataFrame(Seq(
    +      (0, "a a b b c d".split(" ").toSeq)
    +    )).toDF("id", "words")
    +    val n = 100
    +    val hashingTF = new HashingTF()
    +      .setInputCol("words")
    +      .setOutputCol("features")
    +      .setNumFeatures(n)
    +    val output = hashingTF.transform(df)
    +    val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
    +    require(attrGroup.numAttributes === Some(n))
    +    val features = output.select("features").first().getAs[Vector](0)
    +    // Assume perfect hash on "a", "b", "c", and "d".
    +    def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n)
    +    val expected = Vectors.sparse(n,
    +      Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
    +    assert(features ~== expected absTol 1e-14)
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
    index eaee3443c1f23..08f80af03429b 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
    @@ -17,21 +17,15 @@
     
     package org.apache.spark.ml.feature
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
     import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
    -import org.apache.spark.sql.{Row, SQLContext}
    -
    -class IDFSuite extends FunSuite with MLlibTestSparkContext {
    -
    -  @transient var sqlContext: SQLContext = _
    +import org.apache.spark.sql.Row
     
    -  override def beforeAll(): Unit = {
    -    super.beforeAll()
    -    sqlContext = new SQLContext(sc)
    -  }
    +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
         dataSet.map {
    @@ -46,6 +40,12 @@ class IDFSuite extends FunSuite with MLlibTestSparkContext {
         }
       }
     
    +  test("params") {
    +    ParamsSuite.checkParams(new IDF)
    +    val model = new IDFModel("idf", new OldIDFModel(Vectors.dense(1.0)))
    +    ParamsSuite.checkParams(model)
    +  }
    +
       test("compute IDF with default parameter") {
         val numOfFeatures = 4
         val data = Array(
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
    new file mode 100644
    index 0000000000000..c452054bec92f
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
    @@ -0,0 +1,68 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.sql.{Row, SQLContext}
    +
    +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  test("MinMaxScaler fit basic case") {
    +    val sqlContext = new SQLContext(sc)
    +
    +    val data = Array(
    +      Vectors.dense(1, 0, Long.MinValue),
    +      Vectors.dense(2, 0, 0),
    +      Vectors.sparse(3, Array(0, 2), Array(3, Long.MaxValue)),
    +      Vectors.sparse(3, Array(0), Array(1.5)))
    +
    +    val expected: Array[Vector] = Array(
    +      Vectors.dense(-5, 0, -5),
    +      Vectors.dense(0, 0, 0),
    +      Vectors.sparse(3, Array(0, 2), Array(5, 5)),
    +      Vectors.sparse(3, Array(0), Array(-2.5)))
    +
    +    val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
    +    val scaler = new MinMaxScaler()
    +      .setInputCol("features")
    +      .setOutputCol("scaled")
    +      .setMin(-5)
    +      .setMax(5)
    +
    +    val model = scaler.fit(df)
    +    model.transform(df).select("expected", "scaled").collect()
    +      .foreach { case Row(vector1: Vector, vector2: Vector) =>
    +        assert(vector1.equals(vector2), "Transformed vector is different with expected.")
    +    }
    +  }
    +
    +  test("MinMaxScaler arguments max must be larger than min") {
    +    withClue("arguments max must be larger than min") {
    +      intercept[IllegalArgumentException] {
    +        val scaler = new MinMaxScaler().setMin(10).setMax(0)
    +        scaler.validateParams()
    +      }
    +      intercept[IllegalArgumentException] {
    +        val scaler = new MinMaxScaler().setMin(0).setMax(0)
    +        scaler.validateParams()
    +      }
    +    }
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
    new file mode 100644
    index 0000000000000..ab97e3dbc6ee0
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
    @@ -0,0 +1,94 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import scala.beans.BeanInfo
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.sql.{DataFrame, Row}
    +
    +@BeanInfo
    +case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
    +
    +class NGramSuite extends SparkFunSuite with MLlibTestSparkContext {
    +  import org.apache.spark.ml.feature.NGramSuite._
    +
    +  test("default behavior yields bigram features") {
    +    val nGram = new NGram()
    +      .setInputCol("inputTokens")
    +      .setOutputCol("nGrams")
    +    val dataset = sqlContext.createDataFrame(Seq(
    +      NGramTestData(
    +        Array("Test", "for", "ngram", "."),
    +        Array("Test for", "for ngram", "ngram .")
    +    )))
    +    testNGram(nGram, dataset)
    +  }
    +
    +  test("NGramLength=4 yields length 4 n-grams") {
    +    val nGram = new NGram()
    +      .setInputCol("inputTokens")
    +      .setOutputCol("nGrams")
    +      .setN(4)
    +    val dataset = sqlContext.createDataFrame(Seq(
    +      NGramTestData(
    +        Array("a", "b", "c", "d", "e"),
    +        Array("a b c d", "b c d e")
    +      )))
    +    testNGram(nGram, dataset)
    +  }
    +
    +  test("empty input yields empty output") {
    +    val nGram = new NGram()
    +      .setInputCol("inputTokens")
    +      .setOutputCol("nGrams")
    +      .setN(4)
    +    val dataset = sqlContext.createDataFrame(Seq(
    +      NGramTestData(
    +        Array(),
    +        Array()
    +      )))
    +    testNGram(nGram, dataset)
    +  }
    +
    +  test("input array < n yields empty output") {
    +    val nGram = new NGram()
    +      .setInputCol("inputTokens")
    +      .setOutputCol("nGrams")
    +      .setN(6)
    +    val dataset = sqlContext.createDataFrame(Seq(
    +      NGramTestData(
    +        Array("a", "b", "c", "d", "e"),
    +        Array()
    +      )))
    +    testNGram(nGram, dataset)
    +  }
    +}
    +
    +object NGramSuite extends SparkFunSuite {
    +
    +  def testNGram(t: NGram, dataset: DataFrame): Unit = {
    +    t.transform(dataset)
    +      .select("nGrams", "wantedNGrams")
    +      .collect()
    +      .foreach { case Row(actualNGrams, wantedNGrams) =>
    +        assert(actualNGrams === wantedNGrams)
    +      }
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
    index 9d09f24709e23..9f03470b7f328 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
    @@ -17,15 +17,14 @@
     
     package org.apache.spark.ml.feature
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.sql.{DataFrame, Row, SQLContext}
     
     
    -class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
    +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       @transient var data: Array[Vector] = _
       @transient var dataFrame: DataFrame = _
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
    index 92ec407b98d69..65846a846b7b4 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
    @@ -17,20 +17,15 @@
     
     package org.apache.spark.ml.feature
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
    +import org.apache.spark.ml.param.ParamsSuite
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.util.MLlibTestSparkContext
    -import org.apache.spark.sql.{DataFrame, SQLContext}
    -
    -
    -class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
    -  private var sqlContext: SQLContext = _
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions.col
     
    -  override def beforeAll(): Unit = {
    -    super.beforeAll()
    -    sqlContext = new SQLContext(sc)
    -  }
    +class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       def stringIndexed(): DataFrame = {
         val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
    @@ -42,15 +37,20 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
         indexer.transform(df)
       }
     
    -  test("OneHotEncoder includeFirst = true") {
    +  test("params") {
    +    ParamsSuite.checkParams(new OneHotEncoder)
    +  }
    +
    +  test("OneHotEncoder dropLast = false") {
         val transformed = stringIndexed()
         val encoder = new OneHotEncoder()
           .setInputCol("labelIndex")
           .setOutputCol("labelVec")
    +      .setDropLast(false)
         val encoded = encoder.transform(transformed)
     
         val output = encoded.select("id", "labelVec").map { r =>
    -      val vec = r.get(1).asInstanceOf[Vector]
    +      val vec = r.getAs[Vector](1)
           (r.getInt(0), vec(0), vec(1), vec(2))
         }.collect().toSet
         // a -> 0, b -> 2, c -> 1
    @@ -59,22 +59,46 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
         assert(output === expected)
       }
     
    -  test("OneHotEncoder includeFirst = false") {
    +  test("OneHotEncoder dropLast = true") {
         val transformed = stringIndexed()
         val encoder = new OneHotEncoder()
    -      .setIncludeFirst(false)
           .setInputCol("labelIndex")
           .setOutputCol("labelVec")
         val encoded = encoder.transform(transformed)
     
         val output = encoded.select("id", "labelVec").map { r =>
    -      val vec = r.get(1).asInstanceOf[Vector]
    +      val vec = r.getAs[Vector](1)
           (r.getInt(0), vec(0), vec(1))
         }.collect().toSet
         // a -> 0, b -> 2, c -> 1
    -    val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0),
    -      (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0))
    +    val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0),
    +      (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0))
         assert(output === expected)
       }
     
    +  test("input column with ML attribute") {
    +    val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
    +    val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size")
    +      .select(col("size").as("size", attr.toMetadata()))
    +    val encoder = new OneHotEncoder()
    +      .setInputCol("size")
    +      .setOutputCol("encoded")
    +    val output = encoder.transform(df)
    +    val group = AttributeGroup.fromStructField(output.schema("encoded"))
    +    assert(group.size === 2)
    +    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
    +    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
    +  }
    +
    +  test("input column without ML attribute") {
    +    val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index")
    +    val encoder = new OneHotEncoder()
    +      .setInputCol("index")
    +      .setOutputCol("encoded")
    +    val output = encoder.transform(df)
    +    val group = AttributeGroup.fromStructField(output.schema("encoded"))
    +    assert(group.size === 2)
    +    assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
    +    assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
    +  }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
    new file mode 100644
    index 0000000000000..d0ae36b28c7a9
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
    @@ -0,0 +1,64 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.mllib.linalg.distributed.RowMatrix
    +import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices}
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.mllib.util.TestingUtils._
    +import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel}
    +import org.apache.spark.sql.Row
    +
    +class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  test("params") {
    +    ParamsSuite.checkParams(new PCA)
    +    val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]
    +    val model = new PCAModel("pca", new OldPCAModel(2, mat))
    +    ParamsSuite.checkParams(model)
    +  }
    +
    +  test("pca") {
    +    val data = Array(
    +      Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
    +      Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
    +      Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
    +    )
    +
    +    val dataRDD = sc.parallelize(data, 2)
    +
    +    val mat = new RowMatrix(dataRDD)
    +    val pc = mat.computePrincipalComponents(3)
    +    val expected = mat.multiply(pc).rows
    +
    +    val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected")
    +
    +    val pca = new PCA()
    +      .setInputCol("features")
    +      .setOutputCol("pca_features")
    +      .setK(3)
    +      .fit(df)
    +
    +    pca.transform(df).select("pca_features", "expected").collect().foreach {
    +      case Row(x: Vector, y: Vector) =>
    +        assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
    +    }
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
    index c1d64fba0aa8f..29eebd8960ebc 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
    @@ -17,21 +17,19 @@
     
     package org.apache.spark.ml.feature
     
    -import org.scalatest.FunSuite
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.scalatest.exceptions.TestFailedException
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
    -import org.apache.spark.sql.{Row, SQLContext}
    -import org.scalatest.exceptions.TestFailedException
    -
    -class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext {
    +import org.apache.spark.sql.Row
     
    -  @transient var sqlContext: SQLContext = _
    +class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext {
     
    -  override def beforeAll(): Unit = {
    -    super.beforeAll()
    -    sqlContext = new SQLContext(sc)
    +  test("params") {
    +    ParamsSuite.checkParams(new PolynomialExpansion)
       }
     
       test("Polynomial expansion with default parameter") {
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
    index b6939e5870410..99f82bea42688 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
    @@ -17,18 +17,17 @@
     
     package org.apache.spark.ml.feature
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
    +import org.apache.spark.ml.param.ParamsSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
    -import org.apache.spark.sql.SQLContext
     
    -class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
    -  private var sqlContext: SQLContext = _
    +class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
     
    -  override def beforeAll(): Unit = {
    -    super.beforeAll()
    -    sqlContext = new SQLContext(sc)
    +  test("params") {
    +    ParamsSuite.checkParams(new StringIndexer)
    +    val model = new StringIndexerModel("indexer", Array("a", "b"))
    +    ParamsSuite.checkParams(model)
       }
     
       test("StringIndexer") {
    @@ -68,4 +67,12 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
         val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
         assert(output === expected)
       }
    +
    +  test("StringIndexerModel should keep silent if the input column does not exist.") {
    +    val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
    +      .setInputCol("label")
    +      .setOutputCol("labelIndex")
    +    val df = sqlContext.range(0L, 10L)
    +    assert(indexerModel.transform(df).eq(df))
    +  }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
    index d186ead8f542f..e5fd21c3f6fca 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
    @@ -19,64 +19,66 @@ package org.apache.spark.ml.feature
     
     import scala.beans.BeanInfo
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
    -import org.apache.spark.sql.{DataFrame, Row, SQLContext}
    +import org.apache.spark.sql.{DataFrame, Row}
     
     @BeanInfo
     case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
     
    -class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
    +class TokenizerSuite extends SparkFunSuite {
    +
    +  test("params") {
    +    ParamsSuite.checkParams(new Tokenizer)
    +  }
    +}
    +
    +class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
       import org.apache.spark.ml.feature.RegexTokenizerSuite._
    -  
    -  @transient var sqlContext: SQLContext = _
     
    -  override def beforeAll(): Unit = {
    -    super.beforeAll()
    -    sqlContext = new SQLContext(sc)
    +  test("params") {
    +    ParamsSuite.checkParams(new RegexTokenizer)
       }
     
       test("RegexTokenizer") {
    -    val tokenizer = new RegexTokenizer()
    +    val tokenizer0 = new RegexTokenizer()
    +      .setGaps(false)
    +      .setPattern("\\w+|\\p{Punct}")
           .setInputCol("rawText")
           .setOutputCol("tokens")
    -
         val dataset0 = sqlContext.createDataFrame(Seq(
           TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")),
           TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct"))
         ))
    -    testRegexTokenizer(tokenizer, dataset0)
    +    testRegexTokenizer(tokenizer0, dataset0)
     
         val dataset1 = sqlContext.createDataFrame(Seq(
           TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")),
           TokenizerTestData("Te,st. punct", Array("punct"))
         ))
    +    tokenizer0.setMinTokenLength(3)
    +    testRegexTokenizer(tokenizer0, dataset1)
     
    -    tokenizer.setMinTokenLength(3)
    -    testRegexTokenizer(tokenizer, dataset1)
    -
    -    tokenizer
    -      .setPattern("\\s")
    -      .setGaps(true)
    -      .setMinTokenLength(0)
    +    val tokenizer2 = new RegexTokenizer()
    +      .setInputCol("rawText")
    +      .setOutputCol("tokens")
         val dataset2 = sqlContext.createDataFrame(Seq(
           TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")),
    -      TokenizerTestData("Te,st.  punct", Array("Te,st.", "", "punct"))
    +      TokenizerTestData("Te,st.  punct", Array("Te,st.", "punct"))
         ))
    -    testRegexTokenizer(tokenizer, dataset2)
    +    testRegexTokenizer(tokenizer2, dataset2)
       }
     }
     
    -object RegexTokenizerSuite extends FunSuite {
    +object RegexTokenizerSuite extends SparkFunSuite {
     
       def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
         t.transform(dataset)
           .select("tokens", "wantedTokens")
           .collect()
    -      .foreach {
    -        case Row(tokens, wantedTokens) =>
    -          assert(tokens === wantedTokens)
    -    }
    +      .foreach { case Row(tokens, wantedTokens) =>
    +        assert(tokens === wantedTokens)
    +      }
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
    index 0db27607bc274..bb4d5b983e0d4 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
    @@ -17,20 +17,18 @@
     
     package org.apache.spark.ml.feature
     
    -import org.scalatest.FunSuite
    -
    -import org.apache.spark.SparkException
    +import org.apache.spark.{SparkException, SparkFunSuite}
    +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
    +import org.apache.spark.ml.param.ParamsSuite
     import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
    -import org.apache.spark.sql.{Row, SQLContext}
    -
    -class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
    +import org.apache.spark.sql.Row
    +import org.apache.spark.sql.functions.col
     
    -  @transient var sqlContext: SQLContext = _
    +class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
     
    -  override def beforeAll(): Unit = {
    -    super.beforeAll()
    -    sqlContext = new SQLContext(sc)
    +  test("params") {
    +    ParamsSuite.checkParams(new VectorAssembler)
       }
     
       test("assemble") {
    @@ -68,4 +66,39 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
             assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0)))
         }
       }
    +
    +  test("ML attributes") {
    +    val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari")
    +    val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0)
    +    val user = new AttributeGroup("user", Array(
    +      NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"),
    +      NumericAttribute.defaultAttr.withName("salary")))
    +    val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0)))
    +    val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
    +      .select(
    +        col("browser").as("browser", browser.toMetadata()),
    +        col("hour").as("hour", hour.toMetadata()),
    +        col("count"), // "count" is an integer column without ML attribute
    +        col("user").as("user", user.toMetadata()),
    +        col("ad")) // "ad" is a vector column without ML attribute
    +    val assembler = new VectorAssembler()
    +      .setInputCols(Array("browser", "hour", "count", "user", "ad"))
    +      .setOutputCol("features")
    +    val output = assembler.transform(df)
    +    val schema = output.schema
    +    val features = AttributeGroup.fromStructField(schema("features"))
    +    assert(features.size === 7)
    +    val browserOut = features.getAttr(0)
    +    assert(browserOut === browser.withIndex(0).withName("browser"))
    +    val hourOut = features.getAttr(1)
    +    assert(hourOut === hour.withIndex(1).withName("hour"))
    +    val countOut = features.getAttr(2)
    +    assert(countOut === NumericAttribute.defaultAttr.withName("count").withIndex(2))
    +    val userGenderOut = features.getAttr(3)
    +    assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3))
    +    val userSalaryOut = features.getAttr(4)
    +    assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4))
    +    assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
    +    assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
    +  }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
    index 38dc83b1241cf..03120c828ca96 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
    @@ -19,22 +19,18 @@ package org.apache.spark.ml.feature
     
     import scala.beans.{BeanInfo, BeanProperty}
     
    -import org.scalatest.FunSuite
    -
    -import org.apache.spark.SparkException
    +import org.apache.spark.{Logging, SparkException, SparkFunSuite}
     import org.apache.spark.ml.attribute._
    +import org.apache.spark.ml.param.ParamsSuite
     import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.rdd.RDD
    -import org.apache.spark.sql.{DataFrame, SQLContext}
    -
    +import org.apache.spark.sql.DataFrame
     
    -class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
    +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
     
       import VectorIndexerSuite.FeatureData
     
    -  @transient var sqlContext: SQLContext = _
    -
       // identical, of length 3
       @transient var densePoints1: DataFrame = _
       @transient var sparsePoints1: DataFrame = _
    @@ -86,7 +82,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
         checkPair(densePoints1Seq, sparsePoints1Seq)
         checkPair(densePoints2Seq, sparsePoints2Seq)
     
    -    sqlContext = new SQLContext(sc)
         densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
         sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
         densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
    @@ -97,6 +92,12 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
       private def getIndexer: VectorIndexer =
         new VectorIndexer().setInputCol("features").setOutputCol("indexed")
     
    +  test("params") {
    +    ParamsSuite.checkParams(new VectorIndexer)
    +    val model = new VectorIndexerModel("indexer", 1, Map.empty)
    +    ParamsSuite.checkParams(model)
    +  }
    +
       test("Cannot fit an empty DataFrame") {
         val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData))
         val vectorIndexer = getIndexer
    @@ -112,11 +113,11 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
         model.transform(sparsePoints1) // should work
         intercept[SparkException] {
           model.transform(densePoints2).collect()
    -      println("Did not throw error when fit, transform were called on vectors of different lengths")
    +      logInfo("Did not throw error when fit, transform were called on vectors of different lengths")
         }
         intercept[SparkException] {
           vectorIndexer.fit(badPoints)
    -      println("Did not throw error when fitting vectors of different lengths in same RDD.")
    +      logInfo("Did not throw error when fitting vectors of different lengths in same RDD.")
         }
       }
     
    @@ -195,7 +196,7 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
             }
           } catch {
             case e: org.scalatest.exceptions.TestFailedException =>
    -          println(errMsg)
    +          logError(errMsg)
               throw e
           }
         }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
    index 03ba86670d453..aa6ce533fd885 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
    @@ -17,14 +17,21 @@
     
     package org.apache.spark.ml.feature
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
     import org.apache.spark.mllib.linalg.{Vector, Vectors}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.sql.{Row, SQLContext}
    +import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
    +
    +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
     
    -class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
    +  test("params") {
    +    ParamsSuite.checkParams(new Word2Vec)
    +    val model = new Word2VecModel("w2v", new OldWord2VecModel(Map("a" -> Array(0.0f))))
    +    ParamsSuite.checkParams(model)
    +  }
     
       test("Word2Vec") {
         val sqlContext = new SQLContext(sc)
    @@ -35,9 +42,9 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
         val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
     
         val codes = Map(
    -      "a" -> Array(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451),
    -      "b" -> Array(1.0309048891067505,-1.29472815990448,0.22276712954044342),
    -      "c" -> Array(-0.08456747233867645,0.5137411952018738,0.11731560528278351)
    +      "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451),
    +      "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342),
    +      "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351)
         )
     
         val expected = doc.map { sentence =>
    @@ -52,6 +59,7 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
           .setVectorSize(3)
           .setInputCol("text")
           .setOutputCol("result")
    +      .setSeed(42L)
           .fit(docDF)
     
         model.transform(docDF).select("result", "expected").collect().foreach {
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
    index 1505ad872536b..778abcba22c10 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
    @@ -19,8 +19,7 @@ package org.apache.spark.ml.impl
     
     import scala.collection.JavaConverters._
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.api.java.JavaRDD
     import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
     import org.apache.spark.ml.tree._
    @@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.{SQLContext, DataFrame}
     
     
    -private[ml] object TreeTests extends FunSuite {
    +private[ml] object TreeTests extends SparkFunSuite {
     
       /**
        * Convert the given data to a DataFrame, and set the features and label metadata.
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
    index 6056e7d3f6ff8..050d4170ea017 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
    @@ -17,27 +17,28 @@
     
     package org.apache.spark.ml.param
     
    -import org.scalatest.FunSuite
    +import org.apache.spark.SparkFunSuite
     
    -class ParamsSuite extends FunSuite {
    +class ParamsSuite extends SparkFunSuite {
     
       test("param") {
         val solver = new TestParams()
    +    val uid = solver.uid
         import solver.{maxIter, inputCol}
     
         assert(maxIter.name === "maxIter")
    -    assert(maxIter.doc === "max number of iterations (>= 0)")
    -    assert(maxIter.parent.eq(solver))
    -    assert(maxIter.toString === "maxIter: max number of iterations (>= 0) (default: 10)")
    +    assert(maxIter.doc === "maximum number of iterations (>= 0)")
    +    assert(maxIter.parent === uid)
    +    assert(maxIter.toString === s"${uid}__maxIter")
         assert(!maxIter.isValid(-1))
         assert(maxIter.isValid(0))
         assert(maxIter.isValid(1))
     
         solver.setMaxIter(5)
    -    assert(maxIter.toString ===
    -      "maxIter: max number of iterations (>= 0) (default: 10, current: 5)")
    +    assert(solver.explainParam(maxIter) ===
    +      "maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)")
     
    -    assert(inputCol.toString === "inputCol: input column name (undefined)")
    +    assert(inputCol.toString === s"${uid}__inputCol")
     
         intercept[IllegalArgumentException] {
           solver.setMaxIter(-1)
    @@ -118,7 +119,10 @@ class ParamsSuite extends FunSuite {
         assert(!solver.isDefined(inputCol))
         intercept[NoSuchElementException](solver.getInputCol)
     
    -    assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n"))
    +    assert(solver.explainParam(maxIter) ===
    +      "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)")
    +    assert(solver.explainParams() ===
    +      Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n"))
     
         assert(solver.getParam("inputCol").eq(inputCol))
         assert(solver.getParam("maxIter").eq(maxIter))
    @@ -131,7 +135,7 @@ class ParamsSuite extends FunSuite {
         intercept[IllegalArgumentException] {
           solver.validateParams()
         }
    -    solver.validateParams(ParamMap(inputCol -> "input"))
    +    solver.copy(ParamMap(inputCol -> "input")).validateParams()
         solver.setInputCol("input")
         assert(solver.isSet(inputCol))
         assert(solver.isDefined(inputCol))
    @@ -148,7 +152,7 @@ class ParamsSuite extends FunSuite {
         assert(!solver.isSet(maxIter))
     
         val copied = solver.copy(ParamMap(solver.maxIter -> 50))
    -    assert(copied.uid !== solver.uid)
    +    assert(copied.uid === solver.uid)
         assert(copied.getInputCol === solver.getInputCol)
         assert(copied.getMaxIter === 50)
       }
    @@ -197,3 +201,31 @@ class ParamsSuite extends FunSuite {
         assert(inArray(1) && inArray(2) && !inArray(0))
       }
     }
    +
    +object ParamsSuite extends SparkFunSuite {
    +
    +  /**
    +   * Checks common requirements for [[Params.params]]:
    +   *   - params are ordered by names
    +   *   - param parent has the same UID as the object's UID
    +   *   - param name is the same as the param method name
    +   *   - obj.copy should return the same type as the obj
    +   */
    +  def checkParams(obj: Params): Unit = {
    +    val clazz = obj.getClass
    +
    +    val params = obj.params
    +    val paramNames = params.map(_.name)
    +    require(paramNames === paramNames.sorted, "params must be ordered by names")
    +    params.foreach { p =>
    +      assert(p.parent === obj.uid)
    +      assert(obj.getParam(p.name) === p)
    +      // TODO: Check that setters return self, which needs special handling for generic types.
    +    }
    +
    +    val copyMethod = clazz.getMethod("copy", classOf[ParamMap])
    +    val copyReturnType = copyMethod.getReturnType
    +    require(copyReturnType === obj.getClass,
    +      s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.")
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
    index dc16073640407..2759248344531 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
    @@ -18,9 +18,12 @@
     package org.apache.spark.ml.param
     
     import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter}
    +import org.apache.spark.ml.util.Identifiable
     
     /** A subclass of Params for testing. */
    -class TestParams extends Params with HasMaxIter with HasInputCol {
    +class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol {
    +
    +  def this() = this(Identifiable.randomUID("testParams"))
     
       def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
     
    @@ -35,7 +38,5 @@ class TestParams extends Params with HasMaxIter with HasInputCol {
         require(isDefined(inputCol))
       }
     
    -  override def copy(extra: ParamMap): TestParams = {
    -    super.copy(extra).asInstanceOf[TestParams]
    -  }
    +  override def copy(extra: ParamMap): TestParams = defaultCopy(extra)
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
    new file mode 100644
    index 0000000000000..b3af81a3c60b6
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
    @@ -0,0 +1,36 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.param.shared
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.{ParamMap, Params}
    +
    +class SharedParamsSuite extends SparkFunSuite {
    +
    +  test("outputCol") {
    +
    +    class Obj(override val uid: String) extends Params with HasOutputCol {
    +      override def copy(extra: ParamMap): Obj = defaultCopy(extra)
    +    }
    +
    +    val obj = new Obj("obj")
    +
    +    assert(obj.hasDefault(obj.outputCol))
    +    assert(obj.getOrDefault(obj.outputCol) === "obj__output")
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
    index fc7349330cf86..2e5cfe7027eb6 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
    @@ -25,9 +25,8 @@ import scala.collection.mutable.ArrayBuffer
     import scala.language.existentials
     
     import com.github.fommil.netlib.BLAS.{getInstance => blas}
    -import org.scalatest.FunSuite
     
    -import org.apache.spark.{Logging, SparkException}
    +import org.apache.spark.{Logging, SparkException, SparkFunSuite}
     import org.apache.spark.ml.recommendation.ALS._
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.util.MLlibTestSparkContext
    @@ -36,16 +35,14 @@ import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.{Row, SQLContext}
     import org.apache.spark.util.Utils
     
    -class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
    +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
     
    -  private var sqlContext: SQLContext = _
       private var tempDir: File = _
     
       override def beforeAll(): Unit = {
         super.beforeAll()
         tempDir = Utils.createTempDir()
         sc.setCheckpointDir(tempDir.getAbsolutePath)
    -    sqlContext = new SQLContext(sc)
       }
     
       override def afterAll(): Unit = {
    @@ -345,6 +342,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
           .setImplicitPrefs(implicitPrefs)
           .setNumUserBlocks(numUserBlocks)
           .setNumItemBlocks(numItemBlocks)
    +      .setSeed(0)
         val alpha = als.getAlpha
         val model = als.fit(training.toDF())
         val predictions = model.transform(test.toDF())
    @@ -425,17 +423,18 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
         val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
     
         val longRatings = ratings.map(r => Rating(r.user.toLong, r.item.toLong, r.rating))
    -    val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4)
    +    val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4, seed = 0)
         assert(longUserFactors.first()._1.getClass === classOf[Long])
     
         val strRatings = ratings.map(r => Rating(r.user.toString, r.item.toString, r.rating))
    -    val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4)
    +    val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4, seed = 0)
         assert(strUserFactors.first()._1.getClass === classOf[String])
       }
     
       test("nonnegative constraint") {
         val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
    -    val (userFactors, itemFactors) = ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true)
    +    val (userFactors, itemFactors) =
    +      ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true, seed = 0)
         def isNonnegative(factors: RDD[(Int, Array[Float])]): Boolean = {
           factors.values.map { _.forall(_ >= 0.0) }.reduce(_ && _)
         }
    @@ -459,7 +458,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
       test("partitioner in returned factors") {
         val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
         val (userFactors, itemFactors) = ALS.train(
    -      ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4)
    +      ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4, seed = 0)
         for ((tpe, factors) <- Seq(("User", userFactors), ("Item", itemFactors))) {
           assert(userFactors.partitioner.isDefined, s"$tpe factors should have partitioner.")
           val part = userFactors.partitioner.get
    @@ -476,8 +475,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
     
       test("als with large number of iterations") {
         val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
    -    ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2)
    -    ALS.train(
    -      ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true)
    +    ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, seed = 0)
    +    ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2,
    +      implicitPrefs = true, seed = 0)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
    index 5aa81b44ddaf9..33aa9d0d62343 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
    @@ -17,8 +17,7 @@
     
     package org.apache.spark.ml.regression
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.ml.impl.TreeTests
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
    @@ -28,7 +27,7 @@ import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
     
     
    -class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
    +class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       import DecisionTreeRegressorSuite.compareAPIs
     
    @@ -69,7 +68,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
       // TODO: test("model save/load")   SPARK-6725
     }
     
    -private[ml] object DecisionTreeRegressorSuite extends FunSuite {
    +private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
     
       /**
        * Train 2 decision trees on the given dataset, one using the old API and one using the new API.
    @@ -83,9 +82,9 @@ private[ml] object DecisionTreeRegressorSuite extends FunSuite {
         val oldTree = OldDecisionTree.train(data, oldStrategy)
         val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
         val newTree = dt.fit(newData)
    -    // Use parent, fittingParamMap from newTree since these are not checked anyways.
    +    // Use parent from newTree since this is not checked anyways.
         val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
    -      oldTree, newTree.parent, categoricalFeatures)
    +      oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures)
         TreeTests.checkEqual(oldTreeAsNew, newTree)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
    index 25b36ab08b67c..9682edcd9ba84 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
    @@ -17,21 +17,21 @@
     
     package org.apache.spark.ml.regression
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.ml.impl.TreeTests
    +import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
     import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.rdd.RDD
    -import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.{DataFrame, Row}
     
     
     /**
      * Test suite for [[GBTRegressor]].
      */
    -class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext {
    +class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       import GBTRegressorSuite.compareAPIs
     
    @@ -68,6 +68,26 @@ class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext {
         }
       }
     
    +  test("GBTRegressor behaves reasonably on toy data") {
    +    val df = sqlContext.createDataFrame(Seq(
    +      LabeledPoint(10, Vectors.dense(1, 2, 3, 4)),
    +      LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)),
    +      LabeledPoint(11, Vectors.dense(2, 2, 3, 4)),
    +      LabeledPoint(-6, Vectors.dense(6, 4, 2, 1)),
    +      LabeledPoint(9, Vectors.dense(1, 2, 6, 4)),
    +      LabeledPoint(-4, Vectors.dense(6, 3, 2, 2))
    +    ))
    +    val gbt = new GBTRegressor()
    +      .setMaxDepth(2)
    +      .setMaxIter(2)
    +    val model = gbt.fit(df)
    +    val preds = model.transform(df)
    +    val predictions = preds.select("prediction").map(_.getDouble(0))
    +    // Checks based on SPARK-8736 (to ensure it is not doing classification)
    +    assert(predictions.max() > 2)
    +    assert(predictions.min() < -1)
    +  }
    +
       // TODO: Reinstate test once runWithValidation is implemented  SPARK-7132
       /*
       test("runWithValidation stops early and performs better on a validation dataset") {
    @@ -129,8 +149,9 @@ private object GBTRegressorSuite {
         val oldModel = oldGBT.run(data)
         val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
         val newModel = gbt.fit(newData)
    -    // Use parent, fittingParamMap from newTree since these are not checked anyways.
    -    val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent, categoricalFeatures)
    +    // Use parent from newTree since this is not checked anyways.
    +    val oldModelAsNew = GBTRegressionModel.fromOld(
    +      oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures)
         TreeTests.checkEqual(oldModelAsNew, newModel)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
    index 80323ef5201a6..cf120cf2a4b47 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
    @@ -17,61 +17,68 @@
     
     package org.apache.spark.ml.regression
     
    -import org.scalatest.FunSuite
    -
    -import org.apache.spark.mllib.linalg.DenseVector
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
     import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
     import org.apache.spark.mllib.util.TestingUtils._
    -import org.apache.spark.sql.{Row, SQLContext, DataFrame}
    +import org.apache.spark.sql.{DataFrame, Row}
     
    -class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
    +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     
    -  @transient var sqlContext: SQLContext = _
       @transient var dataset: DataFrame = _
    +  @transient var datasetWithoutIntercept: DataFrame = _
    +
    +  /*
    +     In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
    +     is the same as the one trained by R's glmnet package. The following instruction
    +     describes how to reproduce the data in R.
     
    -  /**
    -   * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
    -   * is the same as the one trained by R's glmnet package. The following instruction
    -   * describes how to reproduce the data in R.
    -   *
    -   * import org.apache.spark.mllib.util.LinearDataGenerator
    -   * val data =
    -   *   sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2)
    -   * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path")
    +     import org.apache.spark.mllib.util.LinearDataGenerator
    +     val data =
    +       sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2),
    +         Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)
    +     data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1)
    +       .saveAsTextFile("path")
        */
       override def beforeAll(): Unit = {
         super.beforeAll()
    -    sqlContext = new SQLContext(sc)
         dataset = sqlContext.createDataFrame(
           sc.parallelize(LinearDataGenerator.generateLinearInput(
             6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
    +    /*
    +       datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
    +       training model without intercept
    +     */
    +    datasetWithoutIntercept = sqlContext.createDataFrame(
    +      sc.parallelize(LinearDataGenerator.generateLinearInput(
    +        0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
    +
       }
     
       test("linear regression with intercept without regularization") {
         val trainer = new LinearRegression
         val model = trainer.fit(dataset)
     
    -    /**
    -     * Using the following R code to load the data and train the model using glmnet package.
    -     *
    -     * library("glmnet")
    -     * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
    -     * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
    -     * label <- as.numeric(data$V1)
    -     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0))
    -     * > weights
    -     *  3 x 1 sparse Matrix of class "dgCMatrix"
    -     *                           s0
    -     * (Intercept)         6.300528
    -     * as.numeric.data.V2. 4.701024
    -     * as.numeric.data.V3. 7.198257
    +    /*
    +       Using the following R code to load the data and train the model using glmnet package.
    +
    +       library("glmnet")
    +       data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
    +       features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
    +       label <- as.numeric(data$V1)
    +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0))
    +       > weights
    +        3 x 1 sparse Matrix of class "dgCMatrix"
    +                                 s0
    +       (Intercept)         6.300528
    +       as.numeric.data.V2. 4.701024
    +       as.numeric.data.V3. 7.198257
          */
         val interceptR = 6.298698
    -    val weightsR = Array(4.700706, 7.199082)
    +    val weightsR = Vectors.dense(4.700706, 7.199082)
     
         assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    +    assert(model.weights ~= weightsR relTol 1E-3)
     
         model.transform(dataset).select("features", "prediction").collect().foreach {
           case Row(features: DenseVector, prediction1: Double) =>
    @@ -81,25 +88,87 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
         }
       }
     
    +  test("linear regression without intercept without regularization") {
    +    val trainer = (new LinearRegression).setFitIntercept(false)
    +    val model = trainer.fit(dataset)
    +    val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept)
    +
    +    /*
    +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
    +         intercept = FALSE))
    +       > weights
    +        3 x 1 sparse Matrix of class "dgCMatrix"
    +                                 s0
    +       (Intercept)         .
    +       as.numeric.data.V2. 6.995908
    +       as.numeric.data.V3. 5.275131
    +     */
    +    val weightsR = Vectors.dense(6.995908, 5.275131)
    +
    +    assert(model.intercept ~== 0 absTol 1E-3)
    +    assert(model.weights ~= weightsR relTol 1E-3)
    +    /*
    +       Then again with the data with no intercept:
    +       > weightsWithoutIntercept
    +       3 x 1 sparse Matrix of class "dgCMatrix"
    +                                   s0
    +       (Intercept)           .
    +       as.numeric.data3.V2. 4.70011
    +       as.numeric.data3.V3. 7.19943
    +     */
    +    val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943)
    +
    +    assert(modelWithoutIntercept.intercept ~== 0 absTol 1E-3)
    +    assert(modelWithoutIntercept.weights ~= weightsWithoutInterceptR relTol 1E-3)
    +  }
    +
       test("linear regression with intercept with L1 regularization") {
         val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
         val model = trainer.fit(dataset)
     
    -    /**
    -     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
    -     * > weights
    -     *  3 x 1 sparse Matrix of class "dgCMatrix"
    -     *                           s0
    -     * (Intercept)         6.311546
    -     * as.numeric.data.V2. 2.123522
    -     * as.numeric.data.V3. 4.605651
    +    /*
    +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
    +       > weights
    +        3 x 1 sparse Matrix of class "dgCMatrix"
    +                                 s0
    +       (Intercept)         6.24300
    +       as.numeric.data.V2. 4.024821
    +       as.numeric.data.V3. 6.679841
          */
    -    val interceptR = 6.243000
    -    val weightsR = Array(4.024821, 6.679841)
    +    val interceptR = 6.24300
    +    val weightsR = Vectors.dense(4.024821, 6.679841)
     
         assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    +    assert(model.weights ~= weightsR relTol 1E-3)
    +
    +    model.transform(dataset).select("features", "prediction").collect().foreach {
    +      case Row(features: DenseVector, prediction1: Double) =>
    +        val prediction2 =
    +          features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
    +        assert(prediction1 ~== prediction2 relTol 1E-5)
    +    }
    +  }
    +
    +  test("linear regression without intercept with L1 regularization") {
    +    val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
    +      .setFitIntercept(false)
    +    val model = trainer.fit(dataset)
    +
    +    /*
    +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
    +         intercept=FALSE))
    +       > weights
    +        3 x 1 sparse Matrix of class "dgCMatrix"
    +                                 s0
    +       (Intercept)          .
    +       as.numeric.data.V2. 6.299752
    +       as.numeric.data.V3. 4.772913
    +     */
    +    val interceptR = 0.0
    +    val weightsR = Vectors.dense(6.299752, 4.772913)
    +
    +    assert(model.intercept ~== interceptR absTol 1E-5)
    +    assert(model.weights ~= weightsR relTol 1E-3)
     
         model.transform(dataset).select("features", "prediction").collect().foreach {
           case Row(features: DenseVector, prediction1: Double) =>
    @@ -113,21 +182,49 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
         val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
         val model = trainer.fit(dataset)
     
    -    /**
    -     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
    -     * > weights
    -     *  3 x 1 sparse Matrix of class "dgCMatrix"
    -     *                           s0
    -     * (Intercept)         6.328062
    -     * as.numeric.data.V2. 3.222034
    -     * as.numeric.data.V3. 4.926260
    +    /*
    +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
    +       > weights
    +        3 x 1 sparse Matrix of class "dgCMatrix"
    +                                 s0
    +       (Intercept)         6.328062
    +       as.numeric.data.V2. 3.222034
    +       as.numeric.data.V3. 4.926260
          */
         val interceptR = 5.269376
    -    val weightsR = Array(3.736216, 5.712356)
    +    val weightsR = Vectors.dense(3.736216, 5.712356)
     
         assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    +    assert(model.weights ~= weightsR relTol 1E-3)
    +
    +    model.transform(dataset).select("features", "prediction").collect().foreach {
    +      case Row(features: DenseVector, prediction1: Double) =>
    +        val prediction2 =
    +          features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
    +        assert(prediction1 ~== prediction2 relTol 1E-5)
    +    }
    +  }
    +
    +  test("linear regression without intercept with L2 regularization") {
    +    val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
    +      .setFitIntercept(false)
    +    val model = trainer.fit(dataset)
    +
    +    /*
    +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
    +         intercept = FALSE))
    +       > weights
    +        3 x 1 sparse Matrix of class "dgCMatrix"
    +                                 s0
    +       (Intercept)         .
    +       as.numeric.data.V2. 5.522875
    +       as.numeric.data.V3. 4.214502
    +     */
    +    val interceptR = 0.0
    +    val weightsR = Vectors.dense(5.522875, 4.214502)
    +
    +    assert(model.intercept ~== interceptR absTol 1E-3)
    +    assert(model.weights ~== weightsR relTol 1E-3)
     
         model.transform(dataset).select("features", "prediction").collect().foreach {
           case Row(features: DenseVector, prediction1: Double) =>
    @@ -141,21 +238,49 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
         val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
         val model = trainer.fit(dataset)
     
    -    /**
    -     * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
    -     * > weights
    -     * 3 x 1 sparse Matrix of class "dgCMatrix"
    -     * s0
    -     * (Intercept)         6.324108
    -     * as.numeric.data.V2. 3.168435
    -     * as.numeric.data.V3. 5.200403
    +    /*
    +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
    +       > weights
    +       3 x 1 sparse Matrix of class "dgCMatrix"
    +       s0
    +       (Intercept)         6.324108
    +       as.numeric.data.V2. 3.168435
    +       as.numeric.data.V3. 5.200403
          */
         val interceptR = 5.696056
    -    val weightsR = Array(3.670489, 6.001122)
    +    val weightsR = Vectors.dense(3.670489, 6.001122)
     
         assert(model.intercept ~== interceptR relTol 1E-3)
    -    assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
    -    assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
    +    assert(model.weights ~== weightsR relTol 1E-3)
    +
    +    model.transform(dataset).select("features", "prediction").collect().foreach {
    +      case Row(features: DenseVector, prediction1: Double) =>
    +        val prediction2 =
    +          features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
    +        assert(prediction1 ~== prediction2 relTol 1E-5)
    +    }
    +  }
    +
    +  test("linear regression without intercept with ElasticNet regularization") {
    +    val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
    +      .setFitIntercept(false)
    +    val model = trainer.fit(dataset)
    +
    +    /*
    +       weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
    +         intercept=FALSE))
    +       > weights
    +       3 x 1 sparse Matrix of class "dgCMatrix"
    +       s0
    +       (Intercept)         .
    +       as.numeric.dataM.V2. 5.673348
    +       as.numeric.dataM.V3. 4.322251
    +     */
    +    val interceptR = 0.0
    +    val weightsR = Vectors.dense(5.673348, 4.322251)
    +
    +    assert(model.intercept ~== interceptR absTol 1E-3)
    +    assert(model.weights ~= weightsR relTol 1E-3)
     
         model.transform(dataset).select("features", "prediction").collect().foreach {
           case Row(features: DenseVector, prediction1: Double) =>
    @@ -164,4 +289,63 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
             assert(prediction1 ~== prediction2 relTol 1E-5)
         }
       }
    +
    +  test("linear regression model training summary") {
    +    val trainer = new LinearRegression
    +    val model = trainer.fit(dataset)
    +
    +    // Training results for the model should be available
    +    assert(model.hasSummary)
    +
    +    // Residuals in [[LinearRegressionResults]] should equal those manually computed
    +    val expectedResiduals = dataset.select("features", "label")
    +      .map { case Row(features: DenseVector, label: Double) =>
    +      val prediction =
    +        features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
    +      prediction - label
    +    }
    +      .zip(model.summary.residuals.map(_.getDouble(0)))
    +      .collect()
    +      .foreach { case (manualResidual: Double, resultResidual: Double) =>
    +      assert(manualResidual ~== resultResidual relTol 1E-5)
    +    }
    +
    +    /*
    +       Use the following R code to generate model training results.
    +
    +       predictions <- predict(fit, newx=features)
    +       residuals <- predictions - label
    +       > mean(residuals^2) # MSE
    +       [1] 0.009720325
    +       > mean(abs(residuals)) # MAD
    +       [1] 0.07863206
    +       > cor(predictions, label)^2# r^2
    +               [,1]
    +       s0 0.9998749
    +     */
    +    assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5)
    +    assert(model.summary.meanAbsoluteError ~== 0.07863206  relTol 1E-5)
    +    assert(model.summary.r2 ~== 0.9998749 relTol 1E-5)
    +
    +    // Objective function should be monotonically decreasing for linear regression
    +    assert(
    +      model.summary
    +        .objectiveHistory
    +        .sliding(2)
    +        .forall(x => x(0) >= x(1)))
    +  }
    +
    +  test("linear regression model testset evaluation summary") {
    +    val trainer = new LinearRegression
    +    val model = trainer.fit(dataset)
    +
    +    // Evaluating on training dataset should yield results summary equal to training summary
    +    val testSummary = model.evaluate(dataset)
    +    assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5)
    +    assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5)
    +    model.summary.residuals.select("residuals").collect()
    +      .zip(testSummary.residuals.select("residuals").collect())
    +      .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 }
    +  }
    +
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
    index 45f09f4fdab81..b24ecaa57c89b 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
    @@ -17,8 +17,7 @@
     
     package org.apache.spark.ml.regression
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.ml.impl.TreeTests
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
    @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame
     /**
      * Test suite for [[RandomForestRegressor]].
      */
    -class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext {
    +class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       import RandomForestRegressorSuite.compareAPIs
     
    @@ -98,7 +97,7 @@ class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext {
       */
     }
     
    -private object RandomForestRegressorSuite extends FunSuite {
    +private object RandomForestRegressorSuite extends SparkFunSuite {
     
       /**
        * Train 2 models on the given dataset, one using the old API and one using the new API.
    @@ -114,9 +113,9 @@ private object RandomForestRegressorSuite extends FunSuite {
           data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
         val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
         val newModel = rf.fit(newData)
    -    // Use parent, fittingParamMap from newTree since these are not checked anyways.
    +    // Use parent from newTree since this is not checked anyways.
         val oldModelAsNew = RandomForestRegressionModel.fromOld(
    -      oldModel, newModel.parent, categoricalFeatures)
    +      oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures)
         TreeTests.checkEqual(oldModelAsNew, newModel)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
    index 05313d440fbf6..db64511a76055 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
    @@ -17,15 +17,19 @@
     
     package org.apache.spark.ml.tuning
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.{Estimator, Model}
     import org.apache.spark.ml.classification.LogisticRegression
    -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
    +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
    +import org.apache.spark.ml.param.ParamMap
    +import org.apache.spark.ml.param.shared.HasInputCol
    +import org.apache.spark.ml.regression.LinearRegression
     import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
    -import org.apache.spark.mllib.util.MLlibTestSparkContext
    -import org.apache.spark.sql.{SQLContext, DataFrame}
    +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
    +import org.apache.spark.sql.{DataFrame, SQLContext}
    +import org.apache.spark.sql.types.StructType
     
    -class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
    +class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       @transient var dataset: DataFrame = _
     
    @@ -52,5 +56,90 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
         val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
         assert(parent.getRegParam === 0.001)
         assert(parent.getMaxIter === 10)
    +    assert(cvModel.avgMetrics.length === lrParamMaps.length)
    +  }
    +
    +  test("cross validation with linear regression") {
    +    val dataset = sqlContext.createDataFrame(
    +      sc.parallelize(LinearDataGenerator.generateLinearInput(
    +        6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
    +
    +    val trainer = new LinearRegression
    +    val lrParamMaps = new ParamGridBuilder()
    +      .addGrid(trainer.regParam, Array(1000.0, 0.001))
    +      .addGrid(trainer.maxIter, Array(0, 10))
    +      .build()
    +    val eval = new RegressionEvaluator()
    +    val cv = new CrossValidator()
    +      .setEstimator(trainer)
    +      .setEstimatorParamMaps(lrParamMaps)
    +      .setEvaluator(eval)
    +      .setNumFolds(3)
    +    val cvModel = cv.fit(dataset)
    +    val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
    +    assert(parent.getRegParam === 0.001)
    +    assert(parent.getMaxIter === 10)
    +    assert(cvModel.avgMetrics.length === lrParamMaps.length)
    +
    +    eval.setMetricName("r2")
    +    val cvModel2 = cv.fit(dataset)
    +    val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression]
    +    assert(parent2.getRegParam === 0.001)
    +    assert(parent2.getMaxIter === 10)
    +    assert(cvModel2.avgMetrics.length === lrParamMaps.length)
    +  }
    +
    +  test("validateParams should check estimatorParamMaps") {
    +    import CrossValidatorSuite._
    +
    +    val est = new MyEstimator("est")
    +    val eval = new MyEvaluator
    +    val paramMaps = new ParamGridBuilder()
    +      .addGrid(est.inputCol, Array("input1", "input2"))
    +      .build()
    +
    +    val cv = new CrossValidator()
    +      .setEstimator(est)
    +      .setEstimatorParamMaps(paramMaps)
    +      .setEvaluator(eval)
    +
    +    cv.validateParams() // This should pass.
    +
    +    val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
    +    cv.setEstimatorParamMaps(invalidParamMaps)
    +    intercept[IllegalArgumentException] {
    +      cv.validateParams()
    +    }
    +  }
    +}
    +
    +object CrossValidatorSuite {
    +
    +  abstract class MyModel extends Model[MyModel]
    +
    +  class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
    +
    +    override def validateParams(): Unit = require($(inputCol).nonEmpty)
    +
    +    override def fit(dataset: DataFrame): MyModel = {
    +      throw new UnsupportedOperationException
    +    }
    +
    +    override def transformSchema(schema: StructType): StructType = {
    +      throw new UnsupportedOperationException
    +    }
    +
    +    override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
    +  }
    +
    +  class MyEvaluator extends Evaluator {
    +
    +    override def evaluate(dataset: DataFrame): Double = {
    +      throw new UnsupportedOperationException
    +    }
    +
    +    override val uid: String = "eval"
    +
    +    override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
    index 20aa100112bfe..810b70049ec15 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
    @@ -19,11 +19,10 @@ package org.apache.spark.ml.tuning
     
     import scala.collection.mutable
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.ml.param.{ParamMap, TestParams}
     
    -class ParamGridBuilderSuite extends FunSuite {
    +class ParamGridBuilderSuite extends SparkFunSuite {
     
       val solver = new TestParams()
       import solver.{inputCol, maxIter}
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
    index a629dba8a426f..59944416d96a6 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
    @@ -17,13 +17,12 @@
     
     package org.apache.spark.mllib.api.python
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix}
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.recommendation.Rating
     
    -class PythonMLLibAPISuite extends FunSuite {
    +class PythonMLLibAPISuite extends SparkFunSuite {
     
       SerDe.initialize()
     
    @@ -84,7 +83,7 @@ class PythonMLLibAPISuite extends FunSuite {
     
         val smt = new SparseMatrix(
           3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9),
    -      isTransposed=true)
    +      isTransposed = true)
         val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix]
         assert(smt.toArray === nsmt.toArray)
       }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
    index fb0a194718802..2473510e13514 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
    @@ -21,9 +21,9 @@ import scala.collection.JavaConversions._
     import scala.util.Random
     import scala.util.control.Breaks._
     
    -import org.scalatest.FunSuite
     import org.scalatest.Matchers
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{Vector, Vectors}
     import org.apache.spark.mllib.regression._
     import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
    @@ -101,7 +101,8 @@ object LogisticRegressionSuite {
           // This doesn't work if `vector` is a sparse vector.
           val vectorArray = vector.toArray
           var i = 0
    -      while (i < vectorArray.length) {
    +      val len = vectorArray.length
    +      while (i < len) {
             vectorArray(i) = vectorArray(i) * math.sqrt(xVariance(i)) + xMean(i)
             i += 1
           }
    @@ -118,7 +119,7 @@ object LogisticRegressionSuite {
           }
           // Preventing the overflow when we compute the probability
           val maxMargin = margins.max
    -      if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin
    +      if (maxMargin > 0) for (i <- 0 until nClasses) margins(i) -= maxMargin
     
           // Computing the probabilities for each class from the margins.
           val norm = {
    @@ -129,7 +130,7 @@ object LogisticRegressionSuite {
             }
             temp
           }
    -      for (i <-0 until nClasses) probs(i) /= norm
    +      for (i <- 0 until nClasses) probs(i) /= norm
     
           // Compute the cumulative probability so we can generate a random number and assign a label.
           for (i <- 1 until nClasses) probs(i) += probs(i - 1)
    @@ -168,7 +169,7 @@ object LogisticRegressionSuite {
     }
     
     
    -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
    +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
       def validatePrediction(
           predictions: Seq[Double],
           input: Seq[LabeledPoint],
    @@ -195,6 +196,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
           .setStepSize(10.0)
           .setRegParam(0.0)
           .setNumIterations(20)
    +      .setConvergenceTol(0.0005)
     
         val model = lr.run(testRDD)
     
    @@ -540,7 +542,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
     
     }
     
    -class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
    +class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
     
       test("task size should be small in both training and prediction using SGD optimizer") {
         val m = 4
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
    index ea89b17b7c08f..cffa1ab700f80 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
    @@ -19,20 +19,20 @@ package org.apache.spark.mllib.classification
     
     import scala.util.Random
     
    -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
    +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
     import breeze.stats.distributions.{Multinomial => BrzMultinomial}
     
    -import org.scalatest.FunSuite
    -
    -import org.apache.spark.SparkException
    -import org.apache.spark.mllib.linalg.Vectors
    +import org.apache.spark.{SparkException, SparkFunSuite}
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
    +import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.util.Utils
     
    -
     object NaiveBayesSuite {
     
    +  import NaiveBayes.{Multinomial, Bernoulli}
    +
       private def calcLabel(p: Double, pi: Array[Double]): Int = {
         var sum = 0.0
         for (j <- 0 until pi.length) {
    @@ -48,7 +48,7 @@ object NaiveBayesSuite {
         theta: Array[Array[Double]],  // CXD
         nPoints: Int,
         seed: Int,
    -    modelType: String = "Multinomial",
    +    modelType: String = Multinomial,
         sample: Int = 10): Seq[LabeledPoint] = {
         val D = theta(0).length
         val rnd = new Random(seed)
    @@ -58,10 +58,10 @@ object NaiveBayesSuite {
         for (i <- 0 until nPoints) yield {
           val y = calcLabel(rnd.nextDouble(), _pi)
           val xi = modelType match {
    -        case "Bernoulli" => Array.tabulate[Double] (D) { j =>
    +        case Bernoulli => Array.tabulate[Double] (D) { j =>
                 if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
             }
    -        case "Multinomial" =>
    +        case Multinomial =>
               val mult = BrzMultinomial(BDV(_theta(y)))
               val emptyMap = (0 until D).map(x => (x, 0.0)).toMap
               val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
    @@ -70,7 +70,7 @@ object NaiveBayesSuite {
               counts.toArray.sortBy(_._1).map(_._2)
             case _ =>
               // This should never happen.
    -          throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType")
    +          throw new UnknownError(s"Invalid modelType: $modelType.")
           }
     
           LabeledPoint(y, Vectors.dense(xi))
    @@ -79,16 +79,16 @@ object NaiveBayesSuite {
     
       /** Bernoulli NaiveBayes with binary labels, 3 features */
       private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
    -    pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
    -    "Bernoulli")
    +    pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Bernoulli)
     
       /** Multinomial NaiveBayes with binary labels, 3 features */
       private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
    -    pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
    -    "Multinomial")
    +    pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial)
     }
     
    -class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
    +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  import NaiveBayes.{Multinomial, Bernoulli}
     
       def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
         val numOfPredictions = predictions.zip(input).count {
    @@ -117,6 +117,11 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
         }
       }
     
    +  test("model types") {
    +    assert(Multinomial === "multinomial")
    +    assert(Bernoulli === "bernoulli")
    +  }
    +
       test("get, set params") {
         val nb = new NaiveBayes()
         nb.setLambda(2.0)
    @@ -134,16 +139,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
           Array(0.10, 0.10, 0.70, 0.10)  // label 2
         ).map(_.map(math.log))
     
    -    val testData = NaiveBayesSuite.generateNaiveBayesInput(
    -      pi, theta, nPoints, 42, "Multinomial")
    +    val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42, Multinomial)
         val testRDD = sc.parallelize(testData, 2)
         testRDD.cache()
     
    -    val model = NaiveBayes.train(testRDD, 1.0, "Multinomial")
    +    val model = NaiveBayes.train(testRDD, 1.0, Multinomial)
         validateModelFit(pi, theta, model)
     
         val validationData = NaiveBayesSuite.generateNaiveBayesInput(
    -      pi, theta, nPoints, 17, "Multinomial")
    +      pi, theta, nPoints, 17, Multinomial)
         val validationRDD = sc.parallelize(validationData, 2)
     
         // Test prediction on RDD.
    @@ -151,6 +155,29 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
     
         // Test prediction on Array.
         validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
    +
    +    // Test posteriors
    +    validationData.map(_.features).foreach { features =>
    +      val predicted = model.predictProbabilities(features).toArray
    +      assert(predicted.sum ~== 1.0 relTol 1.0e-10)
    +      val expected = expectedMultinomialProbabilities(model, features)
    +      expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
    +    }
    +  }
    +
    +  /**
    +   * @param model Multinomial Naive Bayes model
    +   * @param testData input to compute posterior probabilities for
    +   * @return posterior class probabilities (in order of labels) for input
    +   */
    +  private def expectedMultinomialProbabilities(model: NaiveBayesModel, testData: Vector) = {
    +    val piVector = new BDV(model.pi)
    +    // model.theta is row-major; treat it as col-major representation of transpose, and transpose:
    +    val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
    +    val logClassProbs: BV[Double] = piVector + (thetaMatrix * testData.toBreeze)
    +    val classProbs = logClassProbs.toArray.map(math.exp)
    +    val classProbsSum = classProbs.sum
    +    classProbs.map(_ / classProbsSum)
       }
     
       test("Naive Bayes Bernoulli") {
    @@ -159,19 +186,19 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
         val theta = Array(
           Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0
           Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1
    -      Array(0.02, 0.02, 0.60, 0.02,  0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30)  // label 2
    +      Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30)  // label 2
         ).map(_.map(math.log))
     
         val testData = NaiveBayesSuite.generateNaiveBayesInput(
    -      pi, theta, nPoints, 45, "Bernoulli")
    +      pi, theta, nPoints, 45, Bernoulli)
         val testRDD = sc.parallelize(testData, 2)
         testRDD.cache()
     
    -    val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli")
    +    val model = NaiveBayes.train(testRDD, 1.0, Bernoulli)
         validateModelFit(pi, theta, model)
     
         val validationData = NaiveBayesSuite.generateNaiveBayesInput(
    -      pi, theta, nPoints, 20, "Bernoulli")
    +      pi, theta, nPoints, 20, Bernoulli)
         val validationRDD = sc.parallelize(validationData, 2)
     
         // Test prediction on RDD.
    @@ -179,6 +206,33 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
     
         // Test prediction on Array.
         validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
    +
    +    // Test posteriors
    +    validationData.map(_.features).foreach { features =>
    +      val predicted = model.predictProbabilities(features).toArray
    +      assert(predicted.sum ~== 1.0 relTol 1.0e-10)
    +      val expected = expectedBernoulliProbabilities(model, features)
    +      expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
    +    }
    +  }
    +
    +  /**
    +   * @param model Bernoulli Naive Bayes model
    +   * @param testData input to compute posterior probabilities for
    +   * @return posterior class probabilities (in order of labels) for input
    +   */
    +  private def expectedBernoulliProbabilities(model: NaiveBayesModel, testData: Vector) = {
    +    val piVector = new BDV(model.pi)
    +    val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
    +    val negThetaMatrix = new BDM(model.theta(0).length, model.theta.length,
    +      model.theta.flatten.map(v => math.log(1.0 - math.exp(v)))).t
    +    val testBreeze = testData.toBreeze
    +    val negTestBreeze = new BDV(Array.fill(testBreeze.size)(1.0)) - testBreeze
    +    val piTheta: BV[Double] = piVector + (thetaMatrix * testBreeze)
    +    val logClassProbs: BV[Double] = piTheta + (negThetaMatrix * negTestBreeze)
    +    val classProbs = logClassProbs.toArray.map(math.exp)
    +    val classProbsSum = classProbs.sum
    +    classProbs.map(_ / classProbsSum)
       }
     
       test("detect negative values") {
    @@ -208,6 +262,39 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
         }
       }
     
    +  test("detect non zero or one values in Bernoulli") {
    +    val badTrain = Seq(
    +      LabeledPoint(1.0, Vectors.dense(1.0)),
    +      LabeledPoint(0.0, Vectors.dense(2.0)),
    +      LabeledPoint(1.0, Vectors.dense(1.0)),
    +      LabeledPoint(1.0, Vectors.dense(0.0)))
    +
    +    intercept[SparkException] {
    +      NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, Bernoulli)
    +    }
    +
    +    val okTrain = Seq(
    +      LabeledPoint(1.0, Vectors.dense(1.0)),
    +      LabeledPoint(0.0, Vectors.dense(0.0)),
    +      LabeledPoint(1.0, Vectors.dense(1.0)),
    +      LabeledPoint(1.0, Vectors.dense(1.0)),
    +      LabeledPoint(0.0, Vectors.dense(0.0)),
    +      LabeledPoint(1.0, Vectors.dense(1.0)),
    +      LabeledPoint(1.0, Vectors.dense(1.0))
    +    )
    +
    +    val badPredict = Seq(
    +      Vectors.dense(1.0),
    +      Vectors.dense(2.0),
    +      Vectors.dense(1.0),
    +      Vectors.dense(0.0))
    +
    +    val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, Bernoulli)
    +    intercept[SparkException] {
    +      model.predict(sc.makeRDD(badPredict, 2)).collect()
    +    }
    +  }
    +
       test("model save/load: 2.0 to 2.0") {
         val tempDir = Utils.createTempDir()
         val path = tempDir.toURI.toString
    @@ -242,14 +329,14 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
           assert(model.labels === sameModel.labels)
           assert(model.pi === sameModel.pi)
           assert(model.theta === sameModel.theta)
    -      assert(model.modelType === "Multinomial")
    +      assert(model.modelType === Multinomial)
         } finally {
           Utils.deleteRecursively(tempDir)
         }
       }
     }
     
    -class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
    +class NaiveBayesClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
     
       test("task size should be small in both training and prediction") {
         val m = 10
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
    index 6de098b383ba3..b1d78cba9e3dc 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
    @@ -21,9 +21,8 @@ import scala.collection.JavaConversions._
     import scala.util.Random
     
     import org.jblas.DoubleMatrix
    -import org.scalatest.FunSuite
     
    -import org.apache.spark.SparkException
    +import org.apache.spark.{SparkException, SparkFunSuite}
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.regression._
     import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
    @@ -46,7 +45,7 @@ object SVMSuite {
         nPoints: Int,
         seed: Int): Seq[LabeledPoint] = {
         val rnd = new Random(seed)
    -    val weightsMat = new DoubleMatrix(1, weights.length, weights:_*)
    +    val weightsMat = new DoubleMatrix(1, weights.length, weights : _*)
         val x = Array.fill[Array[Double]](nPoints)(
             Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0))
         val y = x.map { xi =>
    @@ -62,7 +61,7 @@ object SVMSuite {
     
     }
     
    -class SVMSuite extends FunSuite with MLlibTestSparkContext {
    +class SVMSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
         val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
    @@ -91,7 +90,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
         val model = svm.run(testRDD)
     
         val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
    -    val validationRDD  = sc.parallelize(validationData, 2)
    +    val validationRDD = sc.parallelize(validationData, 2)
     
         // Test prediction on RDD.
     
    @@ -117,7 +116,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
         val B = -1.5
         val C = 1.0
     
    -    val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42)
    +    val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
     
         val testRDD = sc.parallelize(testData, 2)
         testRDD.cache()
    @@ -127,8 +126,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
     
         val model = svm.run(testRDD)
     
    -    val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17)
    -    val validationRDD  = sc.parallelize(validationData, 2)
    +    val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
    +    val validationRDD = sc.parallelize(validationData, 2)
     
         // Test prediction on RDD.
         validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
    @@ -145,7 +144,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
         val B = -1.5
         val C = 1.0
     
    -    val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42)
    +    val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
     
         val initialB = -1.0
         val initialC = -1.0
    @@ -159,8 +158,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
     
         val model = svm.run(testRDD, initialWeights)
     
    -    val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17)
    -    val validationRDD  = sc.parallelize(validationData,2)
    +    val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
    +    val validationRDD = sc.parallelize(validationData, 2)
     
         // Test prediction on RDD.
         validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
    @@ -177,7 +176,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
         val B = -1.5
         val C = 1.0
     
    -    val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42)
    +    val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
         val testRDD = sc.parallelize(testData, 2)
     
         val testRDDInvalid = testRDD.map { lp =>
    @@ -229,7 +228,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
       }
     }
     
    -class SVMClusterSuite extends FunSuite with LocalClusterSparkContext {
    +class SVMClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
     
       test("task size should be small in both training and prediction") {
         val m = 4
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
    index 5683b55e8500a..fd653296c9d97 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
    @@ -19,15 +19,14 @@ package org.apache.spark.mllib.classification
     
     import scala.collection.mutable.ArrayBuffer
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.streaming.dstream.DStream
     import org.apache.spark.streaming.TestSuiteBase
     
    -class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase {
    +class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase {
     
       // use longer wait time to ensure job completion
       override def maxWaitTimeMillis: Int = 30000
    @@ -159,4 +158,21 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase {
         val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
         assert(error.head > 0.8 & error.last < 0.2)
       }
    +
    +  // Test empty RDDs in a stream
    +  test("handling empty RDDs in a stream") {
    +    val model = new StreamingLogisticRegressionWithSGD()
    +      .setInitialWeights(Vectors.dense(-0.1))
    +      .setStepSize(0.01)
    +      .setNumIterations(10)
    +    val numBatches = 10
    +    val emptyInput = Seq.empty[Seq[LabeledPoint]]
    +    val ssc = setupStreams(emptyInput,
    +      (inputDStream: DStream[LabeledPoint]) => {
    +        model.trainOn(inputDStream)
    +        model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
    +      }
    +    )
    +    val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
    +  }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
    index f356ffa3e3a26..b218d72f1268a 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
    @@ -17,15 +17,14 @@
     
     package org.apache.spark.mllib.clustering
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{Vectors, Matrices}
     import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.util.Utils
     
    -class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
    +class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
       test("single cluster") {
         val data = sc.parallelize(Array(
           Vectors.dense(6.0, 9.0),
    @@ -47,7 +46,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
         }
     
       }
    -  
    +
       test("two clusters") {
         val data = sc.parallelize(GaussianTestData.data)
     
    @@ -63,7 +62,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
         val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
         val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
         val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
    -    
    +
         val gmm = new GaussianMixture()
           .setK(2)
           .setInitialModel(initialGmm)
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
    index 0f2b26d462ad2..0dbbd7127444f 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
    @@ -19,14 +19,13 @@ package org.apache.spark.mllib.clustering
     
     import scala.util.Random
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
     import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.util.Utils
     
    -class KMeansSuite extends FunSuite with MLlibTestSparkContext {
    +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM}
     
    @@ -75,7 +74,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
         val center = Vectors.dense(1.0, 2.0, 3.0)
     
         // Make sure code runs.
    -    var model = KMeans.train(data, k=2, maxIterations=1)
    +    var model = KMeans.train(data, k = 2, maxIterations = 1)
         assert(model.clusterCenters.size === 2)
       }
     
    @@ -87,7 +86,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
           2)
     
         // Make sure code runs.
    -    var model = KMeans.train(data, k=3, maxIterations=1)
    +    var model = KMeans.train(data, k = 3, maxIterations = 1)
         assert(model.clusterCenters.size === 3)
       }
     
    @@ -281,7 +280,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
       }
     }
     
    -object KMeansSuite extends FunSuite {
    +object KMeansSuite extends SparkFunSuite {
       def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = {
         val singlePoint = isSparse match {
           case true =>
    @@ -305,7 +304,7 @@ object KMeansSuite extends FunSuite {
       }
     }
     
    -class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {
    +class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
     
       test("task size should be small in both training and prediction") {
         val m = 4
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
    index d5b7d96335744..03a8a2538b464 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
    @@ -19,13 +19,12 @@ package org.apache.spark.mllib.clustering
     
     import breeze.linalg.{DenseMatrix => BDM}
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class LDASuite extends FunSuite with MLlibTestSparkContext {
    +class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
     
       import LDASuite._
     
    @@ -100,9 +99,13 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
     
         // Check: per-doc topic distributions
         val topicDistributions = model.topicDistributions.collect()
    +
         //  Ensure all documents are covered.
    -    assert(topicDistributions.length === tinyCorpus.length)
    -    assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
    +    // SPARK-5562. since the topicDistribution returns the distribution of the non empty docs
    +    // over topics. Compare it against nonEmptyTinyCorpus instead of tinyCorpus
    +    val nonEmptyTinyCorpus = getNonEmptyDoc(tinyCorpus)
    +    assert(topicDistributions.length === nonEmptyTinyCorpus.length)
    +    assert(nonEmptyTinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
         //  Ensure we have proper distributions
         topicDistributions.foreach { case (docId, topicDistribution) =>
           assert(topicDistribution.size === tinyK)
    @@ -233,12 +236,17 @@ private[clustering] object LDASuite {
       }
     
       def tinyCorpus: Array[(Long, Vector)] = Array(
    +    Vectors.dense(0, 0, 0, 0, 0), // empty doc
         Vectors.dense(1, 3, 0, 2, 8),
         Vectors.dense(0, 2, 1, 0, 4),
         Vectors.dense(2, 3, 12, 3, 1),
    +    Vectors.dense(0, 0, 0, 0, 0), // empty doc
         Vectors.dense(0, 3, 1, 9, 8),
         Vectors.dense(1, 1, 4, 2, 6)
       ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
       assert(tinyCorpus.forall(_._2.size == tinyVocabSize)) // sanity check for test data
     
    +  def getNonEmptyDoc(corpus: Array[(Long, Vector)]): Array[(Long, Vector)] = corpus.filter {
    +    case (_, wc: Vector) => Vectors.norm(wc, p = 1.0) != 0.0
    +  }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
    index 6d6fe6fe46bab..19e65f1b53ab5 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
    @@ -20,15 +20,13 @@ package org.apache.spark.mllib.clustering
     import scala.collection.mutable
     import scala.util.Random
     
    -import org.scalatest.FunSuite
    -
    -import org.apache.spark.SparkContext
    +import org.apache.spark.{SparkContext, SparkFunSuite}
     import org.apache.spark.graphx.{Edge, Graph}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.util.Utils
     
    -class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext {
    +class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       import org.apache.spark.mllib.clustering.PowerIterationClustering._
     
    @@ -58,7 +56,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
           predictions(a.cluster) += a.id
         }
         assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
    - 
    +
         val model2 = new PowerIterationClustering()
           .setK(2)
           .setInitializationMode("degree")
    @@ -94,11 +92,13 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
          */
         val similarities = Seq[(Long, Long, Double)](
           (0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (2, 3, 1.0))
    +    // scalastyle:off
         val expected = Array(
           Array(0.0,     1.0/3.0, 1.0/3.0, 1.0/3.0),
           Array(1.0/2.0,     0.0, 1.0/2.0,     0.0),
           Array(1.0/3.0, 1.0/3.0,     0.0, 1.0/3.0),
           Array(1.0/2.0,     0.0, 1.0/2.0,     0.0))
    +    // scalastyle:on
         val w = normalize(sc.parallelize(similarities, 2))
         w.edges.collect().foreach { case Edge(i, j, x) =>
           assert(x ~== expected(i.toInt)(j.toInt) absTol 1e-14)
    @@ -128,7 +128,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
       }
     }
     
    -object PowerIterationClusteringSuite extends FunSuite {
    +object PowerIterationClusteringSuite extends SparkFunSuite {
       def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = {
         val assignments = sc.parallelize(
           (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k))))
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
    index f90025d535e45..ac01622b8a089 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
    @@ -17,15 +17,14 @@
     
     package org.apache.spark.mllib.clustering
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{Vector, Vectors}
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.streaming.TestSuiteBase
     import org.apache.spark.streaming.dstream.DStream
     import org.apache.spark.util.random.XORShiftRandom
     
    -class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
    +class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
     
       override def maxWaitTimeMillis: Int = 30000
     
    @@ -133,6 +132,13 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
         assert(math.abs(c1) ~== 0.8 absTol 0.6)
       }
     
    +  test("SPARK-7946 setDecayFactor") {
    +    val kMeans = new StreamingKMeans()
    +    assert(kMeans.decayFactor === 1.0)
    +    kMeans.setDecayFactor(2.0)
    +    assert(kMeans.decayFactor === 2.0)
    +  }
    +
       def StreamingKMeansDataGenerator(
           numPoints: Int,
           numBatches: Int,
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
    index 79847633ff0dc..87ccc7eda44ea 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
    @@ -17,12 +17,11 @@
     
     package org.apache.spark.mllib.evaluation
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext {
    +class AreaUnderCurveSuite extends SparkFunSuite with MLlibTestSparkContext {
       test("auc computation") {
         val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
         val auc = 4.0
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
    index e0224f960cc43..99d52fabc5309 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
    @@ -17,12 +17,11 @@
     
     package org.apache.spark.mllib.evaluation
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext {
    +class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
     
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
    index 7dc4f3cfbc4e4..d55bc8c3ec09f 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
    @@ -17,12 +17,11 @@
     
     package org.apache.spark.mllib.evaluation
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Matrices
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     
    -class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext {
    +class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
       test("Multiclass evaluation metrics") {
         /*
          * Confusion matrix for 3-class classification with total 9 instances:
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
    index 2537dd62c92f2..f3b19aeb42f84 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
    @@ -17,12 +17,11 @@
     
     package org.apache.spark.mllib.evaluation
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.rdd.RDD
     
    -class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext {
    +class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
       test("Multilabel evaluation metrics") {
         /*
         * Documents true labels (5x class0, 3x class1, 4x class2):
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
    index 609eed983ff4e..c0924a213a844 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
    @@ -17,12 +17,11 @@
     
     package org.apache.spark.mllib.evaluation
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     
    -class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext {
    +class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
       test("Ranking metrics: map, ndcg") {
         val predictionAndLabels = sc.parallelize(
           Seq(
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
    index 670b4c34e6095..9de2bdb6d7246 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
    @@ -17,16 +17,15 @@
     
     package org.apache.spark.mllib.evaluation
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext {
    +class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       test("regression metrics") {
         val predictionAndObservations = sc.parallelize(
    -      Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)), 2)
    +      Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2)
         val metrics = new RegressionMetrics(predictionAndObservations)
         assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5,
           "explained variance regression score mismatch")
    @@ -39,7 +38,7 @@ class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext {
     
       test("regression metrics with complete fitting") {
         val predictionAndObservations = sc.parallelize(
    -      Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)), 2)
    +      Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2)
         val metrics = new RegressionMetrics(predictionAndObservations)
         assert(metrics.explainedVariance ~== 1.0 absTol 1E-5,
           "explained variance regression score mismatch")
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
    index 747f5914598ec..889727fb55823 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
    @@ -17,13 +17,12 @@
     
     package org.apache.spark.mllib.feature
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     
    -class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext {
    +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       /*
        *  Contingency tables
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala
    index f3a482abda873..ccbf8a91cdd37 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala
    @@ -17,13 +17,12 @@
     
     package org.apache.spark.mllib.feature
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class ElementwiseProductSuite extends FunSuite with MLlibTestSparkContext {
    +class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       test("elementwise (hadamard) product should properly apply vector to dense data set") {
         val denseData = Array(
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
    index 0c4dfb7b97c7f..cf279c02334e9 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
    @@ -17,12 +17,11 @@
     
     package org.apache.spark.mllib.feature
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     
    -class HashingTFSuite extends FunSuite with MLlibTestSparkContext {
    +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       test("hashing tf on a single doc") {
         val hashingTF = new HashingTF(1000)
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
    index 0a5cad7caf8e4..21163633051e5 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
    @@ -17,13 +17,12 @@
     
     package org.apache.spark.mllib.feature
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class IDFSuite extends FunSuite with MLlibTestSparkContext {
    +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       test("idf") {
         val n = 4
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
    index 5c4af2b99e68b..34122d6ed2e95 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
    @@ -17,15 +17,14 @@
     
     package org.apache.spark.mllib.feature
     
    -import org.scalatest.FunSuite
    -
     import breeze.linalg.{norm => brzNorm}
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
    +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       val data = Array(
         Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
    index 758af588f1c69..e57f49191378f 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
    @@ -17,13 +17,12 @@
     
     package org.apache.spark.mllib.feature
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.linalg.distributed.RowMatrix
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     
    -class PCASuite extends FunSuite with MLlibTestSparkContext {
    +class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
     
       private val data = Array(
         Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
    index 7f94564b2a3ae..6ab2fa6770123 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
    @@ -17,15 +17,14 @@
     
     package org.apache.spark.mllib.feature
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
     import org.apache.spark.rdd.RDD
     
    -class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
    +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       // When the input data is all constant, the variance is zero. The standardization against
       // zero variance is not well-defined, but we decide to just set it into zero here.
    @@ -360,7 +359,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
         }
         withClue("model needs std and mean vectors to be equal size when both are provided") {
           intercept[IllegalArgumentException] {
    -        val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0,1.0))
    +        val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0, 1.0))
           }
         }
       }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
    index 98a98a7599bcb..b6818369208d7 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
    @@ -17,14 +17,13 @@
     
     package org.apache.spark.mllib.feature
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.util.Utils
     
    -class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
    +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       // TODO: add more tests
     
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
    new file mode 100644
    index 0000000000000..77a2773c36f56
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
    @@ -0,0 +1,89 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.mllib.fpm
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +
    +class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  test("association rules using String type") {
    +    val freqItemsets = sc.parallelize(Seq(
    +      (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
    +      (Set("r"), 3L),
    +      (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
    +      (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
    +      (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
    +      (Set("t", "y", "x"), 3L),
    +      (Set("t", "y", "x", "z"), 3L)
    +    ).map {
    +      case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq)
    +    })
    +
    +    val ar = new AssociationRules()
    +
    +    val results1 = ar
    +      .setMinConfidence(0.9)
    +      .run(freqItemsets)
    +      .collect()
    +
    +    /* Verify results using the `R` code:
    +       transactions = as(sapply(
    +         list("r z h k p",
    +              "z y x w v u t s",
    +              "s x o n r",
    +              "x z y m t s q e",
    +              "z",
    +              "x z y r q t p"),
    +         FUN=function(x) strsplit(x," ",fixed=TRUE)),
    +         "transactions")
    +       ars = apriori(transactions,
    +                     parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
    +       arsDF = as(ars, "data.frame")
    +       arsDF$support = arsDF$support * length(transactions)
    +       names(arsDF)[names(arsDF) == "support"] = "freq"
    +       > nrow(arsDF)
    +       [1] 23
    +       > sum(arsDF$confidence == 1)
    +       [1] 23
    +     */
    +    assert(results1.size === 23)
    +    assert(results1.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
    +
    +    val results2 = ar
    +      .setMinConfidence(0)
    +      .run(freqItemsets)
    +      .collect()
    +
    +    /* Verify results using the `R` code:
    +       ars = apriori(transactions,
    +                  parameter = list(support = 0.5, confidence = 0.5, target="rules", minlen=2))
    +       arsDF = as(ars, "data.frame")
    +       arsDF$support = arsDF$support * length(transactions)
    +       names(arsDF)[names(arsDF) == "support"] = "freq"
    +       nrow(arsDF)
    +       sum(arsDF$confidence == 1)
    +       > nrow(arsDF)
    +       [1] 30
    +       > sum(arsDF$confidence == 1)
    +       [1] 23
    +     */
    +    assert(results2.size === 30)
    +    assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
    +  }
    +}
    +
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
    index bd5b9cc3afa10..4a9bfdb348d9f 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
    @@ -16,11 +16,10 @@
      */
     package org.apache.spark.mllib.fpm
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     
    -class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
    +class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
     
     
       test("FP-Growth using String type") {
    @@ -40,6 +39,22 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
           .setMinSupport(0.9)
           .setNumPartitions(1)
           .run(rdd)
    +
    +    /* Verify results using the `R` code:
    +       transactions = as(sapply(
    +         list("r z h k p",
    +              "z y x w v u t s",
    +              "s x o n r",
    +              "x z y m t s q e",
    +              "z",
    +              "x z y r q t p"),
    +         FUN=function(x) strsplit(x," ",fixed=TRUE)),
    +         "transactions")
    +       > eclat(transactions, parameter = list(support = 0.9))
    +       ...
    +       eclat - zero frequent items
    +       set of 0 itemsets
    +     */
         assert(model6.freqItemsets.count() === 0)
     
         val model3 = fpg
    @@ -49,6 +64,33 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
         val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
           (itemset.items.toSet, itemset.freq)
         }
    +
    +    /* Verify results using the `R` code:
    +       fp = eclat(transactions, parameter = list(support = 0.5))
    +       fpDF = as(sort(fp), "data.frame")
    +       fpDF$support = fpDF$support * length(transactions)
    +       names(fpDF)[names(fpDF) == "support"] = "freq"
    +       > fpDF
    +              items freq
    +       13       {z}    5
    +       14       {x}    4
    +       1      {s,x}    3
    +       2  {t,x,y,z}    3
    +       3    {t,y,z}    3
    +       4    {t,x,y}    3
    +       5    {x,y,z}    3
    +       6      {y,z}    3
    +       7      {x,y}    3
    +       8      {t,y}    3
    +       9    {t,x,z}    3
    +       10     {t,z}    3
    +       11     {t,x}    3
    +       12     {x,z}    3
    +       15       {t}    3
    +       16       {y}    3
    +       17       {s}    3
    +       18       {r}    3
    +     */
         val expected = Set(
           (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
           (Set("r"), 3L),
    @@ -63,15 +105,75 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
           .setMinSupport(0.3)
           .setNumPartitions(4)
           .run(rdd)
    +
    +    /* Verify results using the `R` code:
    +       fp = eclat(transactions, parameter = list(support = 0.3))
    +       fpDF = as(fp, "data.frame")
    +       fpDF$support = fpDF$support * length(transactions)
    +       names(fpDF)[names(fpDF) == "support"] = "freq"
    +       > nrow(fpDF)
    +       [1] 54
    +     */
         assert(model2.freqItemsets.count() === 54)
     
         val model1 = fpg
           .setMinSupport(0.1)
           .setNumPartitions(8)
           .run(rdd)
    +
    +    /* Verify results using the `R` code:
    +       fp = eclat(transactions, parameter = list(support = 0.1))
    +       fpDF = as(fp, "data.frame")
    +       fpDF$support = fpDF$support * length(transactions)
    +       names(fpDF)[names(fpDF) == "support"] = "freq"
    +       > nrow(fpDF)
    +       [1] 625
    +     */
         assert(model1.freqItemsets.count() === 625)
       }
     
    +  test("FP-Growth String type association rule generation") {
    +    val transactions = Seq(
    +      "r z h k p",
    +      "z y x w v u t s",
    +      "s x o n r",
    +      "x z y m t s q e",
    +      "z",
    +      "x z y r q t p")
    +      .map(_.split(" "))
    +    val rdd = sc.parallelize(transactions, 2).cache()
    +
    +    /* Verify results using the `R` code:
    +       transactions = as(sapply(
    +         list("r z h k p",
    +              "z y x w v u t s",
    +              "s x o n r",
    +              "x z y m t s q e",
    +              "z",
    +              "x z y r q t p"),
    +         FUN=function(x) strsplit(x," ",fixed=TRUE)),
    +         "transactions")
    +       ars = apriori(transactions,
    +                     parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
    +       arsDF = as(ars, "data.frame")
    +       arsDF$support = arsDF$support * length(transactions)
    +       names(arsDF)[names(arsDF) == "support"] = "freq"
    +       > nrow(arsDF)
    +       [1] 23
    +       > sum(arsDF$confidence == 1)
    +       [1] 23
    +     */
    +    val rules = (new FPGrowth())
    +      .setMinSupport(0.5)
    +      .setNumPartitions(2)
    +      .run(rdd)
    +      .generateAssociationRules(0.9)
    +      .collect()
    +
    +    assert(rules.size === 23)
    +    assert(rules.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
    +  }
    +
       test("FP-Growth using Int type") {
         val transactions = Seq(
           "1 2 3",
    @@ -90,6 +192,23 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
           .setMinSupport(0.9)
           .setNumPartitions(1)
           .run(rdd)
    +
    +    /* Verify results using the `R` code:
    +       transactions = as(sapply(
    +         list("1 2 3",
    +              "1 2 3 4",
    +              "5 4 3 2 1",
    +              "6 5 4 3 2 1",
    +              "2 4",
    +              "1 3",
    +              "1 7"),
    +         FUN=function(x) strsplit(x," ",fixed=TRUE)),
    +         "transactions")
    +       > eclat(transactions, parameter = list(support = 0.9))
    +       ...
    +       eclat - zero frequent items
    +       set of 0 itemsets
    +     */
         assert(model6.freqItemsets.count() === 0)
     
         val model3 = fpg
    @@ -101,6 +220,24 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
         val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
           (itemset.items.toSet, itemset.freq)
         }
    +
    +    /* Verify results using the `R` code:
    +       fp = eclat(transactions, parameter = list(support = 0.5))
    +       fpDF = as(sort(fp), "data.frame")
    +       fpDF$support = fpDF$support * length(transactions)
    +       names(fpDF)[names(fpDF) == "support"] = "freq"
    +       > fpDF
    +          items freq
    +      6     {1}    6
    +      3   {1,3}    5
    +      7     {2}    5
    +      8     {3}    5
    +      1   {2,4}    4
    +      2 {1,2,3}    4
    +      4   {2,3}    4
    +      5   {1,2}    4
    +      9     {4}    4
    +     */
         val expected = Set(
           (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L),
           (Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L),
    @@ -111,12 +248,30 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
           .setMinSupport(0.3)
           .setNumPartitions(4)
           .run(rdd)
    +
    +    /* Verify results using the `R` code:
    +       fp = eclat(transactions, parameter = list(support = 0.3))
    +       fpDF = as(fp, "data.frame")
    +       fpDF$support = fpDF$support * length(transactions)
    +       names(fpDF)[names(fpDF) == "support"] = "freq"
    +       > nrow(fpDF)
    +       [1] 15
    +     */
         assert(model2.freqItemsets.count() === 15)
     
         val model1 = fpg
           .setMinSupport(0.1)
           .setNumPartitions(8)
           .run(rdd)
    +
    +    /* Verify results using the `R` code:
    +       fp = eclat(transactions, parameter = list(support = 0.1))
    +       fpDF = as(fp, "data.frame")
    +       fpDF$support = fpDF$support * length(transactions)
    +       names(fpDF)[names(fpDF) == "support"] = "freq"
    +       > nrow(fpDF)
    +       [1] 65
    +     */
         assert(model1.freqItemsets.count() === 65)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
    index 04017f67c311d..a56d7b3579213 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
    @@ -19,11 +19,10 @@ package org.apache.spark.mllib.fpm
     
     import scala.language.existentials
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     
    -class FPTreeSuite extends FunSuite with MLlibTestSparkContext {
    +class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       test("add transaction") {
         val tree = new FPTree[String]
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
    new file mode 100644
    index 0000000000000..413436d3db85f
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
    @@ -0,0 +1,120 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.mllib.fpm
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.rdd.RDD
    +
    +class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  test("PrefixSpan using Integer type") {
    +
    +    /*
    +      library("arulesSequences")
    +      prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE"))
    +      freqItemSeq = cspade(
    +        prefixSpanSeqs,
    +        parameter = list(support =
    +          2 / length(unique(transactionInfo(prefixSpanSeqs)$sequenceID)), maxlen = 2 ))
    +      resSeq = as(freqItemSeq, "data.frame")
    +      resSeq
    +    */
    +
    +    val sequences = Array(
    +      Array(1, 3, 4, 5),
    +      Array(2, 3, 1),
    +      Array(2, 4, 1),
    +      Array(3, 1, 3, 4, 5),
    +      Array(3, 4, 4, 3),
    +      Array(6, 5, 3))
    +
    +    val rdd = sc.parallelize(sequences, 2).cache()
    +
    +    def compareResult(
    +        expectedValue: Array[(Array[Int], Long)],
    +        actualValue: Array[(Array[Int], Long)]): Boolean = {
    +      val sortedExpectedValue = expectedValue.sortWith{ (x, y) =>
    +        x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
    +      }
    +      val sortedActualValue = actualValue.sortWith{ (x, y) =>
    +        x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
    +      }
    +      sortedExpectedValue.zip(sortedActualValue)
    +        .map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2)
    +        .reduce(_&&_)
    +    }
    +
    +    val prefixspan = new PrefixSpan()
    +      .setMinSupport(0.33)
    +      .setMaxPatternLength(50)
    +    val result1 = prefixspan.run(rdd)
    +    val expectedValue1 = Array(
    +      (Array(1), 4L),
    +      (Array(1, 3), 2L),
    +      (Array(1, 3, 4), 2L),
    +      (Array(1, 3, 4, 5), 2L),
    +      (Array(1, 3, 5), 2L),
    +      (Array(1, 4), 2L),
    +      (Array(1, 4, 5), 2L),
    +      (Array(1, 5), 2L),
    +      (Array(2), 2L),
    +      (Array(2, 1), 2L),
    +      (Array(3), 5L),
    +      (Array(3, 1), 2L),
    +      (Array(3, 3), 2L),
    +      (Array(3, 4), 3L),
    +      (Array(3, 4, 5), 2L),
    +      (Array(3, 5), 2L),
    +      (Array(4), 4L),
    +      (Array(4, 5), 2L),
    +      (Array(5), 3L)
    +    )
    +    assert(compareResult(expectedValue1, result1.collect()))
    +
    +    prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
    +    val result2 = prefixspan.run(rdd)
    +    val expectedValue2 = Array(
    +      (Array(1), 4L),
    +      (Array(3), 5L),
    +      (Array(3, 4), 3L),
    +      (Array(4), 4L),
    +      (Array(5), 3L)
    +    )
    +    assert(compareResult(expectedValue2, result2.collect()))
    +
    +    prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
    +    val result3 = prefixspan.run(rdd)
    +    val expectedValue3 = Array(
    +      (Array(1), 4L),
    +      (Array(1, 3), 2L),
    +      (Array(1, 4), 2L),
    +      (Array(1, 5), 2L),
    +      (Array(2, 1), 2L),
    +      (Array(2), 2L),
    +      (Array(3), 5L),
    +      (Array(3, 1), 2L),
    +      (Array(3, 3), 2L),
    +      (Array(3, 4), 3L),
    +      (Array(3, 5), 2L),
    +      (Array(4), 4L),
    +      (Array(4, 5), 2L),
    +      (Array(5), 3L)
    +    )
    +    assert(compareResult(expectedValue3, result3.collect()))
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
    index 699f009f0f2ec..d34888af2d73b 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
    @@ -17,18 +17,16 @@
     
     package org.apache.spark.mllib.impl
     
    -import org.scalatest.FunSuite
    -
     import org.apache.hadoop.fs.{FileSystem, Path}
     
    -import org.apache.spark.SparkContext
    +import org.apache.spark.{SparkContext, SparkFunSuite}
     import org.apache.spark.graphx.{Edge, Graph}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.storage.StorageLevel
     import org.apache.spark.util.Utils
     
     
    -class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext {
    +class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       import PeriodicGraphCheckpointerSuite._
     
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
    index 002cb253862b5..b0f3f71113c57 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
    @@ -17,12 +17,11 @@
     
     package org.apache.spark.mllib.linalg
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.mllib.linalg.BLAS._
     
    -class BLASSuite extends FunSuite {
    +class BLASSuite extends SparkFunSuite {
     
       test("copy") {
         val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0))
    @@ -140,7 +139,7 @@ class BLASSuite extends FunSuite {
         syr(alpha, x, dA)
     
         assert(dA ~== expected absTol 1e-15)
    - 
    +
         val dB =
           new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0))
     
    @@ -149,7 +148,7 @@ class BLASSuite extends FunSuite {
             syr(alpha, x, dB)
           }
         }
    - 
    +
         val dC =
           new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8))
     
    @@ -158,7 +157,7 @@ class BLASSuite extends FunSuite {
             syr(alpha, x, dC)
           }
         }
    - 
    +
         val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5))
     
         withClue("Size of vector must match the rank of matrix") {
    @@ -257,32 +256,96 @@ class BLASSuite extends FunSuite {
           new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
         val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))
     
    -    val x = new DenseVector(Array(1.0, 2.0, 3.0))
    +    val dA2 =
    +      new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true)
    +    val sA2 =
    +      new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0),
    +        true)
    +
    +    val dx = new DenseVector(Array(1.0, 2.0, 3.0))
    +    val sx = dx.toSparse
         val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0))
     
    -    assert(dA.multiply(x) ~== expected absTol 1e-15)
    -    assert(sA.multiply(x) ~== expected absTol 1e-15)
    +    assert(dA.multiply(dx) ~== expected absTol 1e-15)
    +    assert(sA.multiply(dx) ~== expected absTol 1e-15)
    +    assert(dA.multiply(sx) ~== expected absTol 1e-15)
    +    assert(sA.multiply(sx) ~== expected absTol 1e-15)
     
         val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
         val y2 = y1.copy
         val y3 = y1.copy
         val y4 = y1.copy
    +    val y5 = y1.copy
    +    val y6 = y1.copy
    +    val y7 = y1.copy
    +    val y8 = y1.copy
    +    val y9 = y1.copy
    +    val y10 = y1.copy
    +    val y11 = y1.copy
    +    val y12 = y1.copy
    +    val y13 = y1.copy
    +    val y14 = y1.copy
    +    val y15 = y1.copy
    +    val y16 = y1.copy
    +
         val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0))
         val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0))
     
    -    gemv(1.0, dA, x, 2.0, y1)
    -    gemv(1.0, sA, x, 2.0, y2)
    -    gemv(2.0, dA, x, 2.0, y3)
    -    gemv(2.0, sA, x, 2.0, y4)
    +    gemv(1.0, dA, dx, 2.0, y1)
    +    gemv(1.0, sA, dx, 2.0, y2)
    +    gemv(1.0, dA, sx, 2.0, y3)
    +    gemv(1.0, sA, sx, 2.0, y4)
    +
    +    gemv(1.0, dA2, dx, 2.0, y5)
    +    gemv(1.0, sA2, dx, 2.0, y6)
    +    gemv(1.0, dA2, sx, 2.0, y7)
    +    gemv(1.0, sA2, sx, 2.0, y8)
    +
    +    gemv(2.0, dA, dx, 2.0, y9)
    +    gemv(2.0, sA, dx, 2.0, y10)
    +    gemv(2.0, dA, sx, 2.0, y11)
    +    gemv(2.0, sA, sx, 2.0, y12)
    +
    +    gemv(2.0, dA2, dx, 2.0, y13)
    +    gemv(2.0, sA2, dx, 2.0, y14)
    +    gemv(2.0, dA2, sx, 2.0, y15)
    +    gemv(2.0, sA2, sx, 2.0, y16)
    +
         assert(y1 ~== expected2 absTol 1e-15)
         assert(y2 ~== expected2 absTol 1e-15)
    -    assert(y3 ~== expected3 absTol 1e-15)
    -    assert(y4 ~== expected3 absTol 1e-15)
    +    assert(y3 ~== expected2 absTol 1e-15)
    +    assert(y4 ~== expected2 absTol 1e-15)
    +
    +    assert(y5 ~== expected2 absTol 1e-15)
    +    assert(y6 ~== expected2 absTol 1e-15)
    +    assert(y7 ~== expected2 absTol 1e-15)
    +    assert(y8 ~== expected2 absTol 1e-15)
    +
    +    assert(y9 ~== expected3 absTol 1e-15)
    +    assert(y10 ~== expected3 absTol 1e-15)
    +    assert(y11 ~== expected3 absTol 1e-15)
    +    assert(y12 ~== expected3 absTol 1e-15)
    +
    +    assert(y13 ~== expected3 absTol 1e-15)
    +    assert(y14 ~== expected3 absTol 1e-15)
    +    assert(y15 ~== expected3 absTol 1e-15)
    +    assert(y16 ~== expected3 absTol 1e-15)
    +
         withClue("columns of A don't match the rows of B") {
           intercept[Exception] {
    -        gemv(1.0, dA.transpose, x, 2.0, y1)
    +        gemv(1.0, dA.transpose, dx, 2.0, y1)
    +      }
    +      intercept[Exception] {
    +        gemv(1.0, sA.transpose, dx, 2.0, y1)
    +      }
    +      intercept[Exception] {
    +        gemv(1.0, dA.transpose, sx, 2.0, y1)
    +      }
    +      intercept[Exception] {
    +        gemv(1.0, sA.transpose, sx, 2.0, y1)
           }
         }
    +
         val dAT =
           new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
         val sAT =
    @@ -291,7 +354,9 @@ class BLASSuite extends FunSuite {
         val dATT = dAT.transpose
         val sATT = sAT.transpose
     
    -    assert(dATT.multiply(x) ~== expected absTol 1e-15)
    -    assert(sATT.multiply(x) ~== expected absTol 1e-15)
    +    assert(dATT.multiply(dx) ~== expected absTol 1e-15)
    +    assert(sATT.multiply(dx) ~== expected absTol 1e-15)
    +    assert(dATT.multiply(sx) ~== expected absTol 1e-15)
    +    assert(sATT.multiply(sx) ~== expected absTol 1e-15)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
    index 2031032373971..dc04258e41d27 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
    @@ -17,11 +17,11 @@
     
     package org.apache.spark.mllib.linalg
     
    -import org.scalatest.FunSuite
    -
     import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM}
     
    -class BreezeMatrixConversionSuite extends FunSuite {
    +import org.apache.spark.SparkFunSuite
    +
    +class BreezeMatrixConversionSuite extends SparkFunSuite {
       test("dense matrix to breeze") {
         val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))
         val breeze = mat.toBreeze.asInstanceOf[BDM[Double]]
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
    index 8abdac72902c6..3772c9235ad3a 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
    @@ -17,14 +17,14 @@
     
     package org.apache.spark.mllib.linalg
     
    -import org.scalatest.FunSuite
    -
     import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
     
    +import org.apache.spark.SparkFunSuite
    +
     /**
      * Test Breeze vector conversions.
      */
    -class BreezeVectorConversionSuite extends FunSuite {
    +class BreezeVectorConversionSuite extends SparkFunSuite {
     
       val arr = Array(0.1, 0.2, 0.3, 0.4)
       val n = 20
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
    index 86119ec38101e..a270ba2562db9 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
    @@ -20,13 +20,13 @@ package org.apache.spark.mllib.linalg
     import java.util.Random
     
     import org.mockito.Mockito.when
    -import org.scalatest.FunSuite
     import org.scalatest.mock.MockitoSugar._
     import scala.collection.mutable.{Map => MutableMap}
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class MatricesSuite extends FunSuite {
    +class MatricesSuite extends SparkFunSuite {
       test("dense matrix construction") {
         val m = 3
         val n = 2
    @@ -455,4 +455,14 @@ class MatricesSuite extends FunSuite {
         lines = mat.toString(5, 100).lines.toArray
         assert(lines.size == 5 && lines.forall(_.size <= 100))
       }
    +
    +  test("numNonzeros and numActives") {
    +    val dm1 = Matrices.dense(3, 2, Array(0, 0, -1, 1, 0, 1))
    +    assert(dm1.numNonzeros === 3)
    +    assert(dm1.numActives === 6)
    +
    +    val sm1 = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0))
    +    assert(sm1.numNonzeros === 1)
    +    assert(sm1.numActives === 3)
    +  }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
    index c3f407cb56e00..03be4119bdaca 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
    @@ -20,12 +20,11 @@ package org.apache.spark.mllib.linalg
     import scala.util.Random
     
     import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance}
    -import org.scalatest.FunSuite
     
    -import org.apache.spark.SparkException
    +import org.apache.spark.{Logging, SparkException, SparkFunSuite}
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class VectorsSuite extends FunSuite {
    +class VectorsSuite extends SparkFunSuite with Logging {
     
       val arr = Array(0.1, 0.0, 0.3, 0.4)
       val n = 4
    @@ -182,7 +181,7 @@ class VectorsSuite extends FunSuite {
         malformatted.foreach { s =>
           intercept[SparkException] {
             Vectors.parse(s)
    -        println(s"Didn't detect malformatted string $s.")
    +        logInfo(s"Didn't detect malformatted string $s.")
           }
         }
       }
    @@ -254,13 +253,13 @@ class VectorsSuite extends FunSuite {
     
           val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze)
     
    -      // SparseVector vs. SparseVector 
    -      assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) 
    +      // SparseVector vs. SparseVector
    +      assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
           // DenseVector  vs. SparseVector
           assert(Vectors.sqdist(denseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
           // DenseVector  vs. DenseVector
           assert(Vectors.sqdist(denseVector1, denseVector2) ~== squaredDist relTol 1E-8)
    -    }    
    +    }
       }
     
       test("foreachActive") {
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
    index 949d1c9939570..93fe04c139b9a 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
    @@ -20,14 +20,13 @@ package org.apache.spark.mllib.linalg.distributed
     import java.{util => ju}
     
     import breeze.linalg.{DenseMatrix => BDM}
    -import org.scalatest.FunSuite
     
    -import org.apache.spark.SparkException
    +import org.apache.spark.{SparkException, SparkFunSuite}
     import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
    +class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       val m = 5
       val n = 4
    @@ -57,11 +56,13 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
         val random = new ju.Random()
         // This should generate a 4x4 grid of 1x2 blocks.
         val part0 = GridPartitioner(4, 7, suggestedNumPartitions = 12)
    +    // scalastyle:off
         val expected0 = Array(
           Array(0, 0, 4, 4,  8,  8, 12),
           Array(1, 1, 5, 5,  9,  9, 13),
           Array(2, 2, 6, 6, 10, 10, 14),
           Array(3, 3, 7, 7, 11, 11, 15))
    +    // scalastyle:on
         for (i <- 0 until 4; j <- 0 until 7) {
           assert(part0.getPartition((i, j)) === expected0(i)(j))
           assert(part0.getPartition((i, j, random.nextInt())) === expected0(i)(j))
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
    index 04b36a9ef9990..f3728cd036a3f 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
    @@ -17,14 +17,13 @@
     
     package org.apache.spark.mllib.linalg.distributed
     
    -import org.scalatest.FunSuite
    -
     import breeze.linalg.{DenseMatrix => BDM}
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.linalg.Vectors
     
    -class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext {
    +class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       val m = 5
       val n = 4
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
    index 2ab53cc13db71..0ecb7a221a503 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
    @@ -17,15 +17,14 @@
     
     package org.apache.spark.mllib.linalg.distributed
     
    -import org.scalatest.FunSuite
    -
     import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV}
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.rdd.RDD
     import org.apache.spark.mllib.linalg.{Matrices, Vectors}
     
    -class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
    +class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       val m = 4
       val n = 3
    @@ -136,6 +135,17 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
         assert(closeToZero(U * brzDiag(s) * V.t - localA))
       }
     
    +  test("validate matrix sizes of svd") {
    +    val k = 2
    +    val A = new IndexedRowMatrix(indexedRows)
    +    val svd = A.computeSVD(k, computeU = true)
    +    assert(svd.U.numRows() === m)
    +    assert(svd.U.numCols() === k)
    +    assert(svd.s.size === k)
    +    assert(svd.V.numRows === n)
    +    assert(svd.V.numCols === k)
    +  }
    +
       test("validate k in svd") {
         val A = new IndexedRowMatrix(indexedRows)
         intercept[IllegalArgumentException] {
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
    index 27bb19f472e1e..b6cb53d0c743e 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
    @@ -20,12 +20,12 @@ package org.apache.spark.mllib.linalg.distributed
     import scala.util.Random
     
     import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd}
    -import org.scalatest.FunSuite
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector}
     import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
     
    -class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
    +class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       val m = 4
       val n = 3
    @@ -240,7 +240,7 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
       }
     }
     
    -class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext {
    +class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
     
       var mat: RowMatrix = _
     
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
    index 86481c6e66200..13b754a03943a 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
    @@ -20,11 +20,12 @@ package org.apache.spark.mllib.optimization
     import scala.collection.JavaConversions._
     import scala.util.Random
     
    -import org.scalatest.{FunSuite, Matchers}
    +import org.scalatest.Matchers
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.regression._
    -import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
    +import org.apache.spark.mllib.util.{MLUtils, LocalClusterSparkContext, MLlibTestSparkContext}
     import org.apache.spark.mllib.util.TestingUtils._
     
     object GradientDescentSuite {
    @@ -42,7 +43,7 @@ object GradientDescentSuite {
           offset: Double,
           scale: Double,
           nPoints: Int,
    -      seed: Int): Seq[LabeledPoint]  = {
    +      seed: Int): Seq[LabeledPoint] = {
         val rnd = new Random(seed)
         val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
     
    @@ -61,7 +62,7 @@ object GradientDescentSuite {
       }
     }
     
    -class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers {
    +class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
     
       test("Assert the loss is decreasing.") {
         val nPoints = 10000
    @@ -81,11 +82,11 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc
         // Add a extra variable consisting of all 1.0's for the intercept.
         val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
         val data = testData.map { case LabeledPoint(label, features) =>
    -      label -> Vectors.dense(1.0 +: features.toArray)
    +      label -> MLUtils.appendBias(features)
         }
     
         val dataRDD = sc.parallelize(data, 2).cache()
    -    val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray)
    +    val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0)
     
         val (_, loss) = GradientDescent.runMiniBatchSGD(
           dataRDD,
    @@ -138,9 +139,48 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc
           "The different between newWeights with/without regularization " +
             "should be initialWeightsWithIntercept.")
       }
    +
    +  test("iteration should end with convergence tolerance") {
    +    val nPoints = 10000
    +    val A = 2.0
    +    val B = -1.5
    +
    +    val initialB = -1.0
    +    val initialWeights = Array(initialB)
    +
    +    val gradient = new LogisticGradient()
    +    val updater = new SimpleUpdater()
    +    val stepSize = 1.0
    +    val numIterations = 10
    +    val regParam = 0
    +    val miniBatchFrac = 1.0
    +    val convergenceTolerance = 5.0e-1
    +
    +    // Add a extra variable consisting of all 1.0's for the intercept.
    +    val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
    +    val data = testData.map { case LabeledPoint(label, features) =>
    +      label -> MLUtils.appendBias(features)
    +    }
    +
    +    val dataRDD = sc.parallelize(data, 2).cache()
    +    val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0)
    +
    +    val (_, loss) = GradientDescent.runMiniBatchSGD(
    +      dataRDD,
    +      gradient,
    +      updater,
    +      stepSize,
    +      numIterations,
    +      regParam,
    +      miniBatchFrac,
    +      initialWeightsWithIntercept,
    +      convergenceTolerance)
    +
    +    assert(loss.length < numIterations, "convergenceTolerance failed to stop optimization early")
    +  }
     }
     
    -class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext {
    +class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
     
       test("task size should be small") {
         val m = 4
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
    index c8f2adcf155a7..75ae0eb32fb7b 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
    @@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization
     
     import scala.util.Random
     
    -import org.scalatest.{FunSuite, Matchers}
    +import org.scalatest.Matchers
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers {
    +class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
     
       val nPoints = 10000
       val A = 2.0
    @@ -121,7 +122,8 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers {
           numGDIterations,
           regParam,
           miniBatchFrac,
    -      initialWeightsWithIntercept)
    +      initialWeightsWithIntercept,
    +      convergenceTol)
     
         assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5,
           "The first losses of LBFGS and GD should be the same.")
    @@ -220,7 +222,8 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers {
           numGDIterations,
           regParam,
           miniBatchFrac,
    -      initialWeightsWithIntercept)
    +      initialWeightsWithIntercept,
    +      convergenceTol)
     
         // for class LBFGS and the optimize method, we only look at the weights
         assert(
    @@ -229,7 +232,7 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers {
       }
     }
     
    -class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext {
    +class LBFGSClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
     
       test("task size should be small") {
         val m = 10
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
    index 22855e4e8f247..d8f9b8c33963d 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
    @@ -19,13 +19,12 @@ package org.apache.spark.mllib.optimization
     
     import scala.util.Random
     
    -import org.scalatest.FunSuite
    -
     import org.jblas.{DoubleMatrix, SimpleBlas}
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class NNLSSuite extends FunSuite {
    +class NNLSSuite extends SparkFunSuite {
       /** Generate an NNLS problem whose optimal solution is the all-ones vector. */
       def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = {
         val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*)
    @@ -68,12 +67,14 @@ class NNLSSuite extends FunSuite {
     
       test("NNLS: nonnegativity constraint active") {
         val n = 5
    +    // scalastyle:off
         val ata = new DoubleMatrix(Array(
           Array( 4.377, -3.531, -1.306, -0.139,  3.418),
           Array(-3.531,  4.344,  0.934,  0.305, -2.140),
           Array(-1.306,  0.934,  2.644, -0.203, -0.170),
           Array(-0.139,  0.305, -0.203,  5.883,  1.428),
           Array( 3.418, -2.140, -0.170,  1.428,  4.684)))
    +    // scalastyle:on
         val atb = new DoubleMatrix(Array(-1.632, 2.115, 1.094, -1.025, -0.636))
     
         val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628)
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
    index 0b646cf1ce6c4..4c6e76e47419b 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
    @@ -19,13 +19,13 @@ package org.apache.spark.mllib.pmml.export
     
     import org.dmg.pmml.RegressionModel
     import org.dmg.pmml.RegressionNormalizationMethodType
    -import org.scalatest.FunSuite
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.classification.LogisticRegressionModel
     import org.apache.spark.mllib.classification.SVMModel
     import org.apache.spark.mllib.util.LinearDataGenerator
     
    -class BinaryClassificationPMMLModelExportSuite extends FunSuite {
    +class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite {
     
       test("logistic regression PMML export") {
         val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
    @@ -53,13 +53,13 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite {
         // ensure logistic regression has normalization method set to LOGIT
         assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT)
       }
    -  
    +
       test("linear SVM PMML export") {
         val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
         val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
    -    
    +
         val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
    -    
    +
         // assert that the PMML format is as expected
         assert(svmModelExport.isInstanceOf[PMMLModelExport])
         val pmml = svmModelExport.getPmml
    @@ -80,5 +80,5 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite {
         // ensure linear SVM has normalization method set to NONE
         assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE)
       }
    -  
    +
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
    index f9afbd888dfc5..1d32309481787 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
    @@ -18,12 +18,12 @@
     package org.apache.spark.mllib.pmml.export
     
     import org.dmg.pmml.RegressionModel
    -import org.scalatest.FunSuite
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
     import org.apache.spark.mllib.util.LinearDataGenerator
     
    -class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
    +class GeneralizedLinearPMMLModelExportSuite extends SparkFunSuite {
     
       test("linear regression PMML export") {
         val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
    index b985d0446d7b0..b3f9750afa730 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
    @@ -18,12 +18,12 @@
     package org.apache.spark.mllib.pmml.export
     
     import org.dmg.pmml.ClusteringModel
    -import org.scalatest.FunSuite
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.clustering.KMeansModel
     import org.apache.spark.mllib.linalg.Vectors
     
    -class KMeansPMMLModelExportSuite extends FunSuite {
    +class KMeansPMMLModelExportSuite extends SparkFunSuite {
     
       test("KMeansPMMLModelExport generate PMML format") {
         val clusterCenters = Array(
    @@ -45,5 +45,5 @@ class KMeansPMMLModelExportSuite extends FunSuite {
         val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
         assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
       }
    -  
    +
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
    index f28a4ac8ad01f..af49450961750 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
    @@ -17,15 +17,14 @@
     
     package org.apache.spark.mllib.pmml.export
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel}
     import org.apache.spark.mllib.clustering.KMeansModel
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
     import org.apache.spark.mllib.util.LinearDataGenerator
     
    -class PMMLModelExportFactorySuite extends FunSuite {
    +class PMMLModelExportFactorySuite extends SparkFunSuite {
     
       test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") {
         val clusterCenters = Array(
    @@ -61,25 +60,25 @@ class PMMLModelExportFactorySuite extends FunSuite {
       test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport "
         + "when passing a LogisticRegressionModel or SVMModel") {
         val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
    -    
    +
         val logisticRegressionModel =
           new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)
         val logisticRegressionModelExport =
           PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)
         assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
    -    
    +
         val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
         val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
         assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
       }
    -  
    +
       test("PMMLModelExportFactory throw IllegalArgumentException "
         + "when passing a Multinomial Logistic Regression") {
         /** 3 classes, 2 features */
         val multiclassLogisticRegressionModel = new LogisticRegressionModel(
    -      weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, 
    +      weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0,
           numFeatures = 2, numClasses = 3)
    -    
    +
         intercept[IllegalArgumentException] {
           PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel)
         }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
    index b792d819fdabb..a5ca1518f82f5 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
    @@ -19,12 +19,11 @@ package org.apache.spark.mllib.random
     
     import scala.math
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.util.StatCounter
     
     // TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
    -class RandomDataGeneratorSuite extends FunSuite {
    +class RandomDataGeneratorSuite extends SparkFunSuite {
     
       def apiChecks(gen: RandomDataGenerator[Double]) {
         // resetting seed should generate the same sequence of random numbers
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
    index 63f2ea916d457..413db2000d6d7 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
    @@ -19,8 +19,7 @@ package org.apache.spark.mllib.random
     
     import scala.collection.mutable.ArrayBuffer
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.SparkContext._
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD}
    @@ -34,7 +33,7 @@ import org.apache.spark.util.StatCounter
      *
      * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
      */
    -class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable {
    +class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Serializable {
     
       def testGeneratedRDD(rdd: RDD[Double],
           expectedSize: Long,
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
    index 57216e8eb4a55..10f5a2be48f7c 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
    @@ -17,12 +17,11 @@
     
     package org.apache.spark.mllib.rdd
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.rdd.MLPairRDDFunctions._
     
    -class MLPairRDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
    +class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {
       test("topByKey") {
         val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5,
           1), (3, 5)), 2)
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
    index 6d6c0aa5be812..bc64172614830 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
    @@ -17,12 +17,11 @@
     
     package org.apache.spark.mllib.rdd
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.rdd.RDDFunctions._
     
    -class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
    +class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       test("sliding") {
         val data = 0 until 6
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
    index b3798940ddc38..05b87728d6fdb 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
    @@ -21,9 +21,9 @@ import scala.collection.JavaConversions._
     import scala.math.abs
     import scala.util.Random
     
    -import org.scalatest.FunSuite
     import org.jblas.DoubleMatrix
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.storage.StorageLevel
     
    @@ -84,7 +84,7 @@ object ALSSuite {
     }
     
     
    -class ALSSuite extends FunSuite with MLlibTestSparkContext {
    +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       test("rank-1 matrices") {
         testALS(50, 100, 1, 15, 0.7, 0.3)
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
    index 2c92866f3893d..2c8ed057a516a 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
    @@ -17,14 +17,13 @@
     
     package org.apache.spark.mllib.recommendation
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.rdd.RDD
     import org.apache.spark.util.Utils
     
    -class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
    +class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       val rank = 2
       var userFeatures: RDD[(Int, Array[Double])] = _
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
    index 3b38bdf5ef5eb..ea4f2865757c1 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
    @@ -17,13 +17,14 @@
     
     package org.apache.spark.mllib.regression
     
    -import org.scalatest.{Matchers, FunSuite}
    +import org.scalatest.Matchers
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.util.Utils
     
    -class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
    +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
     
       private def round(d: Double) = {
         math.round(d * 100).toDouble / 100
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
    index 110c44a7193fd..f8d0af8820e64 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
    @@ -17,11 +17,10 @@
     
     package org.apache.spark.mllib.regression
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     
    -class LabeledPointSuite extends FunSuite {
    +class LabeledPointSuite extends SparkFunSuite {
     
       test("parse labeled points") {
         val points = Seq(
    @@ -32,6 +31,11 @@ class LabeledPointSuite extends FunSuite {
         }
       }
     
    +  test("parse labeled points with whitespaces") {
    +    val point = LabeledPoint.parse("(0.0, [1.0, 2.0])")
    +    assert(point === LabeledPoint(0.0, Vectors.dense(1.0, 2.0)))
    +  }
    +
       test("parse labeled points with v0.9 format") {
         val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0")
         assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0)))
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
    index c9f5dc069ef2e..39537e7bb4c72 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
    @@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression
     
     import scala.util.Random
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
       MLlibTestSparkContext}
    @@ -32,7 +31,7 @@ private object LassoSuite {
       val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
     }
     
    -class LassoSuite extends FunSuite with MLlibTestSparkContext {
    +class LassoSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
         val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
    @@ -67,11 +66,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext {
         assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
         assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")
     
    -    val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
    +    val validationData = LinearDataGenerator
    +      .generateLinearInput(A, Array[Double](B, C), nPoints, 17)
           .map { case LabeledPoint(label, features) =>
           LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
         }
    -    val validationRDD  = sc.parallelize(validationData, 2)
    +    val validationRDD = sc.parallelize(validationData, 2)
     
         // Test prediction on RDD.
         validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
    @@ -100,7 +100,7 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext {
         val testRDD = sc.parallelize(testData, 2).cache()
     
         val ls = new LassoWithSGD()
    -    ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)
    +    ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40).setConvergenceTol(0.0005)
     
         val model = ls.run(testRDD, initialWeights)
         val weight0 = model.weights(0)
    @@ -110,11 +110,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext {
         assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
         assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")
     
    -    val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
    +    val validationData = LinearDataGenerator
    +      .generateLinearInput(A, Array[Double](B, C), nPoints, 17)
           .map { case LabeledPoint(label, features) =>
           LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
         }
    -    val validationRDD  = sc.parallelize(validationData,2)
    +    val validationRDD = sc.parallelize(validationData, 2)
     
         // Test prediction on RDD.
         validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
    @@ -141,7 +142,7 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext {
       }
     }
     
    -class LassoClusterSuite extends FunSuite with LocalClusterSparkContext {
    +class LassoClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
     
       test("task size should be small in both training and prediction") {
         val m = 4
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
    index 3781931c2f819..f88a1c33c9f7c 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
    @@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression
     
     import scala.util.Random
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
       MLlibTestSparkContext}
    @@ -32,7 +31,7 @@ private object LinearRegressionSuite {
       val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
     }
     
    -class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
    +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
         val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
    @@ -150,7 +149,7 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
       }
     }
     
    -class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
    +class LinearRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
     
       test("task size should be small in both training and prediction") {
         val m = 4
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
    index d6c93cc0e49cd..7a781fee634c8 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
    @@ -20,8 +20,8 @@ package org.apache.spark.mllib.regression
     import scala.util.Random
     
     import org.jblas.DoubleMatrix
    -import org.scalatest.FunSuite
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
       MLlibTestSparkContext}
    @@ -33,7 +33,7 @@ private object RidgeRegressionSuite {
       val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
     }
     
    -class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
    +class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = {
         predictions.zip(input).map { case (prediction, expected) =>
    @@ -101,7 +101,7 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
       }
     }
     
    -class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
    +class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
     
       test("task size should be small in both training and prediction") {
         val m = 4
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
    index 26604dbe6c1ef..a2a4c5f6b8b70 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
    @@ -19,14 +19,13 @@ package org.apache.spark.mllib.regression
     
     import scala.collection.mutable.ArrayBuffer
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.util.LinearDataGenerator
     import org.apache.spark.streaming.dstream.DStream
     import org.apache.spark.streaming.TestSuiteBase
     
    -class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
    +class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
     
       // use longer wait time to ensure job completion
       override def maxWaitTimeMillis: Int = 20000
    @@ -54,6 +53,7 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
           .setInitialWeights(Vectors.dense(0.0, 0.0))
           .setStepSize(0.2)
           .setNumIterations(25)
    +      .setConvergenceTol(0.0001)
     
         // generate sequence of simulated data
         val numBatches = 10
    @@ -167,4 +167,22 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
         val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
         assert((error.head - error.last) > 2)
       }
    +
    +  // Test empty RDDs in a stream
    +  test("handling empty RDDs in a stream") {
    +    val model = new StreamingLinearRegressionWithSGD()
    +      .setInitialWeights(Vectors.dense(0.0, 0.0))
    +      .setStepSize(0.2)
    +      .setNumIterations(25)
    +    val numBatches = 10
    +    val nPoints = 100
    +    val emptyInput = Seq.empty[Seq[LabeledPoint]]
    +    val ssc = setupStreams(emptyInput,
    +      (inputDStream: DStream[LabeledPoint]) => {
    +        model.trainOn(inputDStream)
    +        model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
    +      }
    +    )
    +    val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
    +  }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
    index d20a09b4b4925..c3eeda012571c 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
    @@ -17,16 +17,15 @@
     
     package org.apache.spark.mllib.stat
     
    -import org.scalatest.FunSuite
    -
     import breeze.linalg.{DenseMatrix => BDM, Matrix => BM}
     
    +import org.apache.spark.{Logging, SparkFunSuite}
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation,
       SpearmanCorrelation}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     
    -class CorrelationSuite extends FunSuite with MLlibTestSparkContext {
    +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
     
       // test input data
       val xData = Array(1.0, 0.0, -2.0)
    @@ -96,11 +95,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext {
         val X = sc.parallelize(data)
         val defaultMat = Statistics.corr(X)
         val pearsonMat = Statistics.corr(X, "pearson")
    +    // scalastyle:off
         val expected = BDM(
           (1.00000000, 0.05564149, Double.NaN, 0.4004714),
           (0.05564149, 1.00000000, Double.NaN, 0.9135959),
           (Double.NaN, Double.NaN, 1.00000000, Double.NaN),
    -      (0.40047142, 0.91359586, Double.NaN,1.0000000))
    +      (0.40047142, 0.91359586, Double.NaN, 1.0000000))
    +    // scalastyle:on
         assert(matrixApproxEqual(defaultMat.toBreeze, expected))
         assert(matrixApproxEqual(pearsonMat.toBreeze, expected))
       }
    @@ -108,11 +109,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext {
       test("corr(X) spearman") {
         val X = sc.parallelize(data)
         val spearmanMat = Statistics.corr(X, "spearman")
    +    // scalastyle:off
         val expected = BDM(
           (1.0000000,  0.1054093,  Double.NaN, 0.4000000),
           (0.1054093,  1.0000000,  Double.NaN, 0.9486833),
           (Double.NaN, Double.NaN, 1.00000000, Double.NaN),
           (0.4000000,  0.9486833,  Double.NaN, 1.0000000))
    +    // scalastyle:on
         assert(matrixApproxEqual(spearmanMat.toBreeze, expected))
       }
     
    @@ -143,7 +146,7 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext {
       def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = {
         for (i <- 0 until A.rows; j <- 0 until A.cols) {
           if (!approxEqual(A(i, j), B(i, j), threshold)) {
    -        println("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j))
    +        logInfo("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j))
             return false
           }
         }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
    index 15418e6035965..142b90e764a7c 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
    @@ -19,16 +19,18 @@ package org.apache.spark.mllib.stat
     
     import java.util.Random
     
    -import org.scalatest.FunSuite
    +import org.apache.commons.math3.distribution.{ExponentialDistribution,
    +  NormalDistribution, UniformRealDistribution}
    +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest
     
    -import org.apache.spark.SparkException
    +import org.apache.spark.{SparkException, SparkFunSuite}
     import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors}
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.stat.test.ChiSqTest
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext {
    +class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       test("chi squared pearson goodness of fit") {
     
    @@ -155,4 +157,101 @@ class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext {
           Statistics.chiSqTest(sc.parallelize(continuousFeature, 2))
         }
       }
    +
    +  test("1 sample Kolmogorov-Smirnov test: apache commons math3 implementation equivalence") {
    +    // Create theoretical distributions
    +    val stdNormalDist = new NormalDistribution(0, 1)
    +    val expDist = new ExponentialDistribution(0.6)
    +    val unifDist = new UniformRealDistribution()
    +
    +    // set seeds
    +    val seed = 10L
    +    stdNormalDist.reseedRandomGenerator(seed)
    +    expDist.reseedRandomGenerator(seed)
    +    unifDist.reseedRandomGenerator(seed)
    +
    +    // Sample data from the distributions and parallelize it
    +    val n = 100000
    +    val sampledNorm = sc.parallelize(stdNormalDist.sample(n), 10)
    +    val sampledExp = sc.parallelize(expDist.sample(n), 10)
    +    val sampledUnif = sc.parallelize(unifDist.sample(n), 10)
    +
    +    // Use a apache math commons local KS test to verify calculations
    +    val ksTest = new KolmogorovSmirnovTest()
    +    val pThreshold = 0.05
    +
    +    // Comparing a standard normal sample to a standard normal distribution
    +    val result1 = Statistics.kolmogorovSmirnovTest(sampledNorm, "norm", 0, 1)
    +    val referenceStat1 = ksTest.kolmogorovSmirnovStatistic(stdNormalDist, sampledNorm.collect())
    +    val referencePVal1 = 1 - ksTest.cdf(referenceStat1, n)
    +    // Verify vs apache math commons ks test
    +    assert(result1.statistic ~== referenceStat1 relTol 1e-4)
    +    assert(result1.pValue ~== referencePVal1 relTol 1e-4)
    +    // Cannot reject null hypothesis
    +    assert(result1.pValue > pThreshold)
    +
    +    // Comparing an exponential sample to a standard normal distribution
    +    val result2 = Statistics.kolmogorovSmirnovTest(sampledExp, "norm", 0, 1)
    +    val referenceStat2 = ksTest.kolmogorovSmirnovStatistic(stdNormalDist, sampledExp.collect())
    +    val referencePVal2 = 1 - ksTest.cdf(referenceStat2, n)
    +    // verify vs apache math commons ks test
    +    assert(result2.statistic ~== referenceStat2 relTol 1e-4)
    +    assert(result2.pValue ~== referencePVal2 relTol 1e-4)
    +    // reject null hypothesis
    +    assert(result2.pValue < pThreshold)
    +
    +    // Testing the use of a user provided CDF function
    +    // Distribution is not serializable, so will have to create in the lambda
    +    val expCDF = (x: Double) => new ExponentialDistribution(0.2).cumulativeProbability(x)
    +
    +    // Comparing an exponential sample with mean X to an exponential distribution with mean Y
    +    // Where X != Y
    +    val result3 = Statistics.kolmogorovSmirnovTest(sampledExp, expCDF)
    +    val referenceStat3 = ksTest.kolmogorovSmirnovStatistic(new ExponentialDistribution(0.2),
    +      sampledExp.collect())
    +    val referencePVal3 = 1 - ksTest.cdf(referenceStat3, sampledNorm.count().toInt)
    +    // verify vs apache math commons ks test
    +    assert(result3.statistic ~== referenceStat3 relTol 1e-4)
    +    assert(result3.pValue ~== referencePVal3 relTol 1e-4)
    +    // reject null hypothesis
    +    assert(result3.pValue < pThreshold)
    +  }
    +
    +  test("1 sample Kolmogorov-Smirnov test: R implementation equivalence") {
    +    /*
    +      Comparing results with R's implementation of Kolmogorov-Smirnov for 1 sample
    +      > sessionInfo()
    +      R version 3.2.0 (2015-04-16)
    +      Platform: x86_64-apple-darwin13.4.0 (64-bit)
    +      > set.seed(20)
    +      > v <- rnorm(20)
    +      > v
    +       [1]  1.16268529 -0.58592447  1.78546500 -1.33259371 -0.44656677  0.56960612
    +       [7] -2.88971761 -0.86901834 -0.46170268 -0.55554091 -0.02013537 -0.15038222
    +      [13] -0.62812676  1.32322085 -1.52135057 -0.43742787  0.97057758  0.02822264
    +      [19] -0.08578219  0.38921440
    +      > ks.test(v, pnorm, alternative = "two.sided")
    +
    +               One-sample Kolmogorov-Smirnov test
    +
    +      data:  v
    +      D = 0.18874, p-value = 0.4223
    +      alternative hypothesis: two-sided
    +    */
    +
    +    val rKSStat = 0.18874
    +    val rKSPVal = 0.4223
    +    val rData = sc.parallelize(
    +      Array(
    +        1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501,
    +        -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555,
    +        -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063,
    +        -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691,
    +        0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942
    +      )
    +    )
    +    val rCompResult = Statistics.kolmogorovSmirnovTest(rData, "norm", 0, 1)
    +    assert(rCompResult.statistic ~== rKSStat relTol 1e-4)
    +    assert(rCompResult.pValue ~== rKSPVal relTol 1e-4)
    +  }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
    index 16ecae23dd9d4..5feccdf33681a 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
    @@ -17,31 +17,32 @@
     
     package org.apache.spark.mllib.stat
     
    -import org.scalatest.FunSuite
    -
     import org.apache.commons.math3.distribution.NormalDistribution
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     
    -class KernelDensitySuite extends FunSuite with MLlibTestSparkContext {
    +class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext {
       test("kernel density single sample") {
         val rdd = sc.parallelize(Array(5.0))
         val evaluationPoints = Array(5.0, 6.0)
    -    val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints)
    +    val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
         val normal = new NormalDistribution(5.0, 3.0)
         val acceptableErr = 1e-6
    -    assert(densities(0) - normal.density(5.0) < acceptableErr)
    -    assert(densities(0) - normal.density(6.0) < acceptableErr)
    +    assert(math.abs(densities(0) - normal.density(5.0)) < acceptableErr)
    +    assert(math.abs(densities(1) - normal.density(6.0)) < acceptableErr)
       }
     
       test("kernel density multiple samples") {
         val rdd = sc.parallelize(Array(5.0, 10.0))
         val evaluationPoints = Array(5.0, 6.0)
    -    val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints)
    +    val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
         val normal1 = new NormalDistribution(5.0, 3.0)
         val normal2 = new NormalDistribution(10.0, 3.0)
         val acceptableErr = 1e-6
    -    assert(densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2 < acceptableErr)
    -    assert(densities(0) - (normal1.density(6.0) + normal2.density(6.0)) / 2 < acceptableErr)
    +    assert(math.abs(
    +      densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2) < acceptableErr)
    +    assert(math.abs(
    +      densities(1) - (normal1.density(6.0) + normal2.density(6.0)) / 2) < acceptableErr)
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
    index 23b0eec865de6..07efde4f5e6dc 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
    @@ -17,12 +17,11 @@
     
     package org.apache.spark.mllib.stat
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class MultivariateOnlineSummarizerSuite extends FunSuite {
    +class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
     
       test("basic error handing") {
         val summarizer = new MultivariateOnlineSummarizer
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
    index fac2498e4dcb3..aa60deb665aeb 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
    @@ -17,49 +17,48 @@
     
     package org.apache.spark.mllib.stat.distribution
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{ Vectors, Matrices }
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
     
    -class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext {
    +class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext {
       test("univariate") {
         val x1 = Vectors.dense(0.0)
         val x2 = Vectors.dense(1.5)
    -                     
    +
         val mu = Vectors.dense(0.0)
         val sigma1 = Matrices.dense(1, 1, Array(1.0))
         val dist1 = new MultivariateGaussian(mu, sigma1)
         assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
         assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)
    -    
    +
         val sigma2 = Matrices.dense(1, 1, Array(4.0))
         val dist2 = new MultivariateGaussian(mu, sigma2)
         assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
         assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
       }
    -  
    +
       test("multivariate") {
         val x1 = Vectors.dense(0.0, 0.0)
         val x2 = Vectors.dense(1.0, 1.0)
    -    
    +
         val mu = Vectors.dense(0.0, 0.0)
         val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
         val dist1 = new MultivariateGaussian(mu, sigma1)
         assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
         assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)
    -    
    +
         val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
         val dist2 = new MultivariateGaussian(mu, sigma2)
         assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
         assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
       }
    -  
    +
       test("multivariate degenerate") {
         val x1 = Vectors.dense(0.0, 0.0)
         val x2 = Vectors.dense(1.0, 1.0)
    -    
    +
         val mu = Vectors.dense(0.0, 0.0)
         val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
         val dist = new MultivariateGaussian(mu, sigma)
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
    index ce983eb27fa35..356d957f15909 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
    @@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree
     import scala.collection.JavaConverters._
     import scala.collection.mutable
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.configuration.Algo._
    @@ -34,7 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.util.Utils
     
     
    -class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
    +class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       /////////////////////////////////////////////////////////////////////////////
       // Tests examining individual elements of training
    @@ -859,7 +858,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
       }
     }
     
    -object DecisionTreeSuite extends FunSuite {
    +object DecisionTreeSuite extends SparkFunSuite {
     
       def validateClassifier(
           model: DecisionTreeModel,
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
    index 55b0bac7d49fe..2521b3342181a 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
    @@ -17,8 +17,7 @@
     
     package org.apache.spark.mllib.tree
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.{Logging, SparkFunSuite}
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.configuration.Algo._
     import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
    @@ -32,7 +31,7 @@ import org.apache.spark.util.Utils
     /**
      * Test suite for [[GradientBoostedTrees]].
      */
    -class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
    +class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
     
       test("Regression with continuous features: SquaredError") {
         GradientBoostedTreesSuite.testCombinations.foreach {
    @@ -51,7 +50,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
               EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
             } catch {
               case e: java.lang.AssertionError =>
    -            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
    +            logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
                   s" subsamplingRate=$subsamplingRate")
                 throw e
             }
    @@ -81,7 +80,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
               EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.85, "mae")
             } catch {
               case e: java.lang.AssertionError =>
    -            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
    +            logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
                   s" subsamplingRate=$subsamplingRate")
                 throw e
             }
    @@ -112,7 +111,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
               EnsembleTestHelper.validateClassifier(gbt, GradientBoostedTreesSuite.data, 0.9)
             } catch {
               case e: java.lang.AssertionError =>
    -            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
    +            logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
                   s" subsamplingRate=$subsamplingRate")
                 throw e
             }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
    index 92b498580af03..49aff21fe7914 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
    @@ -17,15 +17,14 @@
     
     package org.apache.spark.mllib.tree
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     
     /**
      * Test suites for [[GiniAggregator]] and [[EntropyAggregator]].
      */
    -class ImpuritySuite extends FunSuite with MLlibTestSparkContext {
    +class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext {
       test("Gini impurity does not support negative labels") {
         val gini = new GiniAggregator(2)
         intercept[IllegalArgumentException] {
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
    index ee3bc98486862..e6df5d974bf36 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
    @@ -19,8 +19,7 @@ package org.apache.spark.mllib.tree
     
     import scala.collection.mutable
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.tree.configuration.Algo._
    @@ -35,7 +34,7 @@ import org.apache.spark.util.Utils
     /**
      * Test suite for [[RandomForest]].
      */
    -class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
    +class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
       def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) {
         val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
         val rdd = sc.parallelize(arr)
    @@ -196,7 +195,6 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
           numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
         val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
           featureSubsetStrategy = "sqrt", seed = 12345)
    -    EnsembleTestHelper.validateClassifier(model, arr, 1.0)
       }
     
       test("subsampling rate in RandomForest"){
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
    index b184e936672ca..9d756da410325 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
    @@ -17,15 +17,14 @@
     
     package org.apache.spark.mllib.tree.impl
     
    -import org.scalatest.FunSuite
    -
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.tree.EnsembleTestHelper
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     
     /**
      * Test suite for [[BaggedPoint]].
      */
    -class BaggedPointSuite extends FunSuite with MLlibTestSparkContext  {
    +class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext  {
     
       test("BaggedPoint RDD: without subsampling") {
         val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
    index 668fc1d43c5d6..70219e9ad9d3e 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
    @@ -21,19 +21,19 @@ import java.io.File
     
     import scala.io.Source
     
    -import org.scalatest.FunSuite
    -
     import breeze.linalg.{squaredDistance => breezeSquaredDistance}
     import com.google.common.base.Charsets
     import com.google.common.io.Files
     
    +import org.apache.spark.SparkException
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.util.MLUtils._
     import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.util.Utils
     
    -class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
    +class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       test("epsilon computation") {
         assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.")
    @@ -63,7 +63,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
           val fastSquaredDist3 =
             fastSquaredDistance(v2, norm2, v3, norm3, precision)
           assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m")
    -      if (m > 10) { 
    +      if (m > 10) {
             val v4 = Vectors.sparse(n, indices.slice(0, m - 10),
               indices.map(i => a(i) + 0.5).slice(0, m - 10))
             val norm4 = Vectors.norm(v4, 2.0)
    @@ -109,6 +109,40 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
         Utils.deleteRecursively(tempDir)
       }
     
    +  test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") {
    +    val lines =
    +      """
    +        |0
    +        |0 0:4.0 4:5.0 6:6.0
    +      """.stripMargin
    +    val tempDir = Utils.createTempDir()
    +    val file = new File(tempDir.getPath, "part-00000")
    +    Files.write(lines, file, Charsets.US_ASCII)
    +    val path = tempDir.toURI.toString
    +
    +    intercept[SparkException] {
    +      loadLibSVMFile(sc, path).collect()
    +    }
    +    Utils.deleteRecursively(tempDir)
    +  }
    +
    +  test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") {
    +    val lines =
    +      """
    +        |0
    +        |0 3:4.0 2:5.0 6:6.0
    +      """.stripMargin
    +    val tempDir = Utils.createTempDir()
    +    val file = new File(tempDir.getPath, "part-00000")
    +    Files.write(lines, file, Charsets.US_ASCII)
    +    val path = tempDir.toURI.toString
    +
    +    intercept[SparkException] {
    +      loadLibSVMFile(sc, path).collect()
    +    }
    +    Utils.deleteRecursively(tempDir)
    +  }
    +
       test("saveAsLibSVMFile") {
         val examples = sc.parallelize(Seq(
           LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))),
    @@ -168,7 +202,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
                 "Each training+validation set combined should contain all of the data.")
             }
             // K fold cross validation should only have each element in the validation set exactly once
    -        assert(foldedRdds.map(_._2).reduce((x,y) => x.union(y)).collect().sorted ===
    +        assert(foldedRdds.map(_._2).reduce((x, y) => x.union(y)).collect().sorted ===
               data.collect().sorted)
           }
         }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
    index b658889476d37..5d1796ef65722 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
    @@ -17,13 +17,14 @@
     
     package org.apache.spark.mllib.util
     
    -import org.scalatest.Suite
    -import org.scalatest.BeforeAndAfterAll
    +import org.scalatest.{BeforeAndAfterAll, Suite}
     
     import org.apache.spark.{SparkConf, SparkContext}
    +import org.apache.spark.sql.SQLContext
     
     trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
       @transient var sc: SparkContext = _
    +  @transient var sqlContext: SQLContext = _
     
       override def beforeAll() {
         super.beforeAll()
    @@ -31,12 +32,15 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
           .setMaster("local[2]")
           .setAppName("MLlibUnitTest")
         sc = new SparkContext(conf)
    +    sqlContext = new SQLContext(sc)
       }
     
       override def afterAll() {
    +    sqlContext = null
         if (sc != null) {
           sc.stop()
         }
    +    sc = null
         super.afterAll()
       }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
    index f68fb95eac4e4..16d7c3ab39b03 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
    @@ -17,11 +17,9 @@
     
     package org.apache.spark.mllib.util
     
    -import org.scalatest.FunSuite
    +import org.apache.spark.{SparkException, SparkFunSuite}
     
    -import org.apache.spark.SparkException
    -
    -class NumericParserSuite extends FunSuite {
    +class NumericParserSuite extends SparkFunSuite {
     
       test("parser") {
         val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)"
    @@ -35,8 +33,15 @@ class NumericParserSuite extends FunSuite {
         malformatted.foreach { s =>
           intercept[SparkException] {
             NumericParser.parse(s)
    -        println(s"Didn't detect malformatted string $s.")
    +        throw new RuntimeException(s"Didn't detect malformatted string $s.")
           }
         }
       }
    +
    +  test("parser with whitespaces") {
    +    val s = "(0.0, [1.0, 2.0])"
    +    val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]]
    +    assert(parsed(0).asInstanceOf[Double] === 0.0)
    +    assert(parsed(1).asInstanceOf[Array[Double]] === Array(1.0, 2.0))
    +  }
     }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
    index 59e6c778806f4..8f475f30249d6 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
    @@ -17,12 +17,12 @@
     
     package org.apache.spark.mllib.util
     
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.linalg.Vectors
    -import org.scalatest.FunSuite
     import org.apache.spark.mllib.util.TestingUtils._
     import org.scalatest.exceptions.TestFailedException
     
    -class TestingUtilsSuite extends FunSuite {
    +class TestingUtilsSuite extends SparkFunSuite {
     
       test("Comparing doubles using relative error.") {
     
    diff --git a/network/common/pom.xml b/network/common/pom.xml
    index 0c3147761cfc5..7dc3068ab8cb7 100644
    --- a/network/common/pom.xml
    +++ b/network/common/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.spark
         spark-parent_2.10
    -    1.4.0-SNAPSHOT
    +    1.5.0-SNAPSHOT
         ../../pom.xml
       
     
    @@ -77,7 +77,7 @@
         
         
           org.mockito
    -      mockito-all
    +      mockito-core
           test
         
         
    diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
    index 6b514aaa1290d..7d27439cfde7a 100644
    --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
    +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
    @@ -39,6 +39,12 @@
     public class JavaUtils {
       private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class);
     
    +  /**
    +   * Define a default value for driver memory here since this value is referenced across the code
    +   * base and nearly all files already use Utils.scala
    +   */
    +  public static final long DEFAULT_DRIVER_MEM_MB = 1024;
    +
       /** Closes the given object, ignoring IOExceptions. */
       public static void closeQuietly(Closeable closeable) {
         try {
    diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml
    index 7dc7c65825e34..532463e96fbb7 100644
    --- a/network/shuffle/pom.xml
    +++ b/network/shuffle/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.spark
         spark-parent_2.10
    -    1.4.0-SNAPSHOT
    +    1.5.0-SNAPSHOT
         ../../pom.xml
       
     
    @@ -79,7 +79,7 @@
         
         
           org.mockito
    -      mockito-all
    +      mockito-core
           test
         
       
    diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
    index dd08e24cade23..022ed88a16480 100644
    --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
    +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
    @@ -108,7 +108,8 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) {
     
         if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) {
           return getHashBasedShuffleBlockData(executor, blockId);
    -    } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager)) {
    +    } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager)
    +      || "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager".equals(executor.shuffleManager)) {
           return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId);
         } else {
           throw new UnsupportedOperationException(
    diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java
    index 60485bace643c..ce954b8a289e4 100644
    --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java
    +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java
    @@ -24,6 +24,9 @@
     
     import org.apache.spark.network.protocol.Encoders;
     
    +// Needed by ScalaDoc. See SPARK-7726
    +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
    +
     /** Request to read a set of blocks. Returns {@link StreamHandle}. */
     public class OpenBlocks extends BlockTransferMessage {
       public final String appId;
    diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java
    index 38acae3b31d64..cca8b17c4f129 100644
    --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java
    +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java
    @@ -22,6 +22,9 @@
     
     import org.apache.spark.network.protocol.Encoders;
     
    +// Needed by ScalaDoc. See SPARK-7726
    +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
    +
     /**
      * Initial registration message between an executor and its local shuffle server.
      * Returns nothing (empty bye array).
    diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java
    index 9a9220211a50c..1915295aa6cc2 100644
    --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java
    +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java
    @@ -20,6 +20,9 @@
     import com.google.common.base.Objects;
     import io.netty.buffer.ByteBuf;
     
    +// Needed by ScalaDoc. See SPARK-7726
    +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
    +
     /**
      * Identifier for a fixed number of chunks to read from a stream created by an "open blocks"
      * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}.
    diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java
    index 2ff9aaa650f92..3caed59d508fd 100644
    --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java
    +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java
    @@ -24,6 +24,9 @@
     
     import org.apache.spark.network.protocol.Encoders;
     
    +// Needed by ScalaDoc. See SPARK-7726
    +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
    +
     
     /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
     public class UploadBlock extends BlockTransferMessage {
    diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml
    index 1e2e9c80af6cc..a99f7c4392d3d 100644
    --- a/network/yarn/pom.xml
    +++ b/network/yarn/pom.xml
    @@ -22,7 +22,7 @@
       
         org.apache.spark
         spark-parent_2.10
    -    1.4.0-SNAPSHOT
    +    1.5.0-SNAPSHOT
         ../../pom.xml
       
     
    diff --git a/pom.xml b/pom.xml
    index cf9279ea5a2a6..370c95dd03632 100644
    --- a/pom.xml
    +++ b/pom.xml
    @@ -26,7 +26,7 @@
       
       org.apache.spark
       spark-parent_2.10
    -  1.4.0-SNAPSHOT
    +  1.5.0-SNAPSHOT
       pom
       Spark Project Parent POM
       http://spark.apache.org/
    @@ -102,40 +102,43 @@
         external/twitter
         external/flume
         external/flume-sink
    +    external/flume-assembly
         external/mqtt
         external/zeromq
         examples
         repl
         launcher
    +    external/kafka
    +    external/kafka-assembly
       
     
       
         UTF-8
         UTF-8
    -    org.spark-project.akka
    -    2.3.4-spark
    -    1.6
    +    com.typesafe.akka
    +    2.3.11
    +    1.7
         spark
    -    2.0.1
         0.21.1
         shaded-protobuf
         1.7.10
         1.2.17
         2.2.0
    -    2.4.1
    +    2.5.0
         ${hadoop.version}
    -    0.98.7-hadoop1
    +    0.98.7-hadoop2
         hbase
    -    1.4.0
    +    1.6.0
         3.4.5
    +    2.4.0
         org.spark-project.hive
         
         0.13.1a
         
         0.13.1
         10.10.1.1
    -    1.6.0rc3
    -    1.2.3
    +    1.7.0
    +    1.2.4
         8.1.14.v20131031
         3.0.0.v201112011016
         0.5.0
    @@ -143,10 +146,10 @@
         2.0.8
         3.1.0
         1.7.7
    -    
    +    hadoop2
         0.7.1
    -    1.8.3
    -    1.1.0
    +    1.9.16
    +    1.2.1
         4.3.2
         3.4.1
         ${project.build.directory}/spark-test-classpath.txt
    @@ -154,11 +157,12 @@
         2.10
         ${scala.version}
         org.scala-lang
    -    3.6.3
    -    1.8.8
    +    1.9.13
         2.4.4
         1.1.1.7
         1.1.2
    +    
    +    false
     
         ${java.home}
     
    @@ -175,9 +179,10 @@
         compile
         compile
         compile
    +    test
     
         
         ${session.executionRootDirectory}
    @@ -247,7 +252,7 @@
         
           mapr-repo
           MapR Repository
    -      http://repository.mapr.com/maven
    +      http://repository.mapr.com/maven/
           
             true
           
    @@ -266,6 +271,30 @@
             false
           
         
    +    
    +    
    +      twttr-repo
    +      Twttr Repository
    +      http://maven.twttr.com
    +      
    +        true
    +      
    +      
    +        false
    +      
    +    
    +    
    +    
    +      spark-1.4-staging
    +      Spark 1.4 RC4 Staging Repository
    +      https://repository.apache.org/content/repositories/orgapachespark-1112
    +      
    +        true
    +      
    +      
    +        false
    +      
    +    
       
       
         
    @@ -312,11 +341,6 @@
       
       
         
    -      
    -        ${jline.groupid}
    -        jline
    -        ${jline.version}
    -      
           
             com.twitter
             chill_${scala.binary.version}
    @@ -492,7 +516,7 @@
           
             net.jpountz.lz4
             lz4
    -        1.2.0
    +        1.3.0
           
           
             com.clearspring.analytics
    @@ -573,7 +597,7 @@
           
             io.netty
             netty-all
    -        4.0.23.Final
    +        4.0.28.Final
           
           
             org.apache.derby
    @@ -668,8 +692,8 @@
           
           
             org.mockito
    -        mockito-all
    -        1.9.0
    +        mockito-core
    +        1.9.5
             test
           
           
    @@ -684,6 +708,18 @@
             4.10
             test
           
    +      
    +        org.hamcrest
    +        hamcrest-core
    +        1.3
    +        test
    +      
    +      
    +        org.hamcrest
    +        hamcrest-library
    +        1.3
    +        test
    +      
           
             com.novocode
             junit-interface
    @@ -693,7 +729,7 @@
           
             org.apache.curator
             curator-recipes
    -        2.4.0
    +        ${curator.version}
             ${hadoop.deps.scope}
             
               
    @@ -702,6 +738,16 @@
               
             
           
    +      
    +        org.apache.curator
    +        curator-client
    +        ${curator.version}
    +      
    +      
    +        org.apache.curator
    +        curator-framework
    +        ${curator.version}
    +      
           
             org.apache.hadoop
             hadoop-client
    @@ -712,6 +758,10 @@
                 asm
                 asm
               
    +          
    +            org.codehaus.jackson
    +            jackson-mapper-asl
    +          
               
                 org.ow2.asm
                 asm
    @@ -724,6 +774,10 @@
                 commons-logging
                 commons-logging
               
    +          
    +            org.mockito
    +            mockito-all
    +          
               
                 org.mortbay.jetty
                 servlet-api-2.5
    @@ -1044,17 +1098,23 @@
             
           
           
    -        com.twitter
    +        org.apache.parquet
             parquet-column
             ${parquet.version}
             ${parquet.deps.scope}
           
           
    -        com.twitter
    +        org.apache.parquet
             parquet-hadoop
             ${parquet.version}
             ${parquet.deps.scope}
           
    +      
    +        org.apache.parquet
    +        parquet-avro
    +        ${parquet.version}
    +        ${parquet.test.deps.scope}
    +      
           
             org.apache.flume
             flume-ng-core
    @@ -1065,6 +1125,10 @@
                 io.netty
                 netty
               
    +          
    +            org.apache.flume
    +            flume-ng-auth
    +          
               
                 org.apache.thrift
                 libthrift
    @@ -1180,15 +1244,6 @@
                   -target
                   ${java.version}
                 
    -            
    -            
    -              
    -                org.scalamacros
    -                paradise_${scala.version}
    -                ${scala.macros.version}
    -              
    -            
               
             
             
    @@ -1217,7 +1272,7 @@
                   **/*Suite.java
                 
                 ${project.build.directory}/surefire-reports
    -            -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m
    +            -Xmx3g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m
                 
                   
    +          ${create.dependency.reduced.pom}
               
                 
                   
    @@ -1517,6 +1579,26 @@
               
             
           
    +
    +      
    +        org.apache.maven.plugins
    +        maven-antrun-plugin
    +        
    +          
    +            create-tmp-dir
    +            generate-test-resources
    +            
    +              run
    +            
    +            
    +              
    +                
    +              
    +            
    +          
    +        
    +      
    +
           
           
             org.apache.maven.plugins
    @@ -1632,26 +1714,28 @@
         -->
     
         
    -      hadoop-2.2
    +      hadoop-1
           
    -        2.2.0
    -        2.5.0
    -        0.98.7-hadoop2
    -        hadoop2
    -        1.9.13
    +        1.2.1
    +        2.4.1
    +        0.98.7-hadoop1
    +        hadoop1
    +        1.8.8
    +        org.spark-project.akka
    +        2.3.4-spark
           
         
     
    +    
    +      hadoop-2.2
    +    
    +    
    +
         
           hadoop-2.3
           
             2.3.0
    -        2.5.0
             0.9.3
    -        0.98.7-hadoop2
    -        3.1.1
    -        hadoop2
    -        1.9.13
           
         
     
    @@ -1659,12 +1743,17 @@
           hadoop-2.4
           
             2.4.0
    -        2.5.0
             0.9.3
    -        0.98.7-hadoop2
    -        3.1.1
    -        hadoop2
    -        1.9.13
    +      
    +    
    +
    +    
    +      hadoop-2.6
    +      
    +        2.6.0
    +        0.9.3
    +        3.4.6
    +        2.6.0
           
         
     
    @@ -1698,7 +1787,7 @@
             
               org.apache.curator
               curator-recipes
    -          2.4.0
    +          ${curator.version}
               
                 
                   org.apache.zookeeper
    @@ -1720,22 +1809,6 @@
             sql/hive-thriftserver
           
         
    -    
    -      hive-0.12.0
    -      
    -        0.12.0-protobuf-2.5
    -        0.12.0
    -        10.4.2.0
    -      
    -    
    -    
    -      hive-0.13.1
    -      
    -        0.13.1a
    -        0.13.1
    -        10.10.1.1
    -      
    -    
     
         
           scala-2.10
    @@ -1748,10 +1821,15 @@
             ${scala.version}
             org.scala-lang
           
    -      
    -        external/kafka
    -        external/kafka-assembly
    -      
    +      
    +        
    +          
    +            ${jline.groupid}
    +            jline
    +            ${jline.version}
    +          
    +        
    +      
         
     
         
    @@ -1770,10 +1848,28 @@
             scala-2.11
           
           
    -        2.11.6
    +        2.11.7
             2.11
    -        2.12.1
    -        jline
    +      
    +    
    +
    +    
    +      
    +      release
    +      
    +        
    +        true
           
         
     
    diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
    index dde92949fa175..f16bf989f200b 100644
    --- a/project/MimaBuild.scala
    +++ b/project/MimaBuild.scala
    @@ -91,7 +91,7 @@ object MimaBuild {
     
       def mimaSettings(sparkHome: File, projectRef: ProjectRef) = {
         val organization = "org.apache.spark"
    -    val previousSparkVersion = "1.3.0"
    +    val previousSparkVersion = "1.4.0"
         val fullId = "spark-" + projectRef.project + "_2.10"
         mimaDefaultSettings ++
         Seq(previousArtifact := Some(organization % fullId % previousSparkVersion),
    diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
    index a47e29e2ef365..4e4e810ec36e3 100644
    --- a/project/MimaExcludes.scala
    +++ b/project/MimaExcludes.scala
    @@ -34,10 +34,75 @@ import com.typesafe.tools.mima.core.ProblemFilters._
     object MimaExcludes {
         def excludes(version: String) =
           version match {
    +        case v if v.startsWith("1.5") =>
    +          Seq(
    +            MimaBuild.excludeSparkPackage("deploy"),
    +            // These are needed if checking against the sbt build, since they are part of
    +            // the maven-generated artifacts in 1.3.
    +            excludePackage("org.spark-project.jetty"),
    +            MimaBuild.excludeSparkPackage("unused"),
    +            // JavaRDDLike is not meant to be extended by user programs
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.api.java.JavaRDDLike.partitioner"),
    +            // Modification of private static method
    +            ProblemFilters.exclude[IncompatibleMethTypeProblem](
    +              "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"),
    +            // Mima false positive (was a private[spark] class)
    +            ProblemFilters.exclude[MissingClassProblem](
    +              "org.apache.spark.util.collection.PairIterator"),
    +            // Removing a testing method from a private class
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"),
    +            // While private MiMa is still not happy about the changes,
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.ml.regression.LeastSquaresAggregator.this"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.ml.regression.LeastSquaresCostFun.this"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.ml.classification.LogisticCostFun.this"),
    +            // SQL execution is considered private.
    +            excludePackage("org.apache.spark.sql.execution"),
    +            // Parquet support is considered private.
    +            excludePackage("org.apache.spark.sql.parquet"),
    +            // local function inside a method
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1")
    +          ) ++ Seq(
    +            // SPARK-8479 Add numNonzeros and numActives to Matrix.
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.mllib.linalg.Matrix.numNonzeros"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.mllib.linalg.Matrix.numActives")
    +          ) ++ Seq(
    +            // SPARK-8914 Remove RDDApi
    +            ProblemFilters.exclude[MissingClassProblem](
    +            "org.apache.spark.sql.RDDApi")
    +          ) ++ Seq(
    +            // SPARK-8701 Add input metadata in the batch page.
    +            ProblemFilters.exclude[MissingClassProblem](
    +              "org.apache.spark.streaming.scheduler.InputInfo$"),
    +            ProblemFilters.exclude[MissingClassProblem](
    +              "org.apache.spark.streaming.scheduler.InputInfo")
    +          ) ++ Seq(
    +            // SPARK-6797 Support YARN modes for SparkR
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.api.r.PairwiseRRDD.this"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.api.r.RRDD.createRWorker"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.api.r.RRDD.this"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.api.r.StringRRDD.this"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.api.r.BaseRRDD.this")
    +          )
    +
             case v if v.startsWith("1.4") =>
               Seq(
                 MimaBuild.excludeSparkPackage("deploy"),
                 MimaBuild.excludeSparkPackage("ml"),
    +            // SPARK-7910 Adding a method to get the partioner to JavaRDD,
    +            ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"),
                 // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD
                 ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"),
                 // These are needed if checking against the sbt build, since they are part of
    @@ -87,7 +152,14 @@ object MimaExcludes {
                 ProblemFilters.exclude[MissingMethodProblem](
                   "org.apache.spark.mllib.linalg.Vector.toSparse"),
                 ProblemFilters.exclude[MissingMethodProblem](
    -              "org.apache.spark.mllib.linalg.Vector.numActives")
    +              "org.apache.spark.mllib.linalg.Vector.numActives"),
    +            // SPARK-7681 add SparseVector support for gemv
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.mllib.linalg.Matrix.multiply"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.mllib.linalg.DenseMatrix.multiply"),
    +            ProblemFilters.exclude[MissingMethodProblem](
    +              "org.apache.spark.mllib.linalg.SparseMatrix.multiply")
               ) ++ Seq(
                 // Execution should never be included as its always internal.
                 MimaBuild.excludeSparkPackage("sql.execution"),
    @@ -111,17 +183,43 @@ object MimaExcludes {
                   "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues"),
                 ProblemFilters.exclude[MissingClassProblem](
                   "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues$"),
    +            ProblemFilters.exclude[MissingClassProblem](
    +              "org.apache.spark.sql.parquet.ParquetRelation2"),
    +            ProblemFilters.exclude[MissingClassProblem](
    +              "org.apache.spark.sql.parquet.ParquetRelation2$"),
    +            ProblemFilters.exclude[MissingClassProblem](
    +              "org.apache.spark.sql.parquet.ParquetRelation2$MetadataCache"),
                 // These test support classes were moved out of src/main and into src/test:
                 ProblemFilters.exclude[MissingClassProblem](
                   "org.apache.spark.sql.parquet.ParquetTestData"),
                 ProblemFilters.exclude[MissingClassProblem](
                   "org.apache.spark.sql.parquet.ParquetTestData$"),
                 ProblemFilters.exclude[MissingClassProblem](
    -              "org.apache.spark.sql.parquet.TestGroupWriteSupport")
    +              "org.apache.spark.sql.parquet.TestGroupWriteSupport"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"),
    +            // TODO: Remove the following rule once ParquetTest has been moved to src/test.
    +            ProblemFilters.exclude[MissingClassProblem](
    +              "org.apache.spark.sql.parquet.ParquetTest")
               ) ++ Seq(
                 // SPARK-7530 Added StreamingContext.getState()
                 ProblemFilters.exclude[MissingMethodProblem](
                   "org.apache.spark.streaming.StreamingContext.state_=")
    +          ) ++ Seq(
    +            // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some
    +            // unnecessary type bounds in order to fix some compiler warnings that occurred when
    +            // implementing this interface in Java. Note that ShuffleWriter is private[spark].
    +            ProblemFilters.exclude[IncompatibleTemplateDefProblem](
    +              "org.apache.spark.shuffle.ShuffleWriter")
    +          ) ++ Seq(
    +            // SPARK-6888 make jdbc driver handling user definable
    +            // This patch renames some classes to API friendly names.
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"),
    +            ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks")
               )
     
             case v if v.startsWith("1.3") =>
    diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
    index 1b87e4e98bd83..4291b0be2a616 100644
    --- a/project/SparkBuild.scala
    +++ b/project/SparkBuild.scala
    @@ -23,11 +23,12 @@ import scala.collection.JavaConversions._
     import sbt._
     import sbt.Classpaths.publishTask
     import sbt.Keys._
    -import sbtunidoc.Plugin.genjavadocSettings
     import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion
     import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys}
     import net.virtualvoid.sbt.graph.Plugin.graphSettings
     
    +import spray.revolver.RevolverPlugin._
    +
     object BuildCommons {
     
       private val buildLocation = file(".").getAbsoluteFile.getParentFile
    @@ -44,14 +45,16 @@ object BuildCommons {
         sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl",
         "kinesis-asl").map(ProjectRef(buildLocation, _))
     
    -  val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly) =
    -    Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly")
    +  val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly) =
    +    Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly")
           .map(ProjectRef(buildLocation, _))
     
       val tools = ProjectRef(buildLocation, "tools")
       // Root project.
       val spark = ProjectRef(buildLocation, "spark")
       val sparkHome = buildLocation
    +
    +  val testTempDir = s"$sparkHome/target/tmp"
     }
     
     object SparkBuild extends PomBuild {
    @@ -66,6 +69,7 @@ object SparkBuild extends PomBuild {
         import scala.collection.mutable
         var isAlphaYarn = false
         var profiles: mutable.Seq[String] = mutable.Seq("sbt")
    +    // scalastyle:off println
         if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) {
           println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pspark-ganglia-lgpl flag.")
           profiles ++= Seq("spark-ganglia-lgpl")
    @@ -85,6 +89,7 @@ object SparkBuild extends PomBuild {
           println("NOTE: SPARK_YARN is deprecated, please use -Pyarn flag.")
           profiles ++= Seq("yarn")
         }
    +    // scalastyle:on println
         profiles
       }
     
    @@ -93,8 +98,10 @@ object SparkBuild extends PomBuild {
         case None => backwardCompatibility
         case Some(v) =>
           if (backwardCompatibility.nonEmpty)
    +        // scalastyle:off println
             println("Note: We ignore environment variables, when use of profile is detected in " +
               "conjunction with environment variable.")
    +        // scalastyle:on println
           v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq
         }
     
    @@ -118,7 +125,12 @@ object SparkBuild extends PomBuild {
       lazy val MavenCompile = config("m2r") extend(Compile)
       lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
     
    -  lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq (
    +  lazy val sparkGenjavadocSettings: Seq[sbt.Def.Setting[_]] = Seq(
    +    libraryDependencies += compilerPlugin(
    +      "org.spark-project" %% "genjavadoc-plugin" % unidocGenjavadocVersion.value cross CrossVersion.full),
    +    scalacOptions <+= target.map(t => "-P:genjavadoc:out=" + (t / "java")))
    +
    +  lazy val sharedSettings = graphSettings ++ sparkGenjavadocSettings ++ Seq (
         javaHome := sys.env.get("JAVA_HOME")
           .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() })
           .map(file),
    @@ -126,7 +138,7 @@ object SparkBuild extends PomBuild {
         retrieveManaged := true,
         retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
         publishMavenStyle := true,
    -    unidocGenjavadocVersion := "0.8",
    +    unidocGenjavadocVersion := "0.9-spark0",
     
         resolvers += Resolver.mavenLocal,
         otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))),
    @@ -140,7 +152,9 @@ object SparkBuild extends PomBuild {
         javacOptions in (Compile, doc) ++= {
           val Array(major, minor, _) = System.getProperty("java.version").split("\\.", 3)
           if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty
    -    }
    +    },
    +
    +    javacOptions in Compile ++= Seq("-encoding", "UTF-8")
       )
     
       def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = {
    @@ -151,14 +165,13 @@ object SparkBuild extends PomBuild {
       // Note ordering of these settings matter.
       /* Enable shared settings on all projects */
       (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools))
    -    .foreach(enable(sharedSettings ++ ExludedDependencies.settings))
    +    .foreach(enable(sharedSettings ++ ExcludedDependencies.settings ++ Revolver.settings))
     
       /* Enable tests settings for all projects except examples, assembly and tools */
       (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
     
    -  // TODO: remove launcher from this list after 1.4.0
       allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl,
    -    networkCommon, networkShuffle, networkYarn, launcher, unsafe).contains(x)).foreach {
    +    networkCommon, networkShuffle, networkYarn, unsafe).contains(x)).foreach {
           x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)
         }
     
    @@ -174,9 +187,6 @@ object SparkBuild extends PomBuild {
       /* Enable unidoc only for the root spark project */
       enable(Unidoc.settings)(spark)
     
    -  /* Catalyst macro settings */
    -  enable(Catalyst.settings)(catalyst)
    -
       /* Spark SQL Core console settings */
       enable(SQL.settings)(sql)
     
    @@ -200,7 +210,7 @@ object SparkBuild extends PomBuild {
         fork := true,
         outputStrategy in run := Some (StdoutOutput),
     
    -    javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=1g"),
    +    javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=256m"),
     
         sparkShell := {
           (runMain in Compile).toTask(" org.apache.spark.repl.Main -usejavacp").value
    @@ -240,7 +250,7 @@ object Flume {
       This excludes library dependencies in sbt, which are specified in maven but are
       not needed by sbt build.
       */
    -object ExludedDependencies {
    +object ExcludedDependencies {
       lazy val settings = Seq(
         libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") }
       )
    @@ -271,14 +281,6 @@ object OldDeps {
       )
     }
     
    -object Catalyst {
    -  lazy val settings = Seq(
    -    addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full),
    -    // Quasiquotes break compiling scala doc...
    -    // TODO: Investigate fixing this.
    -    sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen")))
    -}
    -
     object SQL {
       lazy val settings = Seq(
         initialCommands in console :=
    @@ -301,7 +303,7 @@ object SQL {
     object Hive {
     
       lazy val settings = Seq(
    -    javaOptions += "-XX:MaxPermSize=1g",
    +    javaOptions += "-XX:MaxPermSize=256m",
         // Specially disable assertions since some Hive tests fail them
         javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"),
         // Multiple queries rely on the TestHive singleton. See comments there for more details.
    @@ -324,6 +326,7 @@ object Hive {
             |import org.apache.spark.sql.functions._
             |import org.apache.spark.sql.hive._
             |import org.apache.spark.sql.hive.test.TestHive._
    +        |import org.apache.spark.sql.hive.test.TestHive.implicits._
             |import org.apache.spark.sql.types._""".stripMargin,
         cleanupCommands in console := "sparkContext.stop()",
         // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce
    @@ -348,7 +351,7 @@ object Assembly {
             .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String])
         },
         jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) =>
    -      if (mName.contains("streaming-kafka-assembly")) {
    +      if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly")) {
             // This must match the same name used in maven (see external/kafka-assembly/pom.xml)
             s"${mName}-${v}.jar"
           } else {
    @@ -502,6 +505,7 @@ object TestSettings {
           "SPARK_DIST_CLASSPATH" ->
             (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"),
           "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))),
    +    javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir",
         javaOptions in Test += "-Dspark.test.home=" + sparkHome,
         javaOptions in Test += "-Dspark.testing=1",
         javaOptions in Test += "-Dspark.port.maxRetries=100",
    @@ -510,10 +514,11 @@ object TestSettings {
         javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true",
         javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true",
         javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
    +    javaOptions in Test += "-Dderby.system.durability=test",
         javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark")
           .map { case (k,v) => s"-D$k=$v" }.toSeq,
         javaOptions in Test += "-ea",
    -    javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g"
    +    javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g"
           .split(" ").toSeq,
         javaOptions += "-Xmx3g",
         // Show full stack trace and duration in test cases.
    @@ -523,6 +528,13 @@ object TestSettings {
         libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test",
         // Only allow one test at a time, even across projects, since they run in the same JVM
         parallelExecution in Test := false,
    +    // Make sure the test temp directory exists.
    +    resourceGenerators in Test <+= resourceManaged in Test map { outDir: File =>
    +      if (!new File(testTempDir).isDirectory()) {
    +        require(new File(testTempDir).mkdirs())
    +      }
    +      Seq[File]()
    +    },
         concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
         // Remove certain packages from Scaladoc
         scalacOptions in (Compile, doc) := Seq(
    diff --git a/project/plugins.sbt b/project/plugins.sbt
    index 7096b0d3ee7de..51820460ca1a0 100644
    --- a/project/plugins.sbt
    +++ b/project/plugins.sbt
    @@ -25,10 +25,12 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6")
     
     addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1")
     
    -addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1")
    +addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3")
     
     addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2")
     
    +addSbtPlugin("io.spray" % "sbt-revolver" % "0.7.2")
    +
     libraryDependencies += "org.ow2.asm"  % "asm" % "5.0.3"
     
     libraryDependencies += "org.ow2.asm"  % "asm-commons" % "5.0.3"
    diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst
    index 8379b8fc8a1e1..518b8e774dd5f 100644
    --- a/python/docs/pyspark.ml.rst
    +++ b/python/docs/pyspark.ml.rst
    @@ -1,8 +1,8 @@
     pyspark.ml package
    -=====================
    +==================
     
     ML Pipeline APIs
    ---------------
    +----------------
     
     .. automodule:: pyspark.ml
         :members:
    @@ -10,7 +10,7 @@ ML Pipeline APIs
         :inherited-members:
     
     pyspark.ml.param module
    --------------------------
    +-----------------------
     
     .. automodule:: pyspark.ml.param
         :members:
    @@ -34,7 +34,7 @@ pyspark.ml.classification module
         :inherited-members:
     
     pyspark.ml.recommendation module
    --------------------------
    +--------------------------------
     
     .. automodule:: pyspark.ml.recommendation
         :members:
    @@ -42,7 +42,7 @@ pyspark.ml.recommendation module
         :inherited-members:
     
     pyspark.ml.regression module
    --------------------------
    +----------------------------
     
     .. automodule:: pyspark.ml.regression
         :members:
    @@ -50,7 +50,7 @@ pyspark.ml.regression module
         :inherited-members:
     
     pyspark.ml.tuning module
    ---------------------------------
    +------------------------
     
     .. automodule:: pyspark.ml.tuning
         :members:
    @@ -58,7 +58,7 @@ pyspark.ml.tuning module
         :inherited-members:
     
     pyspark.ml.evaluation module
    ---------------------------------
    +----------------------------
     
     .. automodule:: pyspark.ml.evaluation
         :members:
    diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
    index 0d21a132048a5..6ef8cf53cc747 100644
    --- a/python/pyspark/accumulators.py
    +++ b/python/pyspark/accumulators.py
    @@ -261,3 +261,9 @@ def _start_update_server():
         thread.daemon = True
         thread.start()
         return server
    +
    +if __name__ == "__main__":
    +    import doctest
    +    (failure_count, test_count) = doctest.testmod()
    +    if failure_count:
    +        exit(-1)
    diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
    index 3de4615428bb6..663c9abe0881e 100644
    --- a/python/pyspark/broadcast.py
    +++ b/python/pyspark/broadcast.py
    @@ -115,4 +115,6 @@ def __reduce__(self):
     
     if __name__ == "__main__":
         import doctest
    -    doctest.testmod()
    +    (failure_count, test_count) = doctest.testmod()
    +    if failure_count:
    +        exit(-1)
    diff --git a/python/pyspark/context.py b/python/pyspark/context.py
    index 31992795a9e45..d7466729b8f36 100644
    --- a/python/pyspark/context.py
    +++ b/python/pyspark/context.py
    @@ -173,6 +173,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
                 self._jvm.PythonAccumulatorParam(host, port))
     
             self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
    +        self.pythonVer = "%d.%d" % sys.version_info[:2]
     
             # Broadcast's __reduce__ method stores Broadcast instances here.
             # This allows other code to determine which Broadcast instances have
    @@ -290,6 +291,26 @@ def version(self):
             """
             return self._jsc.version()
     
    +    @property
    +    @ignore_unicode_prefix
    +    def applicationId(self):
    +        """
    +        A unique identifier for the Spark application.
    +        Its format depends on the scheduler implementation.
    +        (i.e.
    +            in case of local spark app something like 'local-1433865536131'
    +            in case of YARN something like 'application_1433865536131_34483'
    +        )
    +        >>> sc.applicationId  # doctest: +ELLIPSIS
    +        u'local-...'
    +        """
    +        return self._jsc.sc().applicationId()
    +
    +    @property
    +    def startTime(self):
    +        """Return the epoch time when the Spark Context was started."""
    +        return self._jsc.startTime()
    +
         @property
         def defaultParallelism(self):
             """
    @@ -318,6 +339,38 @@ def stop(self):
             with SparkContext._lock:
                 SparkContext._active_spark_context = None
     
    +    def emptyRDD(self):
    +        """
    +        Create an RDD that has no partitions or elements.
    +        """
    +        return RDD(self._jsc.emptyRDD(), self, NoOpSerializer())
    +
    +    def range(self, start, end=None, step=1, numSlices=None):
    +        """
    +        Create a new RDD of int containing elements from `start` to `end`
    +        (exclusive), increased by `step` every element. Can be called the same
    +        way as python's built-in range() function. If called with a single argument,
    +        the argument is interpreted as `end`, and `start` is set to 0.
    +
    +        :param start: the start value
    +        :param end: the end value (exclusive)
    +        :param step: the incremental step (default: 1)
    +        :param numSlices: the number of partitions of the new RDD
    +        :return: An RDD of int
    +
    +        >>> sc.range(5).collect()
    +        [0, 1, 2, 3, 4]
    +        >>> sc.range(2, 4).collect()
    +        [2, 3]
    +        >>> sc.range(1, 7, 2).collect()
    +        [1, 3, 5]
    +        """
    +        if end is None:
    +            end = start
    +            start = 0
    +
    +        return self.parallelize(xrange(start, end, step), numSlices)
    +
         def parallelize(self, c, numSlices=None):
             """
             Distribute a local Python collection to form an RDD. Using xrange
    diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py
    index 4ef2afe03544f..b27e91a4cc251 100644
    --- a/python/pyspark/heapq3.py
    +++ b/python/pyspark/heapq3.py
    @@ -883,6 +883,7 @@ def nlargest(n, iterable, key=None):
     
     
     if __name__ == "__main__":
    -
         import doctest
    -    print(doctest.testmod())
    +    (failure_count, test_count) = doctest.testmod()
    +    if failure_count:
    +        exit(-1)
    diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
    index 3cee4ea6e3a35..90cd342a6cf7f 100644
    --- a/python/pyspark/java_gateway.py
    +++ b/python/pyspark/java_gateway.py
    @@ -51,6 +51,8 @@ def launch_gateway():
             on_windows = platform.system() == "Windows"
             script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
             submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
    +        if os.environ.get("SPARK_TESTING"):
    +            submit_args = "--conf spark.ui.enabled=false " + submit_args
             command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args)
     
             # Start a socket that will be used by PythonGatewayServer to communicate its port to us
    diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py
    index da793d9db7f91..327a11b14b5aa 100644
    --- a/python/pyspark/ml/__init__.py
    +++ b/python/pyspark/ml/__init__.py
    @@ -15,6 +15,6 @@
     # limitations under the License.
     #
     
    -from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel, Evaluator
    +from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel
     
    -__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel", "Evaluator"]
    +__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"]
    diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
    index 8a009c4ac721f..89117e492846b 100644
    --- a/python/pyspark/ml/classification.py
    +++ b/python/pyspark/ml/classification.py
    @@ -17,17 +17,20 @@
     
     from pyspark.ml.util import keyword_only
     from pyspark.ml.wrapper import JavaEstimator, JavaModel
    -from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
    -    HasRegParam
    +from pyspark.ml.param.shared import *
    +from pyspark.ml.regression import (
    +    RandomForestParams, DecisionTreeModel, TreeEnsembleModels)
     from pyspark.mllib.common import inherit_doc
     
     
    -__all__ = ['LogisticRegression', 'LogisticRegressionModel']
    +__all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier',
    +           'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel',
    +           'RandomForestClassifier', 'RandomForestClassificationModel']
     
     
     @inherit_doc
     class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
    -                         HasRegParam):
    +                         HasRegParam, HasTol, HasProbabilityCol):
         """
         Logistic regression.
     
    @@ -41,6 +44,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
         >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
         >>> model.transform(test0).head().prediction
         0.0
    +    >>> model.weights
    +    DenseVector([5.5...])
    +    >>> model.intercept
    +    -2.68...
         >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
         >>> model.transform(test1).head().prediction
         1.0
    @@ -49,26 +56,52 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
             ...
         TypeError: Method setParams forces keyword arguments.
         """
    -    _java_class = "org.apache.spark.ml.classification.LogisticRegression"
    +
    +    # a placeholder to make it appear in the generated doc
    +    elasticNetParam = \
    +        Param(Params._dummy(), "elasticNetParam",
    +              "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
    +              "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
    +    fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.")
    +    threshold = Param(Params._dummy(), "threshold",
    +                      "threshold in binary classification prediction, in range [0, 1].")
     
         @keyword_only
         def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
    -                 maxIter=100, regParam=0.1):
    +                 maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
    +                 threshold=0.5, probabilityCol="probability"):
             """
             __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    -                 maxIter=100, regParam=0.1)
    +                 maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
    +                 threshold=0.5, probabilityCol="probability")
             """
             super(LogisticRegression, self).__init__()
    -        self._setDefault(maxIter=100, regParam=0.1)
    +        self._java_obj = self._new_java_obj(
    +            "org.apache.spark.ml.classification.LogisticRegression", self.uid)
    +        #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty
    +        #  is an L2 penalty. For alpha = 1, it is an L1 penalty.
    +        self.elasticNetParam = \
    +            Param(self, "elasticNetParam",
    +                  "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " +
    +                  "is an L2 penalty. For alpha = 1, it is an L1 penalty.")
    +        #: param for whether to fit an intercept term.
    +        self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.")
    +        #: param for threshold in binary classification prediction, in range [0, 1].
    +        self.threshold = Param(self, "threshold",
    +                               "threshold in binary classification prediction, in range [0, 1].")
    +        self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6,
    +                         fitIntercept=True, threshold=0.5)
             kwargs = self.__init__._input_kwargs
             self.setParams(**kwargs)
     
         @keyword_only
         def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
    -                  maxIter=100, regParam=0.1):
    +                  maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
    +                  threshold=0.5, probabilityCol="probability"):
             """
             setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    -                  maxIter=100, regParam=0.1)
    +                  maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
    +                 threshold=0.5, probabilityCol="probability")
             Sets params for logistic regression.
             """
             kwargs = self.setParams._input_kwargs
    @@ -77,12 +110,471 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
         def _create_model(self, java_model):
             return LogisticRegressionModel(java_model)
     
    +    def setElasticNetParam(self, value):
    +        """
    +        Sets the value of :py:attr:`elasticNetParam`.
    +        """
    +        self._paramMap[self.elasticNetParam] = value
    +        return self
    +
    +    def getElasticNetParam(self):
    +        """
    +        Gets the value of elasticNetParam or its default value.
    +        """
    +        return self.getOrDefault(self.elasticNetParam)
    +
    +    def setFitIntercept(self, value):
    +        """
    +        Sets the value of :py:attr:`fitIntercept`.
    +        """
    +        self._paramMap[self.fitIntercept] = value
    +        return self
    +
    +    def getFitIntercept(self):
    +        """
    +        Gets the value of fitIntercept or its default value.
    +        """
    +        return self.getOrDefault(self.fitIntercept)
    +
    +    def setThreshold(self, value):
    +        """
    +        Sets the value of :py:attr:`threshold`.
    +        """
    +        self._paramMap[self.threshold] = value
    +        return self
    +
    +    def getThreshold(self):
    +        """
    +        Gets the value of threshold or its default value.
    +        """
    +        return self.getOrDefault(self.threshold)
    +
     
     class LogisticRegressionModel(JavaModel):
         """
         Model fitted by LogisticRegression.
         """
     
    +    @property
    +    def weights(self):
    +        """
    +        Model weights.
    +        """
    +        return self._call_java("weights")
    +
    +    @property
    +    def intercept(self):
    +        """
    +        Model intercept.
    +        """
    +        return self._call_java("intercept")
    +
    +
    +class TreeClassifierParams(object):
    +    """
    +    Private class to track supported impurity measures.
    +    """
    +    supportedImpurities = ["entropy", "gini"]
    +
    +
    +class GBTParams(object):
    +    """
    +    Private class to track supported GBT params.
    +    """
    +    supportedLossTypes = ["logistic"]
    +
    +
    +@inherit_doc
    +class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
    +                             DecisionTreeParams, HasCheckpointInterval):
    +    """
    +    `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
    +    learning algorithm for classification.
    +    It supports both binary and multiclass labels, as well as both continuous and categorical
    +    features.
    +
    +    >>> from pyspark.mllib.linalg import Vectors
    +    >>> from pyspark.ml.feature import StringIndexer
    +    >>> df = sqlContext.createDataFrame([
    +    ...     (1.0, Vectors.dense(1.0)),
    +    ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
    +    >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
    +    >>> si_model = stringIndexer.fit(df)
    +    >>> td = si_model.transform(df)
    +    >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed")
    +    >>> model = dt.fit(td)
    +    >>> model.numNodes
    +    3
    +    >>> model.depth
    +    1
    +    >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
    +    >>> model.transform(test0).head().prediction
    +    0.0
    +    >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
    +    >>> model.transform(test1).head().prediction
    +    1.0
    +    """
    +
    +    # a placeholder to make it appear in the generated doc
    +    impurity = Param(Params._dummy(), "impurity",
    +                     "Criterion used for information gain calculation (case-insensitive). " +
    +                     "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities))
    +
    +    @keyword_only
    +    def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
    +                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    +                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"):
    +        """
    +        __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    +                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
    +                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
    +        """
    +        super(DecisionTreeClassifier, self).__init__()
    +        self._java_obj = self._new_java_obj(
    +            "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
    +        #: param for Criterion used for information gain calculation (case-insensitive).
    +        self.impurity = \
    +            Param(self, "impurity",
    +                  "Criterion used for information gain calculation (case-insensitive). " +
    +                  "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities))
    +        self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    +                         maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
    +                         impurity="gini")
    +        kwargs = self.__init__._input_kwargs
    +        self.setParams(**kwargs)
    +
    +    @keyword_only
    +    def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
    +                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    +                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
    +                  impurity="gini"):
    +        """
    +        setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    +                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
    +                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
    +        Sets params for the DecisionTreeClassifier.
    +        """
    +        kwargs = self.setParams._input_kwargs
    +        return self._set(**kwargs)
    +
    +    def _create_model(self, java_model):
    +        return DecisionTreeClassificationModel(java_model)
    +
    +    def setImpurity(self, value):
    +        """
    +        Sets the value of :py:attr:`impurity`.
    +        """
    +        self._paramMap[self.impurity] = value
    +        return self
    +
    +    def getImpurity(self):
    +        """
    +        Gets the value of impurity or its default value.
    +        """
    +        return self.getOrDefault(self.impurity)
    +
    +
    +@inherit_doc
    +class DecisionTreeClassificationModel(DecisionTreeModel):
    +    """
    +    Model fitted by DecisionTreeClassifier.
    +    """
    +
    +
    +@inherit_doc
    +class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
    +                             DecisionTreeParams, HasCheckpointInterval):
    +    """
    +    `http://en.wikipedia.org/wiki/Random_forest  Random Forest`
    +    learning algorithm for classification.
    +    It supports both binary and multiclass labels, as well as both continuous and categorical
    +    features.
    +
    +    >>> from numpy import allclose
    +    >>> from pyspark.mllib.linalg import Vectors
    +    >>> from pyspark.ml.feature import StringIndexer
    +    >>> df = sqlContext.createDataFrame([
    +    ...     (1.0, Vectors.dense(1.0)),
    +    ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
    +    >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
    +    >>> si_model = stringIndexer.fit(df)
    +    >>> td = si_model.transform(df)
    +    >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
    +    >>> model = rf.fit(td)
    +    >>> allclose(model.treeWeights, [1.0, 1.0])
    +    True
    +    >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
    +    >>> model.transform(test0).head().prediction
    +    0.0
    +    >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
    +    >>> model.transform(test1).head().prediction
    +    1.0
    +    """
    +
    +    # a placeholder to make it appear in the generated doc
    +    impurity = Param(Params._dummy(), "impurity",
    +                     "Criterion used for information gain calculation (case-insensitive). " +
    +                     "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities))
    +    subsamplingRate = Param(Params._dummy(), "subsamplingRate",
    +                            "Fraction of the training data used for learning each decision tree, " +
    +                            "in range (0, 1].")
    +    numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1)")
    +    featureSubsetStrategy = \
    +        Param(Params._dummy(), "featureSubsetStrategy",
    +              "The number of features to consider for splits at each tree node. Supported " +
    +              "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies))
    +
    +    @keyword_only
    +    def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
    +                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    +                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
    +                 numTrees=20, featureSubsetStrategy="auto", seed=None):
    +        """
    +        __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    +                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
    +                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
    +                 numTrees=20, featureSubsetStrategy="auto", seed=None)
    +        """
    +        super(RandomForestClassifier, self).__init__()
    +        self._java_obj = self._new_java_obj(
    +            "org.apache.spark.ml.classification.RandomForestClassifier", self.uid)
    +        #: param for Criterion used for information gain calculation (case-insensitive).
    +        self.impurity = \
    +            Param(self, "impurity",
    +                  "Criterion used for information gain calculation (case-insensitive). " +
    +                  "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities))
    +        #: param for Fraction of the training data used for learning each decision tree,
    +        #  in range (0, 1]
    +        self.subsamplingRate = Param(self, "subsamplingRate",
    +                                     "Fraction of the training data used for learning each " +
    +                                     "decision tree, in range (0, 1].")
    +        #: param for Number of trees to train (>= 1)
    +        self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1)")
    +        #: param for The number of features to consider for splits at each tree node
    +        self.featureSubsetStrategy = \
    +            Param(self, "featureSubsetStrategy",
    +                  "The number of features to consider for splits at each tree node. Supported " +
    +                  "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies))
    +        self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    +                         maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
    +                         impurity="gini", numTrees=20, featureSubsetStrategy="auto")
    +        kwargs = self.__init__._input_kwargs
    +        self.setParams(**kwargs)
    +
    +    @keyword_only
    +    def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
    +                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    +                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
    +                  impurity="gini", numTrees=20, featureSubsetStrategy="auto"):
    +        """
    +        setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    +                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
    +                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
    +                  impurity="gini", numTrees=20, featureSubsetStrategy="auto")
    +        Sets params for linear classification.
    +        """
    +        kwargs = self.setParams._input_kwargs
    +        return self._set(**kwargs)
    +
    +    def _create_model(self, java_model):
    +        return RandomForestClassificationModel(java_model)
    +
    +    def setImpurity(self, value):
    +        """
    +        Sets the value of :py:attr:`impurity`.
    +        """
    +        self._paramMap[self.impurity] = value
    +        return self
    +
    +    def getImpurity(self):
    +        """
    +        Gets the value of impurity or its default value.
    +        """
    +        return self.getOrDefault(self.impurity)
    +
    +    def setSubsamplingRate(self, value):
    +        """
    +        Sets the value of :py:attr:`subsamplingRate`.
    +        """
    +        self._paramMap[self.subsamplingRate] = value
    +        return self
    +
    +    def getSubsamplingRate(self):
    +        """
    +        Gets the value of subsamplingRate or its default value.
    +        """
    +        return self.getOrDefault(self.subsamplingRate)
    +
    +    def setNumTrees(self, value):
    +        """
    +        Sets the value of :py:attr:`numTrees`.
    +        """
    +        self._paramMap[self.numTrees] = value
    +        return self
    +
    +    def getNumTrees(self):
    +        """
    +        Gets the value of numTrees or its default value.
    +        """
    +        return self.getOrDefault(self.numTrees)
    +
    +    def setFeatureSubsetStrategy(self, value):
    +        """
    +        Sets the value of :py:attr:`featureSubsetStrategy`.
    +        """
    +        self._paramMap[self.featureSubsetStrategy] = value
    +        return self
    +
    +    def getFeatureSubsetStrategy(self):
    +        """
    +        Gets the value of featureSubsetStrategy or its default value.
    +        """
    +        return self.getOrDefault(self.featureSubsetStrategy)
    +
    +
    +class RandomForestClassificationModel(TreeEnsembleModels):
    +    """
    +    Model fitted by RandomForestClassifier.
    +    """
    +
    +
    +@inherit_doc
    +class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
    +                    DecisionTreeParams, HasCheckpointInterval):
    +    """
    +    `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)`
    +    learning algorithm for classification.
    +    It supports binary labels, as well as both continuous and categorical features.
    +    Note: Multiclass labels are not currently supported.
    +
    +    >>> from numpy import allclose
    +    >>> from pyspark.mllib.linalg import Vectors
    +    >>> from pyspark.ml.feature import StringIndexer
    +    >>> df = sqlContext.createDataFrame([
    +    ...     (1.0, Vectors.dense(1.0)),
    +    ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
    +    >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
    +    >>> si_model = stringIndexer.fit(df)
    +    >>> td = si_model.transform(df)
    +    >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed")
    +    >>> model = gbt.fit(td)
    +    >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
    +    True
    +    >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
    +    >>> model.transform(test0).head().prediction
    +    0.0
    +    >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
    +    >>> model.transform(test1).head().prediction
    +    1.0
    +    """
    +
    +    # a placeholder to make it appear in the generated doc
    +    lossType = Param(Params._dummy(), "lossType",
    +                     "Loss function which GBT tries to minimize (case-insensitive). " +
    +                     "Supported options: " + ", ".join(GBTParams.supportedLossTypes))
    +    subsamplingRate = Param(Params._dummy(), "subsamplingRate",
    +                            "Fraction of the training data used for learning each decision tree, " +
    +                            "in range (0, 1].")
    +    stepSize = Param(Params._dummy(), "stepSize",
    +                     "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the " +
    +                     "contribution of each estimator")
    +
    +    @keyword_only
    +    def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
    +                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    +                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
    +                 maxIter=20, stepSize=0.1):
    +        """
    +        __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    +                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
    +                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
    +                 lossType="logistic", maxIter=20, stepSize=0.1)
    +        """
    +        super(GBTClassifier, self).__init__()
    +        self._java_obj = self._new_java_obj(
    +            "org.apache.spark.ml.classification.GBTClassifier", self.uid)
    +        #: param for Loss function which GBT tries to minimize (case-insensitive).
    +        self.lossType = Param(self, "lossType",
    +                              "Loss function which GBT tries to minimize (case-insensitive). " +
    +                              "Supported options: " + ", ".join(GBTParams.supportedLossTypes))
    +        #: Fraction of the training data used for learning each decision tree, in range (0, 1].
    +        self.subsamplingRate = Param(self, "subsamplingRate",
    +                                     "Fraction of the training data used for learning each " +
    +                                     "decision tree, in range (0, 1].")
    +        #: Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of
    +        #  each estimator
    +        self.stepSize = Param(self, "stepSize",
    +                              "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
    +                              "the contribution of each estimator")
    +        self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    +                         maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
    +                         lossType="logistic", maxIter=20, stepSize=0.1)
    +        kwargs = self.__init__._input_kwargs
    +        self.setParams(**kwargs)
    +
    +    @keyword_only
    +    def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
    +                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    +                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
    +                  lossType="logistic", maxIter=20, stepSize=0.1):
    +        """
    +        setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    +                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
    +                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
    +                  lossType="logistic", maxIter=20, stepSize=0.1)
    +        Sets params for Gradient Boosted Tree Classification.
    +        """
    +        kwargs = self.setParams._input_kwargs
    +        return self._set(**kwargs)
    +
    +    def _create_model(self, java_model):
    +        return GBTClassificationModel(java_model)
    +
    +    def setLossType(self, value):
    +        """
    +        Sets the value of :py:attr:`lossType`.
    +        """
    +        self._paramMap[self.lossType] = value
    +        return self
    +
    +    def getLossType(self):
    +        """
    +        Gets the value of lossType or its default value.
    +        """
    +        return self.getOrDefault(self.lossType)
    +
    +    def setSubsamplingRate(self, value):
    +        """
    +        Sets the value of :py:attr:`subsamplingRate`.
    +        """
    +        self._paramMap[self.subsamplingRate] = value
    +        return self
    +
    +    def getSubsamplingRate(self):
    +        """
    +        Gets the value of subsamplingRate or its default value.
    +        """
    +        return self.getOrDefault(self.subsamplingRate)
    +
    +    def setStepSize(self, value):
    +        """
    +        Sets the value of :py:attr:`stepSize`.
    +        """
    +        self._paramMap[self.stepSize] = value
    +        return self
    +
    +    def getStepSize(self):
    +        """
    +        Gets the value of stepSize or its default value.
    +        """
    +        return self.getOrDefault(self.stepSize)
    +
    +
    +class GBTClassificationModel(TreeEnsembleModels):
    +    """
    +    Model fitted by GBTClassifier.
    +    """
    +
     
     if __name__ == "__main__":
         import doctest
    diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
    index 02020ebff94c2..595593a7f2cde 100644
    --- a/python/pyspark/ml/evaluation.py
    +++ b/python/pyspark/ml/evaluation.py
    @@ -15,13 +15,72 @@
     # limitations under the License.
     #
     
    -from pyspark.ml.wrapper import JavaEvaluator
    +from abc import abstractmethod, ABCMeta
    +
    +from pyspark.ml.wrapper import JavaWrapper
     from pyspark.ml.param import Param, Params
    -from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol
    +from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
     from pyspark.ml.util import keyword_only
     from pyspark.mllib.common import inherit_doc
     
    -__all__ = ['BinaryClassificationEvaluator']
    +__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator']
    +
    +
    +@inherit_doc
    +class Evaluator(Params):
    +    """
    +    Base class for evaluators that compute metrics from predictions.
    +    """
    +
    +    __metaclass__ = ABCMeta
    +
    +    @abstractmethod
    +    def _evaluate(self, dataset):
    +        """
    +        Evaluates the output.
    +
    +        :param dataset: a dataset that contains labels/observations and
    +               predictions
    +        :return: metric
    +        """
    +        raise NotImplementedError()
    +
    +    def evaluate(self, dataset, params={}):
    +        """
    +        Evaluates the output with optional parameters.
    +
    +        :param dataset: a dataset that contains labels/observations and
    +                        predictions
    +        :param params: an optional param map that overrides embedded
    +                       params
    +        :return: metric
    +        """
    +        if isinstance(params, dict):
    +            if params:
    +                return self.copy(params)._evaluate(dataset)
    +            else:
    +                return self._evaluate(dataset)
    +        else:
    +            raise ValueError("Params must be a param map but got %s." % type(params))
    +
    +
    +@inherit_doc
    +class JavaEvaluator(Evaluator, JavaWrapper):
    +    """
    +    Base class for :py:class:`Evaluator`s that wrap Java/Scala
    +    implementations.
    +    """
    +
    +    __metaclass__ = ABCMeta
    +
    +    def _evaluate(self, dataset):
    +        """
    +        Evaluates the output.
    +        :param dataset: a dataset that contains labels/observations and predictions.
    +        :return: evaluation metric
    +        """
    +        self._transfer_params_to_java()
    +        return self._java_obj.evaluate(dataset._jdf)
     
     
     @inherit_doc
    @@ -42,8 +101,6 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
         0.83...
         """
     
    -    _java_class = "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator"
    -
         # a placeholder to make it appear in the generated doc
         metricName = Param(Params._dummy(), "metricName",
                            "metric name in evaluation (areaUnderROC|areaUnderPR)")
    @@ -56,6 +113,8 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label",
                      metricName="areaUnderROC")
             """
             super(BinaryClassificationEvaluator, self).__init__()
    +        self._java_obj = self._new_java_obj(
    +            "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid)
             #: param for metric name in evaluation (areaUnderROC|areaUnderPR)
             self.metricName = Param(self, "metricName",
                                     "metric name in evaluation (areaUnderROC|areaUnderPR)")
    @@ -68,7 +127,7 @@ def setMetricName(self, value):
             """
             Sets the value of :py:attr:`metricName`.
             """
    -        self.paramMap[self.metricName] = value
    +        self._paramMap[self.metricName] = value
             return self
     
         def getMetricName(self):
    @@ -89,6 +148,72 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",
             return self._set(**kwargs)
     
     
    +@inherit_doc
    +class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
    +    """
    +    Evaluator for Regression, which expects two input
    +    columns: prediction and label.
    +
    +    >>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5),
    +    ...   (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)]
    +    >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"])
    +    ...
    +    >>> evaluator = RegressionEvaluator(predictionCol="raw")
    +    >>> evaluator.evaluate(dataset)
    +    -2.842...
    +    >>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"})
    +    0.993...
    +    >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
    +    -2.649...
    +    """
    +    # Because we will maximize evaluation value (ref: `CrossValidator`),
    +    # when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),
    +    # we take and output the negative of this metric.
    +    metricName = Param(Params._dummy(), "metricName",
    +                       "metric name in evaluation (mse|rmse|r2|mae)")
    +
    +    @keyword_only
    +    def __init__(self, predictionCol="prediction", labelCol="label",
    +                 metricName="rmse"):
    +        """
    +        __init__(self, predictionCol="prediction", labelCol="label", \
    +                 metricName="rmse")
    +        """
    +        super(RegressionEvaluator, self).__init__()
    +        self._java_obj = self._new_java_obj(
    +            "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid)
    +        #: param for metric name in evaluation (mse|rmse|r2|mae)
    +        self.metricName = Param(self, "metricName",
    +                                "metric name in evaluation (mse|rmse|r2|mae)")
    +        self._setDefault(predictionCol="prediction", labelCol="label",
    +                         metricName="rmse")
    +        kwargs = self.__init__._input_kwargs
    +        self._set(**kwargs)
    +
    +    def setMetricName(self, value):
    +        """
    +        Sets the value of :py:attr:`metricName`.
    +        """
    +        self._paramMap[self.metricName] = value
    +        return self
    +
    +    def getMetricName(self):
    +        """
    +        Gets the value of metricName or its default value.
    +        """
    +        return self.getOrDefault(self.metricName)
    +
    +    @keyword_only
    +    def setParams(self, predictionCol="prediction", labelCol="label",
    +                  metricName="rmse"):
    +        """
    +        setParams(self, predictionCol="prediction", labelCol="label", \
    +                  metricName="rmse")
    +        Sets params for regression evaluator.
    +        """
    +        kwargs = self.setParams._input_kwargs
    +        return self._set(**kwargs)
    +
     if __name__ == "__main__":
         import doctest
         from pyspark.context import SparkContext
    diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
    index f35bc1463d51b..9bca7cc000aa5 100644
    --- a/python/pyspark/ml/feature.py
    +++ b/python/pyspark/ml/feature.py
    @@ -21,7 +21,7 @@
     from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
     from pyspark.mllib.common import inherit_doc
     
    -__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'Normalizer', 'OneHotEncoder',
    +__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder',
                'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel',
                'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer',
                'Word2Vec', 'Word2VecModel']
    @@ -43,7 +43,6 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol):
         1.0
         """
     
    -    _java_class = "org.apache.spark.ml.feature.Binarizer"
         # a placeholder to make it appear in the generated doc
         threshold = Param(Params._dummy(), "threshold",
                           "threshold in binary classification prediction, in range [0, 1]")
    @@ -54,6 +53,7 @@ def __init__(self, threshold=0.0, inputCol=None, outputCol=None):
             __init__(self, threshold=0.0, inputCol=None, outputCol=None)
             """
             super(Binarizer, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid)
             self.threshold = Param(self, "threshold",
                                    "threshold in binary classification prediction, in range [0, 1]")
             self._setDefault(threshold=0.0)
    @@ -73,7 +73,7 @@ def setThreshold(self, value):
             """
             Sets the value of :py:attr:`threshold`.
             """
    -        self.paramMap[self.threshold] = value
    +        self._paramMap[self.threshold] = value
             return self
     
         def getThreshold(self):
    @@ -83,6 +83,83 @@ def getThreshold(self):
             return self.getOrDefault(self.threshold)
     
     
    +@inherit_doc
    +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol):
    +    """
    +    Maps a column of continuous features to a column of feature buckets.
    +
    +    >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"])
    +    >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")],
    +    ...     inputCol="values", outputCol="buckets")
    +    >>> bucketed = bucketizer.transform(df).collect()
    +    >>> bucketed[0].buckets
    +    0.0
    +    >>> bucketed[1].buckets
    +    0.0
    +    >>> bucketed[2].buckets
    +    1.0
    +    >>> bucketed[3].buckets
    +    2.0
    +    >>> bucketizer.setParams(outputCol="b").transform(df).head().b
    +    0.0
    +    """
    +
    +    # a placeholder to make it appear in the generated doc
    +    splits = \
    +        Param(Params._dummy(), "splits",
    +              "Split points for mapping continuous features into buckets. With n+1 splits, " +
    +              "there are n buckets. A bucket defined by splits x,y holds values in the " +
    +              "range [x,y) except the last bucket, which also includes y. The splits " +
    +              "should be strictly increasing. Values at -inf, inf must be explicitly " +
    +              "provided to cover all Double values; otherwise, values outside the splits " +
    +              "specified will be treated as errors.")
    +
    +    @keyword_only
    +    def __init__(self, splits=None, inputCol=None, outputCol=None):
    +        """
    +        __init__(self, splits=None, inputCol=None, outputCol=None)
    +        """
    +        super(Bucketizer, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid)
    +        #: param for Splitting points for mapping continuous features into buckets. With n+1 splits,
    +        #  there are n buckets. A bucket defined by splits x,y holds values in the range [x,y)
    +        #  except the last bucket, which also includes y. The splits should be strictly increasing.
    +        #  Values at -inf, inf must be explicitly provided to cover all Double values; otherwise,
    +        #  values outside the splits specified will be treated as errors.
    +        self.splits = \
    +            Param(self, "splits",
    +                  "Split points for mapping continuous features into buckets. With n+1 splits, " +
    +                  "there are n buckets. A bucket defined by splits x,y holds values in the " +
    +                  "range [x,y) except the last bucket, which also includes y. The splits " +
    +                  "should be strictly increasing. Values at -inf, inf must be explicitly " +
    +                  "provided to cover all Double values; otherwise, values outside the splits " +
    +                  "specified will be treated as errors.")
    +        kwargs = self.__init__._input_kwargs
    +        self.setParams(**kwargs)
    +
    +    @keyword_only
    +    def setParams(self, splits=None, inputCol=None, outputCol=None):
    +        """
    +        setParams(self, splits=None, inputCol=None, outputCol=None)
    +        Sets params for this Bucketizer.
    +        """
    +        kwargs = self.setParams._input_kwargs
    +        return self._set(**kwargs)
    +
    +    def setSplits(self, value):
    +        """
    +        Sets the value of :py:attr:`splits`.
    +        """
    +        self._paramMap[self.splits] = value
    +        return self
    +
    +    def getSplits(self):
    +        """
    +        Gets the value of threshold or its default value.
    +        """
    +        return self.getOrDefault(self.splits)
    +
    +
     @inherit_doc
     class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
         """
    @@ -100,14 +177,13 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
         SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0})
         """
     
    -    _java_class = "org.apache.spark.ml.feature.HashingTF"
    -
         @keyword_only
         def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
             """
             __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
             """
             super(HashingTF, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid)
             self._setDefault(numFeatures=1 << 18)
             kwargs = self.__init__._input_kwargs
             self.setParams(**kwargs)
    @@ -140,8 +216,6 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol):
         DenseVector([0.2877, 0.0])
         """
     
    -    _java_class = "org.apache.spark.ml.feature.IDF"
    -
         # a placeholder to make it appear in the generated doc
         minDocFreq = Param(Params._dummy(), "minDocFreq",
                            "minimum of documents in which a term should appear for filtering")
    @@ -152,6 +226,7 @@ def __init__(self, minDocFreq=0, inputCol=None, outputCol=None):
             __init__(self, minDocFreq=0, inputCol=None, outputCol=None)
             """
             super(IDF, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid)
             self.minDocFreq = Param(self, "minDocFreq",
                                     "minimum of documents in which a term should appear for filtering")
             self._setDefault(minDocFreq=0)
    @@ -171,7 +246,7 @@ def setMinDocFreq(self, value):
             """
             Sets the value of :py:attr:`minDocFreq`.
             """
    -        self.paramMap[self.minDocFreq] = value
    +        self._paramMap[self.minDocFreq] = value
             return self
     
         def getMinDocFreq(self):
    @@ -180,6 +255,9 @@ def getMinDocFreq(self):
             """
             return self.getOrDefault(self.minDocFreq)
     
    +    def _create_model(self, java_model):
    +        return IDFModel(java_model)
    +
     
     class IDFModel(JavaModel):
         """
    @@ -187,6 +265,75 @@ class IDFModel(JavaModel):
         """
     
     
    +@inherit_doc
    +@ignore_unicode_prefix
    +class NGram(JavaTransformer, HasInputCol, HasOutputCol):
    +    """
    +    A feature transformer that converts the input array of strings into an array of n-grams. Null
    +    values in the input array are ignored.
    +    It returns an array of n-grams where each n-gram is represented by a space-separated string of
    +    words.
    +    When the input is empty, an empty array is returned.
    +    When the input array length is less than n (number of elements per n-gram), no n-grams are
    +    returned.
    +
    +    >>> df = sqlContext.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])])
    +    >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams")
    +    >>> ngram.transform(df).head()
    +    Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e'])
    +    >>> # Change n-gram length
    +    >>> ngram.setParams(n=4).transform(df).head()
    +    Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e'])
    +    >>> # Temporarily modify output column.
    +    >>> ngram.transform(df, {ngram.outputCol: "output"}).head()
    +    Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], output=[u'a b c d', u'b c d e'])
    +    >>> ngram.transform(df).head()
    +    Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e'])
    +    >>> # Must use keyword arguments to specify params.
    +    >>> ngram.setParams("text")
    +    Traceback (most recent call last):
    +        ...
    +    TypeError: Method setParams forces keyword arguments.
    +    """
    +
    +    # a placeholder to make it appear in the generated doc
    +    n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)")
    +
    +    @keyword_only
    +    def __init__(self, n=2, inputCol=None, outputCol=None):
    +        """
    +        __init__(self, n=2, inputCol=None, outputCol=None)
    +        """
    +        super(NGram, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid)
    +        self.n = Param(self, "n", "number of elements per n-gram (>=1)")
    +        self._setDefault(n=2)
    +        kwargs = self.__init__._input_kwargs
    +        self.setParams(**kwargs)
    +
    +    @keyword_only
    +    def setParams(self, n=2, inputCol=None, outputCol=None):
    +        """
    +        setParams(self, n=2, inputCol=None, outputCol=None)
    +        Sets params for this NGram.
    +        """
    +        kwargs = self.setParams._input_kwargs
    +        return self._set(**kwargs)
    +
    +    def setN(self, value):
    +        """
    +        Sets the value of :py:attr:`n`.
    +        """
    +        self._paramMap[self.n] = value
    +        return self
    +
    +    def getN(self):
    +        """
    +        Gets the value of n or its default value.
    +        """
    +        return self.getOrDefault(self.n)
    +
    +
     @inherit_doc
     class Normalizer(JavaTransformer, HasInputCol, HasOutputCol):
         """
    @@ -208,14 +355,13 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol):
         # a placeholder to make it appear in the generated doc
         p = Param(Params._dummy(), "p", "the p norm value.")
     
    -    _java_class = "org.apache.spark.ml.feature.Normalizer"
    -
         @keyword_only
         def __init__(self, p=2.0, inputCol=None, outputCol=None):
             """
             __init__(self, p=2.0, inputCol=None, outputCol=None)
             """
             super(Normalizer, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid)
             self.p = Param(self, "p", "the p norm value.")
             self._setDefault(p=2.0)
             kwargs = self.__init__._input_kwargs
    @@ -234,7 +380,7 @@ def setP(self, value):
             """
             Sets the value of :py:attr:`p`.
             """
    -        self.paramMap[self.p] = value
    +        self._paramMap[self.p] = value
             return self
     
         def getP(self):
    @@ -247,66 +393,73 @@ def getP(self):
     @inherit_doc
     class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol):
         """
    -    A one-hot encoder that maps a column of label indices to a column of binary vectors, with
    -    at most a single one-value. By default, the binary vector has an element for each category, so
    -    with 5 categories, an input value of 2.0 would map to an output vector of
    -    (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so
    -    the output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
    -    of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
    -    linearly dependent because they sum up to one.
    -
    -    TODO: This method requires the use of StringIndexer first. Decouple them.
    +    A one-hot encoder that maps a column of category indices to a
    +    column of binary vectors, with at most a single one-value per row
    +    that indicates the input category index.
    +    For example with 5 categories, an input value of 2.0 would map to
    +    an output vector of `[0.0, 0.0, 1.0, 0.0]`.
    +    The last category is not included by default (configurable via
    +    :py:attr:`dropLast`) because it makes the vector entries sum up to
    +    one, and hence linearly dependent.
    +    So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
    +    Note that this is different from scikit-learn's OneHotEncoder,
    +    which keeps all categories.
    +    The output vectors are sparse.
    +
    +    .. seealso::
    +
    +       :py:class:`StringIndexer` for converting categorical values into
    +       category indices
     
         >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
         >>> model = stringIndexer.fit(stringIndDf)
         >>> td = model.transform(stringIndDf)
    -    >>> encoder = OneHotEncoder(includeFirst=False, inputCol="indexed", outputCol="features")
    +    >>> encoder = OneHotEncoder(inputCol="indexed", outputCol="features")
         >>> encoder.transform(td).head().features
    -    SparseVector(2, {})
    +    SparseVector(2, {0: 1.0})
         >>> encoder.setParams(outputCol="freqs").transform(td).head().freqs
    -    SparseVector(2, {})
    -    >>> params = {encoder.includeFirst: True, encoder.outputCol: "test"}
    +    SparseVector(2, {0: 1.0})
    +    >>> params = {encoder.dropLast: False, encoder.outputCol: "test"}
         >>> encoder.transform(td, params).head().test
         SparseVector(3, {0: 1.0})
         """
     
    -    _java_class = "org.apache.spark.ml.feature.OneHotEncoder"
    -
         # a placeholder to make it appear in the generated doc
    -    includeFirst = Param(Params._dummy(), "includeFirst", "include first category")
    +    dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category")
     
         @keyword_only
    -    def __init__(self, includeFirst=True, inputCol=None, outputCol=None):
    +    def __init__(self, dropLast=True, inputCol=None, outputCol=None):
             """
             __init__(self, includeFirst=True, inputCol=None, outputCol=None)
             """
             super(OneHotEncoder, self).__init__()
    -        self.includeFirst = Param(self, "includeFirst", "include first category")
    -        self._setDefault(includeFirst=True)
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid)
    +        self.dropLast = Param(self, "dropLast", "whether to drop the last category")
    +        self._setDefault(dropLast=True)
             kwargs = self.__init__._input_kwargs
             self.setParams(**kwargs)
     
         @keyword_only
    -    def setParams(self, includeFirst=True, inputCol=None, outputCol=None):
    +    def setParams(self, dropLast=True, inputCol=None, outputCol=None):
             """
    -        setParams(self, includeFirst=True, inputCol=None, outputCol=None)
    +        setParams(self, dropLast=True, inputCol=None, outputCol=None)
             Sets params for this OneHotEncoder.
             """
             kwargs = self.setParams._input_kwargs
             return self._set(**kwargs)
     
    -    def setIncludeFirst(self, value):
    +    def setDropLast(self, value):
             """
    -        Sets the value of :py:attr:`includeFirst`.
    +        Sets the value of :py:attr:`dropLast`.
             """
    -        self.paramMap[self.includeFirst] = value
    +        self._paramMap[self.dropLast] = value
             return self
     
    -    def getIncludeFirst(self):
    +    def getDropLast(self):
             """
    -        Gets the value of includeFirst or its default value.
    +        Gets the value of dropLast or its default value.
             """
    -        return self.getOrDefault(self.includeFirst)
    +        return self.getOrDefault(self.dropLast)
     
     
     @inherit_doc
    @@ -327,8 +480,6 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol):
         DenseVector([0.5, 0.25, 2.0, 1.0, 4.0])
         """
     
    -    _java_class = "org.apache.spark.ml.feature.PolynomialExpansion"
    -
         # a placeholder to make it appear in the generated doc
         degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)")
     
    @@ -338,6 +489,8 @@ def __init__(self, degree=2, inputCol=None, outputCol=None):
             __init__(self, degree=2, inputCol=None, outputCol=None)
             """
             super(PolynomialExpansion, self).__init__()
    +        self._java_obj = self._new_java_obj(
    +            "org.apache.spark.ml.feature.PolynomialExpansion", self.uid)
             self.degree = Param(self, "degree", "the polynomial degree to expand (>= 1)")
             self._setDefault(degree=2)
             kwargs = self.__init__._input_kwargs
    @@ -356,7 +509,7 @@ def setDegree(self, value):
             """
             Sets the value of :py:attr:`degree`.
             """
    -        self.paramMap[self.degree] = value
    +        self._paramMap[self.degree] = value
             return self
     
         def getDegree(self):
    @@ -370,23 +523,25 @@ def getDegree(self):
     @ignore_unicode_prefix
     class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
         """
    -    A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default)
    -    or using it to split the text (set matching to false). Optional parameters also allow filtering
    -    tokens using a minimal length.
    +    A regex based tokenizer that extracts tokens either by using the
    +    provided regex pattern (in Java dialect) to split the text
    +    (default) or repeatedly matching the regex (if gaps is true).
    +    Optional parameters also allow filtering tokens using a minimal
    +    length.
         It returns an array of strings that can be empty.
     
    -    >>> df = sqlContext.createDataFrame([("a b c",)], ["text"])
    +    >>> df = sqlContext.createDataFrame([("a b  c",)], ["text"])
         >>> reTokenizer = RegexTokenizer(inputCol="text", outputCol="words")
         >>> reTokenizer.transform(df).head()
    -    Row(text=u'a b c', words=[u'a', u'b', u'c'])
    +    Row(text=u'a b  c', words=[u'a', u'b', u'c'])
         >>> # Change a parameter.
         >>> reTokenizer.setParams(outputCol="tokens").transform(df).head()
    -    Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
    +    Row(text=u'a b  c', tokens=[u'a', u'b', u'c'])
         >>> # Temporarily modify a parameter.
         >>> reTokenizer.transform(df, {reTokenizer.outputCol: "words"}).head()
    -    Row(text=u'a b c', words=[u'a', u'b', u'c'])
    +    Row(text=u'a b  c', words=[u'a', u'b', u'c'])
         >>> reTokenizer.transform(df).head()
    -    Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
    +    Row(text=u'a b  c', tokens=[u'a', u'b', u'c'])
         >>> # Must use keyword arguments to specify params.
         >>> reTokenizer.setParams("text")
         Traceback (most recent call last):
    @@ -394,33 +549,29 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
         TypeError: Method setParams forces keyword arguments.
         """
     
    -    _java_class = "org.apache.spark.ml.feature.RegexTokenizer"
         # a placeholder to make it appear in the generated doc
         minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)")
    -    gaps = Param(Params._dummy(), "gaps", "Set regex to match gaps or tokens")
    -    pattern = Param(Params._dummy(), "pattern", "regex pattern used for tokenizing")
    +    gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens")
    +    pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing")
     
         @keyword_only
    -    def __init__(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+",
    -                 inputCol=None, outputCol=None):
    +    def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None):
             """
    -        __init__(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+",
    -                 inputCol=None, outputCol=None)
    +        __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None)
             """
             super(RegexTokenizer, self).__init__()
    -        self.minTokenLength = Param(self, "minLength", "minimum token length (>= 0)")
    -        self.gaps = Param(self, "gaps", "Set regex to match gaps or tokens")
    -        self.pattern = Param(self, "pattern", "regex pattern used for tokenizing")
    -        self._setDefault(minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+")
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid)
    +        self.minTokenLength = Param(self, "minTokenLength", "minimum token length (>= 0)")
    +        self.gaps = Param(self, "gaps", "whether regex splits on gaps (True) or matches tokens")
    +        self.pattern = Param(self, "pattern", "regex pattern (Java dialect) used for tokenizing")
    +        self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+")
             kwargs = self.__init__._input_kwargs
             self.setParams(**kwargs)
     
         @keyword_only
    -    def setParams(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+",
    -                  inputCol=None, outputCol=None):
    +    def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None):
             """
    -        setParams(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+",
    -                  inputCol="input", outputCol="output")
    +        setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None)
             Sets params for this RegexTokenizer.
             """
             kwargs = self.setParams._input_kwargs
    @@ -430,7 +581,7 @@ def setMinTokenLength(self, value):
             """
             Sets the value of :py:attr:`minTokenLength`.
             """
    -        self.paramMap[self.minTokenLength] = value
    +        self._paramMap[self.minTokenLength] = value
             return self
     
         def getMinTokenLength(self):
    @@ -443,7 +594,7 @@ def setGaps(self, value):
             """
             Sets the value of :py:attr:`gaps`.
             """
    -        self.paramMap[self.gaps] = value
    +        self._paramMap[self.gaps] = value
             return self
     
         def getGaps(self):
    @@ -456,7 +607,7 @@ def setPattern(self, value):
             """
             Sets the value of :py:attr:`pattern`.
             """
    -        self.paramMap[self.pattern] = value
    +        self._paramMap[self.pattern] = value
             return self
     
         def getPattern(self):
    @@ -476,12 +627,14 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
         >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"])
         >>> standardScaler = StandardScaler(inputCol="a", outputCol="scaled")
         >>> model = standardScaler.fit(df)
    +    >>> model.mean
    +    DenseVector([1.0])
    +    >>> model.std
    +    DenseVector([1.4142])
         >>> model.transform(df).collect()[1].scaled
         DenseVector([1.4142])
         """
     
    -    _java_class = "org.apache.spark.ml.feature.StandardScaler"
    -
         # a placeholder to make it appear in the generated doc
         withMean = Param(Params._dummy(), "withMean", "Center data with mean")
         withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation")
    @@ -492,6 +645,7 @@ def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None):
             __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None)
             """
             super(StandardScaler, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid)
             self.withMean = Param(self, "withMean", "Center data with mean")
             self.withStd = Param(self, "withStd", "Scale to unit standard deviation")
             self._setDefault(withMean=False, withStd=True)
    @@ -511,7 +665,7 @@ def setWithMean(self, value):
             """
             Sets the value of :py:attr:`withMean`.
             """
    -        self.paramMap[self.withMean] = value
    +        self._paramMap[self.withMean] = value
             return self
     
         def getWithMean(self):
    @@ -524,7 +678,7 @@ def setWithStd(self, value):
             """
             Sets the value of :py:attr:`withStd`.
             """
    -        self.paramMap[self.withStd] = value
    +        self._paramMap[self.withStd] = value
             return self
     
         def getWithStd(self):
    @@ -533,12 +687,29 @@ def getWithStd(self):
             """
             return self.getOrDefault(self.withStd)
     
    +    def _create_model(self, java_model):
    +        return StandardScalerModel(java_model)
    +
     
     class StandardScalerModel(JavaModel):
         """
         Model fitted by StandardScaler.
         """
     
    +    @property
    +    def std(self):
    +        """
    +        Standard deviation of the StandardScalerModel.
    +        """
    +        return self._call_java("std")
    +
    +    @property
    +    def mean(self):
    +        """
    +        Mean of the StandardScalerModel.
    +        """
    +        return self._call_java("mean")
    +
     
     @inherit_doc
     class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
    @@ -556,14 +727,13 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
         [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)]
         """
     
    -    _java_class = "org.apache.spark.ml.feature.StringIndexer"
    -
         @keyword_only
         def __init__(self, inputCol=None, outputCol=None):
             """
             __init__(self, inputCol=None, outputCol=None)
             """
             super(StringIndexer, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
             kwargs = self.__init__._input_kwargs
             self.setParams(**kwargs)
     
    @@ -576,6 +746,9 @@ def setParams(self, inputCol=None, outputCol=None):
             kwargs = self.setParams._input_kwargs
             return self._set(**kwargs)
     
    +    def _create_model(self, java_model):
    +        return StringIndexerModel(java_model)
    +
     
     class StringIndexerModel(JavaModel):
         """
    @@ -609,14 +782,13 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
         TypeError: Method setParams forces keyword arguments.
         """
     
    -    _java_class = "org.apache.spark.ml.feature.Tokenizer"
    -
         @keyword_only
         def __init__(self, inputCol=None, outputCol=None):
             """
             __init__(self, inputCol=None, outputCol=None)
             """
             super(Tokenizer, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Tokenizer", self.uid)
             kwargs = self.__init__._input_kwargs
             self.setParams(**kwargs)
     
    @@ -646,14 +818,13 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
         DenseVector([0.0, 1.0])
         """
     
    -    _java_class = "org.apache.spark.ml.feature.VectorAssembler"
    -
         @keyword_only
         def __init__(self, inputCols=None, outputCol=None):
             """
             __init__(self, inputCols=None, outputCol=None)
             """
             super(VectorAssembler, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid)
             kwargs = self.__init__._input_kwargs
             self.setParams(**kwargs)
     
    @@ -720,7 +891,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol):
         DenseVector([1.0, 0.0])
         """
     
    -    _java_class = "org.apache.spark.ml.feature.VectorIndexer"
         # a placeholder to make it appear in the generated doc
         maxCategories = Param(Params._dummy(), "maxCategories",
                               "Threshold for the number of values a categorical feature can take " +
    @@ -733,6 +903,7 @@ def __init__(self, maxCategories=20, inputCol=None, outputCol=None):
             __init__(self, maxCategories=20, inputCol=None, outputCol=None)
             """
             super(VectorIndexer, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid)
             self.maxCategories = Param(self, "maxCategories",
                                        "Threshold for the number of values a categorical feature " +
                                        "can take (>= 2). If a feature is found to have " +
    @@ -754,7 +925,7 @@ def setMaxCategories(self, value):
             """
             Sets the value of :py:attr:`maxCategories`.
             """
    -        self.paramMap[self.maxCategories] = value
    +        self._paramMap[self.maxCategories] = value
             return self
     
         def getMaxCategories(self):
    @@ -763,6 +934,15 @@ def getMaxCategories(self):
             """
             return self.getOrDefault(self.maxCategories)
     
    +    def _create_model(self, java_model):
    +        return VectorIndexerModel(java_model)
    +
    +
    +class VectorIndexerModel(JavaModel):
    +    """
    +    Model fitted by VectorIndexer.
    +    """
    +
     
     @inherit_doc
     @ignore_unicode_prefix
    @@ -778,7 +958,6 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
         DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276])
         """
     
    -    _java_class = "org.apache.spark.ml.feature.Word2Vec"
         # a placeholder to make it appear in the generated doc
         vectorSize = Param(Params._dummy(), "vectorSize",
                            "the dimension of codes after transforming from words")
    @@ -790,12 +969,13 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
     
         @keyword_only
         def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
    -                 seed=42, inputCol=None, outputCol=None):
    +                 seed=None, inputCol=None, outputCol=None):
             """
    -        __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
    -                 seed=42, inputCol=None, outputCol=None)
    +        __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, \
    +                 seed=None, inputCol=None, outputCol=None)
             """
             super(Word2Vec, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid)
             self.vectorSize = Param(self, "vectorSize",
                                     "the dimension of codes after transforming from words")
             self.numPartitions = Param(self, "numPartitions",
    @@ -804,15 +984,15 @@ def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025,
                                   "the minimum number of times a token must appear to be included " +
                                   "in the word2vec model's vocabulary")
             self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
    -                         seed=42)
    +                         seed=None)
             kwargs = self.__init__._input_kwargs
             self.setParams(**kwargs)
     
         @keyword_only
         def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
    -                  seed=42, inputCol=None, outputCol=None):
    +                  seed=None, inputCol=None, outputCol=None):
             """
    -        setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=42,
    +        setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None, \
                      inputCol=None, outputCol=None)
             Sets params for this Word2Vec.
             """
    @@ -823,7 +1003,7 @@ def setVectorSize(self, value):
             """
             Sets the value of :py:attr:`vectorSize`.
             """
    -        self.paramMap[self.vectorSize] = value
    +        self._paramMap[self.vectorSize] = value
             return self
     
         def getVectorSize(self):
    @@ -836,7 +1016,7 @@ def setNumPartitions(self, value):
             """
             Sets the value of :py:attr:`numPartitions`.
             """
    -        self.paramMap[self.numPartitions] = value
    +        self._paramMap[self.numPartitions] = value
             return self
     
         def getNumPartitions(self):
    @@ -849,7 +1029,7 @@ def setMinCount(self, value):
             """
             Sets the value of :py:attr:`minCount`.
             """
    -        self.paramMap[self.minCount] = value
    +        self._paramMap[self.minCount] = value
             return self
     
         def getMinCount(self):
    @@ -858,6 +1038,9 @@ def getMinCount(self):
             """
             return self.getOrDefault(self.minCount)
     
    +    def _create_model(self, java_model):
    +        return Word2VecModel(java_model)
    +
     
     class Word2VecModel(JavaModel):
         """
    diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
    index 49c20b4cf70cf..7845536161e07 100644
    --- a/python/pyspark/ml/param/__init__.py
    +++ b/python/pyspark/ml/param/__init__.py
    @@ -16,6 +16,7 @@
     #
     
     from abc import ABCMeta
    +import copy
     
     from pyspark.ml.util import Identifiable
     
    @@ -29,9 +30,9 @@ class Param(object):
         """
     
         def __init__(self, parent, name, doc):
    -        if not isinstance(parent, Params):
    -            raise TypeError("Parent must be a Params but got type %s." % type(parent))
    -        self.parent = parent
    +        if not isinstance(parent, Identifiable):
    +            raise TypeError("Parent must be an Identifiable but got type %s." % type(parent))
    +        self.parent = parent.uid
             self.name = str(name)
             self.doc = str(doc)
     
    @@ -41,6 +42,15 @@ def __str__(self):
         def __repr__(self):
             return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc)
     
    +    def __hash__(self):
    +        return hash(str(self))
    +
    +    def __eq__(self, other):
    +        if isinstance(other, Param):
    +            return self.parent == other.parent and self.name == other.name
    +        else:
    +            return False
    +
     
     class Params(Identifiable):
         """
    @@ -51,10 +61,13 @@ class Params(Identifiable):
         __metaclass__ = ABCMeta
     
         #: internal param map for user-supplied values param map
    -    paramMap = {}
    +    _paramMap = {}
     
         #: internal param map for default values
    -    defaultParamMap = {}
    +    _defaultParamMap = {}
    +
    +    #: value returned by :py:func:`params`
    +    _params = None
     
         @property
         def params(self):
    @@ -63,10 +76,12 @@ def params(self):
             uses :py:func:`dir` to get all attributes of type
             :py:class:`Param`.
             """
    -        return list(filter(lambda attr: isinstance(attr, Param),
    -                           [getattr(self, x) for x in dir(self) if x != "params"]))
    +        if self._params is None:
    +            self._params = list(filter(lambda attr: isinstance(attr, Param),
    +                                       [getattr(self, x) for x in dir(self) if x != "params"]))
    +        return self._params
     
    -    def _explain(self, param):
    +    def explainParam(self, param):
             """
             Explains a single param and returns its name, doc, and optional
             default value and user-supplied value in a string.
    @@ -74,10 +89,10 @@ def _explain(self, param):
             param = self._resolveParam(param)
             values = []
             if self.isDefined(param):
    -            if param in self.defaultParamMap:
    -                values.append("default: %s" % self.defaultParamMap[param])
    -            if param in self.paramMap:
    -                values.append("current: %s" % self.paramMap[param])
    +            if param in self._defaultParamMap:
    +                values.append("default: %s" % self._defaultParamMap[param])
    +            if param in self._paramMap:
    +                values.append("current: %s" % self._paramMap[param])
             else:
                 values.append("undefined")
             valueStr = "(" + ", ".join(values) + ")"
    @@ -88,7 +103,7 @@ def explainParams(self):
             Returns the documentation of all params with their optionally
             default values and user-supplied values.
             """
    -        return "\n".join([self._explain(param) for param in self.params])
    +        return "\n".join([self.explainParam(param) for param in self.params])
     
         def getParam(self, paramName):
             """
    @@ -105,56 +120,76 @@ def isSet(self, param):
             Checks whether a param is explicitly set by user.
             """
             param = self._resolveParam(param)
    -        return param in self.paramMap
    +        return param in self._paramMap
     
         def hasDefault(self, param):
             """
             Checks whether a param has a default value.
             """
             param = self._resolveParam(param)
    -        return param in self.defaultParamMap
    +        return param in self._defaultParamMap
     
         def isDefined(self, param):
             """
    -        Checks whether a param is explicitly set by user or has a default value.
    +        Checks whether a param is explicitly set by user or has
    +        a default value.
             """
             return self.isSet(param) or self.hasDefault(param)
     
    +    def hasParam(self, paramName):
    +        """
    +        Tests whether this instance contains a param with a given
    +        (string) name.
    +        """
    +        param = self._resolveParam(paramName)
    +        return param in self.params
    +
         def getOrDefault(self, param):
             """
             Gets the value of a param in the user-supplied param map or its
    -        default value. Raises an error if either is set.
    +        default value. Raises an error if neither is set.
             """
    -        if isinstance(param, Param):
    -            if param in self.paramMap:
    -                return self.paramMap[param]
    -            else:
    -                return self.defaultParamMap[param]
    -        elif isinstance(param, str):
    -            return self.getOrDefault(self.getParam(param))
    +        param = self._resolveParam(param)
    +        if param in self._paramMap:
    +            return self._paramMap[param]
             else:
    -            raise KeyError("Cannot recognize %r as a param." % param)
    +            return self._defaultParamMap[param]
     
    -    def extractParamMap(self, extraParamMap={}):
    +    def extractParamMap(self, extra={}):
             """
             Extracts the embedded default param values and user-supplied
             values, and then merges them with extra values from input into
             a flat param map, where the latter value is used if there exist
             conflicts, i.e., with ordering: default param values <
    -        user-supplied values < extraParamMap.
    -        :param extraParamMap: extra param values
    +        user-supplied values < extra.
    +        :param extra: extra param values
             :return: merged param map
             """
    -        paramMap = self.defaultParamMap.copy()
    -        paramMap.update(self.paramMap)
    -        paramMap.update(extraParamMap)
    +        paramMap = self._defaultParamMap.copy()
    +        paramMap.update(self._paramMap)
    +        paramMap.update(extra)
             return paramMap
     
    +    def copy(self, extra={}):
    +        """
    +        Creates a copy of this instance with the same uid and some
    +        extra params. The default implementation creates a
    +        shallow copy using :py:func:`copy.copy`, and then copies the
    +        embedded and extra parameters over and returns the copy.
    +        Subclasses should override this method if the default approach
    +        is not sufficient.
    +        :param extra: Extra parameters to copy to the new instance
    +        :return: Copy of this instance
    +        """
    +        that = copy.copy(self)
    +        that._paramMap = self.extractParamMap(extra)
    +        return that
    +
         def _shouldOwn(self, param):
             """
             Validates that the input param belongs to this Params instance.
             """
    -        if param.parent is not self:
    +        if not (self.uid == param.parent and self.hasParam(param.name)):
                 raise ValueError("Param %r does not belong to %r." % (param, self))
     
         def _resolveParam(self, param):
    @@ -175,7 +210,8 @@ def _resolveParam(self, param):
         @staticmethod
         def _dummy():
             """
    -        Returns a dummy Params instance used as a placeholder to generate docs.
    +        Returns a dummy Params instance used as a placeholder to
    +        generate docs.
             """
             dummy = Params()
             dummy.uid = "undefined"
    @@ -186,7 +222,7 @@ def _set(self, **kwargs):
             Sets user-supplied params.
             """
             for param, value in kwargs.items():
    -            self.paramMap[getattr(self, param)] = value
    +            self._paramMap[getattr(self, param)] = value
             return self
     
         def _setDefault(self, **kwargs):
    @@ -194,5 +230,19 @@ def _setDefault(self, **kwargs):
             Sets default params.
             """
             for param, value in kwargs.items():
    -            self.defaultParamMap[getattr(self, param)] = value
    +            self._defaultParamMap[getattr(self, param)] = value
             return self
    +
    +    def _copyValues(self, to, extra={}):
    +        """
    +        Copies param values from this instance to another instance for
    +        params shared by them.
    +        :param to: the target instance
    +        :param extra: extra params to be copied
    +        :return: the target instance with param values copied
    +        """
    +        paramMap = self.extractParamMap(extra)
    +        for p in self.params:
    +            if p in paramMap and to.hasParam(p.name):
    +                to._set(**{p.name: paramMap[p]})
    +        return to
    diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
    index 4a5cc6e64f023..69efc424ec4ef 100644
    --- a/python/pyspark/ml/param/_shared_params_code_gen.py
    +++ b/python/pyspark/ml/param/_shared_params_code_gen.py
    @@ -56,9 +56,10 @@ def _gen_param_header(name, doc, defaultValueStr):
         def __init__(self):
             super(Has$Name, self).__init__()
             #: param for $doc
    -        self.$name = Param(self, "$name", "$doc")
    -        if $defaultValueStr is not None:
    -            self._setDefault($name=$defaultValueStr)'''
    +        self.$name = Param(self, "$name", "$doc")'''
    +    if defaultValueStr is not None:
    +        template += '''
    +        self._setDefault($name=$defaultValueStr)'''
     
         Name = name[0].upper() + name[1:]
         return template \
    @@ -83,7 +84,7 @@ def set$Name(self, value):
             """
             Sets the value of :py:attr:`$name`.
             """
    -        self.paramMap[self.$name] = value
    +        self._paramMap[self.$name] = value
             return self
     
         def get$Name(self):
    @@ -109,13 +110,16 @@ def get$Name(self):
             ("featuresCol", "features column name", "'features'"),
             ("labelCol", "label column name", "'label'"),
             ("predictionCol", "prediction column name", "'prediction'"),
    +        ("probabilityCol", "Column name for predicted class conditional probabilities. " +
    +         "Note: Not all models output well-calibrated probability estimates! These probabilities " +
    +         "should be treated as confidences, not precise probabilities.", "'probability'"),
             ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", "'rawPrediction'"),
             ("inputCol", "input column name", None),
             ("inputCols", "input column names", None),
    -        ("outputCol", "output column name", None),
    +        ("outputCol", "output column name", "self.uid + '__output'"),
             ("numFeatures", "number of features", None),
             ("checkpointInterval", "checkpoint interval (>= 1)", None),
    -        ("seed", "random seed", None),
    +        ("seed", "random seed", "hash(type(self).__name__)"),
             ("tol", "the convergence tolerance for iterative algorithms", None),
             ("stepSize", "Step size to be used for each iteration of optimization.", None)]
         code = []
    @@ -156,6 +160,7 @@ def __init__(self):
         for name, doc in decisionTreeParams:
             variable = paramTemplate.replace("$name", name).replace("$doc", doc)
             dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n    "
    +        realParams += "#: param for " + doc + "\n        "
             realParams += "self." + variable.replace("$owner", "self") + "\n        "
             dtParamMethods += _gen_param_code(name, doc, None) + "\n"
         code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders)
    diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
    index 779cabe853f8e..bc088e4c29e26 100644
    --- a/python/pyspark/ml/param/shared.py
    +++ b/python/pyspark/ml/param/shared.py
    @@ -32,14 +32,12 @@ def __init__(self):
             super(HasMaxIter, self).__init__()
             #: param for max number of iterations (>= 0)
             self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0)")
    -        if None is not None:
    -            self._setDefault(maxIter=None)
     
         def setMaxIter(self, value):
             """
             Sets the value of :py:attr:`maxIter`.
             """
    -        self.paramMap[self.maxIter] = value
    +        self._paramMap[self.maxIter] = value
             return self
     
         def getMaxIter(self):
    @@ -61,14 +59,12 @@ def __init__(self):
             super(HasRegParam, self).__init__()
             #: param for regularization parameter (>= 0)
             self.regParam = Param(self, "regParam", "regularization parameter (>= 0)")
    -        if None is not None:
    -            self._setDefault(regParam=None)
     
         def setRegParam(self, value):
             """
             Sets the value of :py:attr:`regParam`.
             """
    -        self.paramMap[self.regParam] = value
    +        self._paramMap[self.regParam] = value
             return self
     
         def getRegParam(self):
    @@ -90,14 +86,13 @@ def __init__(self):
             super(HasFeaturesCol, self).__init__()
             #: param for features column name
             self.featuresCol = Param(self, "featuresCol", "features column name")
    -        if 'features' is not None:
    -            self._setDefault(featuresCol='features')
    +        self._setDefault(featuresCol='features')
     
         def setFeaturesCol(self, value):
             """
             Sets the value of :py:attr:`featuresCol`.
             """
    -        self.paramMap[self.featuresCol] = value
    +        self._paramMap[self.featuresCol] = value
             return self
     
         def getFeaturesCol(self):
    @@ -119,14 +114,13 @@ def __init__(self):
             super(HasLabelCol, self).__init__()
             #: param for label column name
             self.labelCol = Param(self, "labelCol", "label column name")
    -        if 'label' is not None:
    -            self._setDefault(labelCol='label')
    +        self._setDefault(labelCol='label')
     
         def setLabelCol(self, value):
             """
             Sets the value of :py:attr:`labelCol`.
             """
    -        self.paramMap[self.labelCol] = value
    +        self._paramMap[self.labelCol] = value
             return self
     
         def getLabelCol(self):
    @@ -148,14 +142,13 @@ def __init__(self):
             super(HasPredictionCol, self).__init__()
             #: param for prediction column name
             self.predictionCol = Param(self, "predictionCol", "prediction column name")
    -        if 'prediction' is not None:
    -            self._setDefault(predictionCol='prediction')
    +        self._setDefault(predictionCol='prediction')
     
         def setPredictionCol(self, value):
             """
             Sets the value of :py:attr:`predictionCol`.
             """
    -        self.paramMap[self.predictionCol] = value
    +        self._paramMap[self.predictionCol] = value
             return self
     
         def getPredictionCol(self):
    @@ -165,6 +158,34 @@ def getPredictionCol(self):
             return self.getOrDefault(self.predictionCol)
     
     
    +class HasProbabilityCol(Params):
    +    """
    +    Mixin for param probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities..
    +    """
    +
    +    # a placeholder to make it appear in the generated doc
    +    probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.")
    +
    +    def __init__(self):
    +        super(HasProbabilityCol, self).__init__()
    +        #: param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.
    +        self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.")
    +        self._setDefault(probabilityCol='probability')
    +
    +    def setProbabilityCol(self, value):
    +        """
    +        Sets the value of :py:attr:`probabilityCol`.
    +        """
    +        self._paramMap[self.probabilityCol] = value
    +        return self
    +
    +    def getProbabilityCol(self):
    +        """
    +        Gets the value of probabilityCol or its default value.
    +        """
    +        return self.getOrDefault(self.probabilityCol)
    +
    +
     class HasRawPredictionCol(Params):
         """
         Mixin for param rawPredictionCol: raw prediction (a.k.a. confidence) column name.
    @@ -177,14 +198,13 @@ def __init__(self):
             super(HasRawPredictionCol, self).__init__()
             #: param for raw prediction (a.k.a. confidence) column name
             self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name")
    -        if 'rawPrediction' is not None:
    -            self._setDefault(rawPredictionCol='rawPrediction')
    +        self._setDefault(rawPredictionCol='rawPrediction')
     
         def setRawPredictionCol(self, value):
             """
             Sets the value of :py:attr:`rawPredictionCol`.
             """
    -        self.paramMap[self.rawPredictionCol] = value
    +        self._paramMap[self.rawPredictionCol] = value
             return self
     
         def getRawPredictionCol(self):
    @@ -206,14 +226,12 @@ def __init__(self):
             super(HasInputCol, self).__init__()
             #: param for input column name
             self.inputCol = Param(self, "inputCol", "input column name")
    -        if None is not None:
    -            self._setDefault(inputCol=None)
     
         def setInputCol(self, value):
             """
             Sets the value of :py:attr:`inputCol`.
             """
    -        self.paramMap[self.inputCol] = value
    +        self._paramMap[self.inputCol] = value
             return self
     
         def getInputCol(self):
    @@ -235,14 +253,12 @@ def __init__(self):
             super(HasInputCols, self).__init__()
             #: param for input column names
             self.inputCols = Param(self, "inputCols", "input column names")
    -        if None is not None:
    -            self._setDefault(inputCols=None)
     
         def setInputCols(self, value):
             """
             Sets the value of :py:attr:`inputCols`.
             """
    -        self.paramMap[self.inputCols] = value
    +        self._paramMap[self.inputCols] = value
             return self
     
         def getInputCols(self):
    @@ -264,14 +280,13 @@ def __init__(self):
             super(HasOutputCol, self).__init__()
             #: param for output column name
             self.outputCol = Param(self, "outputCol", "output column name")
    -        if None is not None:
    -            self._setDefault(outputCol=None)
    +        self._setDefault(outputCol=self.uid + '__output')
     
         def setOutputCol(self, value):
             """
             Sets the value of :py:attr:`outputCol`.
             """
    -        self.paramMap[self.outputCol] = value
    +        self._paramMap[self.outputCol] = value
             return self
     
         def getOutputCol(self):
    @@ -293,14 +308,12 @@ def __init__(self):
             super(HasNumFeatures, self).__init__()
             #: param for number of features
             self.numFeatures = Param(self, "numFeatures", "number of features")
    -        if None is not None:
    -            self._setDefault(numFeatures=None)
     
         def setNumFeatures(self, value):
             """
             Sets the value of :py:attr:`numFeatures`.
             """
    -        self.paramMap[self.numFeatures] = value
    +        self._paramMap[self.numFeatures] = value
             return self
     
         def getNumFeatures(self):
    @@ -322,14 +335,12 @@ def __init__(self):
             super(HasCheckpointInterval, self).__init__()
             #: param for checkpoint interval (>= 1)
             self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1)")
    -        if None is not None:
    -            self._setDefault(checkpointInterval=None)
     
         def setCheckpointInterval(self, value):
             """
             Sets the value of :py:attr:`checkpointInterval`.
             """
    -        self.paramMap[self.checkpointInterval] = value
    +        self._paramMap[self.checkpointInterval] = value
             return self
     
         def getCheckpointInterval(self):
    @@ -351,14 +362,13 @@ def __init__(self):
             super(HasSeed, self).__init__()
             #: param for random seed
             self.seed = Param(self, "seed", "random seed")
    -        if None is not None:
    -            self._setDefault(seed=None)
    +        self._setDefault(seed=hash(type(self).__name__))
     
         def setSeed(self, value):
             """
             Sets the value of :py:attr:`seed`.
             """
    -        self.paramMap[self.seed] = value
    +        self._paramMap[self.seed] = value
             return self
     
         def getSeed(self):
    @@ -380,14 +390,12 @@ def __init__(self):
             super(HasTol, self).__init__()
             #: param for the convergence tolerance for iterative algorithms
             self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms")
    -        if None is not None:
    -            self._setDefault(tol=None)
     
         def setTol(self, value):
             """
             Sets the value of :py:attr:`tol`.
             """
    -        self.paramMap[self.tol] = value
    +        self._paramMap[self.tol] = value
             return self
     
         def getTol(self):
    @@ -409,14 +417,12 @@ def __init__(self):
             super(HasStepSize, self).__init__()
             #: param for Step size to be used for each iteration of optimization.
             self.stepSize = Param(self, "stepSize", "Step size to be used for each iteration of optimization.")
    -        if None is not None:
    -            self._setDefault(stepSize=None)
     
         def setStepSize(self, value):
             """
             Sets the value of :py:attr:`stepSize`.
             """
    -        self.paramMap[self.stepSize] = value
    +        self._paramMap[self.stepSize] = value
             return self
     
         def getStepSize(self):
    @@ -438,6 +444,7 @@ class DecisionTreeParams(Params):
         minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.")
         maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
         cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")
    +    
     
         def __init__(self):
             super(DecisionTreeParams, self).__init__()
    @@ -453,12 +460,12 @@ def __init__(self):
             self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
             #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.
             self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")
    -
    +        
         def setMaxDepth(self, value):
             """
             Sets the value of :py:attr:`maxDepth`.
             """
    -        self.paramMap[self.maxDepth] = value
    +        self._paramMap[self.maxDepth] = value
             return self
     
         def getMaxDepth(self):
    @@ -471,7 +478,7 @@ def setMaxBins(self, value):
             """
             Sets the value of :py:attr:`maxBins`.
             """
    -        self.paramMap[self.maxBins] = value
    +        self._paramMap[self.maxBins] = value
             return self
     
         def getMaxBins(self):
    @@ -484,7 +491,7 @@ def setMinInstancesPerNode(self, value):
             """
             Sets the value of :py:attr:`minInstancesPerNode`.
             """
    -        self.paramMap[self.minInstancesPerNode] = value
    +        self._paramMap[self.minInstancesPerNode] = value
             return self
     
         def getMinInstancesPerNode(self):
    @@ -497,7 +504,7 @@ def setMinInfoGain(self, value):
             """
             Sets the value of :py:attr:`minInfoGain`.
             """
    -        self.paramMap[self.minInfoGain] = value
    +        self._paramMap[self.minInfoGain] = value
             return self
     
         def getMinInfoGain(self):
    @@ -510,7 +517,7 @@ def setMaxMemoryInMB(self, value):
             """
             Sets the value of :py:attr:`maxMemoryInMB`.
             """
    -        self.paramMap[self.maxMemoryInMB] = value
    +        self._paramMap[self.maxMemoryInMB] = value
             return self
     
         def getMaxMemoryInMB(self):
    @@ -523,7 +530,7 @@ def setCacheNodeIds(self, value):
             """
             Sets the value of :py:attr:`cacheNodeIds`.
             """
    -        self.paramMap[self.cacheNodeIds] = value
    +        self._paramMap[self.cacheNodeIds] = value
             return self
     
         def getCacheNodeIds(self):
    diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
    index a328bcf84a2e7..9889f56cac9e4 100644
    --- a/python/pyspark/ml/pipeline.py
    +++ b/python/pyspark/ml/pipeline.py
    @@ -31,18 +31,42 @@ class Estimator(Params):
         __metaclass__ = ABCMeta
     
         @abstractmethod
    -    def fit(self, dataset, params={}):
    +    def _fit(self, dataset):
             """
    -        Fits a model to the input dataset with optional parameters.
    +        Fits a model to the input dataset. This is called by the
    +        default implementation of fit.
     
             :param dataset: input dataset, which is an instance of
                             :py:class:`pyspark.sql.DataFrame`
    -        :param params: an optional param map that overwrites embedded
    -                       params
             :returns: fitted model
             """
             raise NotImplementedError()
     
    +    def fit(self, dataset, params=None):
    +        """
    +        Fits a model to the input dataset with optional parameters.
    +
    +        :param dataset: input dataset, which is an instance of
    +                        :py:class:`pyspark.sql.DataFrame`
    +        :param params: an optional param map that overrides embedded
    +                       params. If a list/tuple of param maps is given,
    +                       this calls fit on each param map and returns a
    +                       list of models.
    +        :returns: fitted model(s)
    +        """
    +        if params is None:
    +            params = dict()
    +        if isinstance(params, (list, tuple)):
    +            return [self.fit(dataset, paramMap) for paramMap in params]
    +        elif isinstance(params, dict):
    +            if params:
    +                return self.copy(params)._fit(dataset)
    +            else:
    +                return self._fit(dataset)
    +        else:
    +            raise ValueError("Params must be either a param map or a list/tuple of param maps, "
    +                             "but got %s." % type(params))
    +
     
     @inherit_doc
     class Transformer(Params):
    @@ -54,18 +78,36 @@ class Transformer(Params):
         __metaclass__ = ABCMeta
     
         @abstractmethod
    -    def transform(self, dataset, params={}):
    +    def _transform(self, dataset):
             """
             Transforms the input dataset with optional parameters.
     
             :param dataset: input dataset, which is an instance of
                             :py:class:`pyspark.sql.DataFrame`
    -        :param params: an optional param map that overwrites embedded
    -                       params
             :returns: transformed dataset
             """
             raise NotImplementedError()
     
    +    def transform(self, dataset, params=None):
    +        """
    +        Transforms the input dataset with optional parameters.
    +
    +        :param dataset: input dataset, which is an instance of
    +                        :py:class:`pyspark.sql.DataFrame`
    +        :param params: an optional param map that overrides embedded
    +                       params.
    +        :returns: transformed dataset
    +        """
    +        if params is None:
    +            params = dict()
    +        if isinstance(params, dict):
    +            if params:
    +                return self.copy(params,)._transform(dataset)
    +            else:
    +                return self._transform(dataset)
    +        else:
    +            raise ValueError("Params must be either a param map but got %s." % type(params))
    +
     
     @inherit_doc
     class Model(Transformer):
    @@ -97,10 +139,12 @@ class Pipeline(Estimator):
         """
     
         @keyword_only
    -    def __init__(self, stages=[]):
    +    def __init__(self, stages=None):
             """
             __init__(self, stages=[])
             """
    +        if stages is None:
    +            stages = []
             super(Pipeline, self).__init__()
             #: Param for pipeline stages.
             self.stages = Param(self, "stages", "pipeline stages")
    @@ -113,28 +157,29 @@ def setStages(self, value):
             :param value: a list of transformers or estimators
             :return: the pipeline instance
             """
    -        self.paramMap[self.stages] = value
    +        self._paramMap[self.stages] = value
             return self
     
         def getStages(self):
             """
             Get pipeline stages.
             """
    -        if self.stages in self.paramMap:
    -            return self.paramMap[self.stages]
    +        if self.stages in self._paramMap:
    +            return self._paramMap[self.stages]
     
         @keyword_only
    -    def setParams(self, stages=[]):
    +    def setParams(self, stages=None):
             """
             setParams(self, stages=[])
             Sets params for Pipeline.
             """
    +        if stages is None:
    +            stages = []
             kwargs = self.setParams._input_kwargs
             return self._set(**kwargs)
     
    -    def fit(self, dataset, params={}):
    -        paramMap = self.extractParamMap(params)
    -        stages = paramMap[self.stages]
    +    def _fit(self, dataset):
    +        stages = self.getStages()
             for stage in stages:
                 if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
                     raise TypeError(
    @@ -148,16 +193,23 @@ def fit(self, dataset, params={}):
                 if i <= indexOfLastEstimator:
                     if isinstance(stage, Transformer):
                         transformers.append(stage)
    -                    dataset = stage.transform(dataset, paramMap)
    +                    dataset = stage.transform(dataset)
                     else:  # must be an Estimator
    -                    model = stage.fit(dataset, paramMap)
    +                    model = stage.fit(dataset)
                         transformers.append(model)
                         if i < indexOfLastEstimator:
    -                        dataset = model.transform(dataset, paramMap)
    +                        dataset = model.transform(dataset)
                 else:
                     transformers.append(stage)
             return PipelineModel(transformers)
     
    +    def copy(self, extra=None):
    +        if extra is None:
    +            extra = dict()
    +        that = Params.copy(self, extra)
    +        stages = [stage.copy(extra) for stage in that.getStages()]
    +        return that.setStages(stages)
    +
     
     @inherit_doc
     class PipelineModel(Model):
    @@ -165,33 +217,17 @@ class PipelineModel(Model):
         Represents a compiled pipeline with transformers and fitted models.
         """
     
    -    def __init__(self, transformers):
    +    def __init__(self, stages):
             super(PipelineModel, self).__init__()
    -        self.transformers = transformers
    +        self.stages = stages
     
    -    def transform(self, dataset, params={}):
    -        paramMap = self.extractParamMap(params)
    -        for t in self.transformers:
    -            dataset = t.transform(dataset, paramMap)
    +    def _transform(self, dataset):
    +        for t in self.stages:
    +            dataset = t.transform(dataset)
             return dataset
     
    -
    -class Evaluator(Params):
    -    """
    -    Base class for evaluators that compute metrics from predictions.
    -    """
    -
    -    __metaclass__ = ABCMeta
    -
    -    @abstractmethod
    -    def evaluate(self, dataset, params={}):
    -        """
    -        Evaluates the output.
    -
    -        :param dataset: a dataset that contains labels/observations and
    -                        predictions
    -        :param params: an optional param map that overrides embedded
    -                       params
    -        :return: metric
    -        """
    -        raise NotImplementedError()
    +    def copy(self, extra=None):
    +        if extra is None:
    +            extra = dict()
    +        stages = [stage.copy(extra) for stage in self.stages]
    +        return PipelineModel(stages)
    diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
    index 4846b907e85ec..b06099ac0aee6 100644
    --- a/python/pyspark/ml/recommendation.py
    +++ b/python/pyspark/ml/recommendation.py
    @@ -63,8 +63,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
         indicated user preferences rather than explicit ratings given to
         items.
     
    +    >>> df = sqlContext.createDataFrame(
    +    ...     [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
    +    ...     ["user", "item", "rating"])
         >>> als = ALS(rank=10, maxIter=5)
         >>> model = als.fit(df)
    +    >>> model.rank
    +    10
    +    >>> model.userFactors.orderBy("id").collect()
    +    [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
         >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
         >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
         >>> predictions[0]
    @@ -74,7 +81,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
         >>> predictions[2]
         Row(user=2, item=0, prediction=-1.15...)
         """
    -    _java_class = "org.apache.spark.ml.recommendation.ALS"
    +
         # a placeholder to make it appear in the generated doc
         rank = Param(Params._dummy(), "rank", "rank of the factorization")
         numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks")
    @@ -89,14 +96,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
     
         @keyword_only
         def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
    -                 implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0,
    +                 implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
                      ratingCol="rating", nonnegative=False, checkpointInterval=10):
             """
    -        __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
    -                 implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=0,
    +        __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
    +                 implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \
                      ratingCol="rating", nonnegative=false, checkpointInterval=10)
             """
             super(ALS, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
             self.rank = Param(self, "rank", "rank of the factorization")
             self.numUserBlocks = Param(self, "numUserBlocks", "number of user blocks")
             self.numItemBlocks = Param(self, "numItemBlocks", "number of item blocks")
    @@ -108,18 +116,18 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB
             self.nonnegative = Param(self, "nonnegative",
                                      "whether to use nonnegative constraint for least squares")
             self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
    -                         implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0,
    +                         implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
                              ratingCol="rating", nonnegative=False, checkpointInterval=10)
             kwargs = self.__init__._input_kwargs
             self.setParams(**kwargs)
     
         @keyword_only
         def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
    -                  implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0,
    +                  implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
                       ratingCol="rating", nonnegative=False, checkpointInterval=10):
             """
    -        setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
    -                 implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0,
    +        setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
    +                 implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \
                      ratingCol="rating", nonnegative=False, checkpointInterval=10)
             Sets params for ALS.
             """
    @@ -133,7 +141,7 @@ def setRank(self, value):
             """
             Sets the value of :py:attr:`rank`.
             """
    -        self.paramMap[self.rank] = value
    +        self._paramMap[self.rank] = value
             return self
     
         def getRank(self):
    @@ -146,7 +154,7 @@ def setNumUserBlocks(self, value):
             """
             Sets the value of :py:attr:`numUserBlocks`.
             """
    -        self.paramMap[self.numUserBlocks] = value
    +        self._paramMap[self.numUserBlocks] = value
             return self
     
         def getNumUserBlocks(self):
    @@ -159,7 +167,7 @@ def setNumItemBlocks(self, value):
             """
             Sets the value of :py:attr:`numItemBlocks`.
             """
    -        self.paramMap[self.numItemBlocks] = value
    +        self._paramMap[self.numItemBlocks] = value
             return self
     
         def getNumItemBlocks(self):
    @@ -172,14 +180,14 @@ def setNumBlocks(self, value):
             """
             Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value.
             """
    -        self.paramMap[self.numUserBlocks] = value
    -        self.paramMap[self.numItemBlocks] = value
    +        self._paramMap[self.numUserBlocks] = value
    +        self._paramMap[self.numItemBlocks] = value
     
         def setImplicitPrefs(self, value):
             """
             Sets the value of :py:attr:`implicitPrefs`.
             """
    -        self.paramMap[self.implicitPrefs] = value
    +        self._paramMap[self.implicitPrefs] = value
             return self
     
         def getImplicitPrefs(self):
    @@ -192,7 +200,7 @@ def setAlpha(self, value):
             """
             Sets the value of :py:attr:`alpha`.
             """
    -        self.paramMap[self.alpha] = value
    +        self._paramMap[self.alpha] = value
             return self
     
         def getAlpha(self):
    @@ -205,7 +213,7 @@ def setUserCol(self, value):
             """
             Sets the value of :py:attr:`userCol`.
             """
    -        self.paramMap[self.userCol] = value
    +        self._paramMap[self.userCol] = value
             return self
     
         def getUserCol(self):
    @@ -218,7 +226,7 @@ def setItemCol(self, value):
             """
             Sets the value of :py:attr:`itemCol`.
             """
    -        self.paramMap[self.itemCol] = value
    +        self._paramMap[self.itemCol] = value
             return self
     
         def getItemCol(self):
    @@ -231,7 +239,7 @@ def setRatingCol(self, value):
             """
             Sets the value of :py:attr:`ratingCol`.
             """
    -        self.paramMap[self.ratingCol] = value
    +        self._paramMap[self.ratingCol] = value
             return self
     
         def getRatingCol(self):
    @@ -244,7 +252,7 @@ def setNonnegative(self, value):
             """
             Sets the value of :py:attr:`nonnegative`.
             """
    -        self.paramMap[self.nonnegative] = value
    +        self._paramMap[self.nonnegative] = value
             return self
     
         def getNonnegative(self):
    @@ -259,6 +267,27 @@ class ALSModel(JavaModel):
         Model fitted by ALS.
         """
     
    +    @property
    +    def rank(self):
    +        """rank of the matrix factorization model"""
    +        return self._call_java("rank")
    +
    +    @property
    +    def userFactors(self):
    +        """
    +        a DataFrame that stores user factors in two columns: `id` and
    +        `features`
    +        """
    +        return self._call_java("userFactors")
    +
    +    @property
    +    def itemFactors(self):
    +        """
    +        a DataFrame that stores item factors in two columns: `id` and
    +        `features`
    +        """
    +        return self._call_java("itemFactors")
    +
     
     if __name__ == "__main__":
         import doctest
    @@ -271,8 +300,6 @@ class ALSModel(JavaModel):
         sqlContext = SQLContext(sc)
         globs['sc'] = sc
         globs['sqlContext'] = sqlContext
    -    globs['df'] = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0),
    -                                              (2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"])
         (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
         sc.stop()
         if failure_count:
    diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
    index 0ab5c6c3d20c3..44f60a769566d 100644
    --- a/python/pyspark/ml/regression.py
    +++ b/python/pyspark/ml/regression.py
    @@ -33,8 +33,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
         Linear regression.
     
         The learning objective is to minimize the squared error, with regularization.
    -    The specific squared error loss function used is:
    -      L = 1/2n ||A weights - y||^2^
    +    The specific squared error loss function used is: L = 1/2n ||A weights - y||^2^
     
         This support multiple types of regularization:
          - none (a.k.a. ordinary least squares)
    @@ -51,6 +50,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
         >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
         >>> model.transform(test0).head().prediction
         -1.0
    +    >>> model.weights
    +    DenseVector([1.0])
    +    >>> model.intercept
    +    0.0
         >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
         >>> model.transform(test1).head().prediction
         1.0
    @@ -59,7 +62,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
             ...
         TypeError: Method setParams forces keyword arguments.
         """
    -    _java_class = "org.apache.spark.ml.regression.LinearRegression"
    +
         # a placeholder to make it appear in the generated doc
         elasticNetParam = \
             Param(Params._dummy(), "elasticNetParam",
    @@ -74,6 +77,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
                      maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6)
             """
             super(LinearRegression, self).__init__()
    +        self._java_obj = self._new_java_obj(
    +            "org.apache.spark.ml.regression.LinearRegression", self.uid)
             #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty
             #  is an L2 penalty. For alpha = 1, it is an L1 penalty.
             self.elasticNetParam = \
    @@ -102,7 +107,7 @@ def setElasticNetParam(self, value):
             """
             Sets the value of :py:attr:`elasticNetParam`.
             """
    -        self.paramMap[self.elasticNetParam] = value
    +        self._paramMap[self.elasticNetParam] = value
             return self
     
         def getElasticNetParam(self):
    @@ -117,6 +122,20 @@ class LinearRegressionModel(JavaModel):
         Model fitted by LinearRegression.
         """
     
    +    @property
    +    def weights(self):
    +        """
    +        Model weights.
    +        """
    +        return self._call_java("weights")
    +
    +    @property
    +    def intercept(self):
    +        """
    +        Model intercept.
    +        """
    +        return self._call_java("intercept")
    +
     
     class TreeRegressorParams(object):
         """
    @@ -153,6 +172,10 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
         ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
         >>> dt = DecisionTreeRegressor(maxDepth=2)
         >>> model = dt.fit(df)
    +    >>> model.depth
    +    1
    +    >>> model.numNodes
    +    3
         >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
         >>> model.transform(test0).head().prediction
         0.0
    @@ -161,7 +184,6 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
         1.0
         """
     
    -    _java_class = "org.apache.spark.ml.regression.DecisionTreeRegressor"
         # a placeholder to make it appear in the generated doc
         impurity = Param(Params._dummy(), "impurity",
                          "Criterion used for information gain calculation (case-insensitive). " +
    @@ -173,10 +195,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
                      maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance"):
             """
             __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    -                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    +                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
                      maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance")
             """
             super(DecisionTreeRegressor, self).__init__()
    +        self._java_obj = self._new_java_obj(
    +            "org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid)
             #: param for Criterion used for information gain calculation (case-insensitive).
             self.impurity = \
                 Param(self, "impurity",
    @@ -195,9 +219,8 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
                       impurity="variance"):
             """
             setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    -                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    -                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
    -                  impurity="variance")
    +                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
    +                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance")
             Sets params for the DecisionTreeRegressor.
             """
             kwargs = self.setParams._input_kwargs
    @@ -210,7 +233,7 @@ def setImpurity(self, value):
             """
             Sets the value of :py:attr:`impurity`.
             """
    -        self.paramMap[self.impurity] = value
    +        self._paramMap[self.impurity] = value
             return self
     
         def getImpurity(self):
    @@ -220,7 +243,37 @@ def getImpurity(self):
             return self.getOrDefault(self.impurity)
     
     
    -class DecisionTreeRegressionModel(JavaModel):
    +@inherit_doc
    +class DecisionTreeModel(JavaModel):
    +
    +    @property
    +    def numNodes(self):
    +        """Return number of nodes of the decision tree."""
    +        return self._call_java("numNodes")
    +
    +    @property
    +    def depth(self):
    +        """Return depth of the decision tree."""
    +        return self._call_java("depth")
    +
    +    def __repr__(self):
    +        return self._call_java("toString")
    +
    +
    +@inherit_doc
    +class TreeEnsembleModels(JavaModel):
    +
    +    @property
    +    def treeWeights(self):
    +        """Return the weights for each tree"""
    +        return list(self._call_java("javaTreeWeights"))
    +
    +    def __repr__(self):
    +        return self._call_java("toString")
    +
    +
    +@inherit_doc
    +class DecisionTreeRegressionModel(DecisionTreeModel):
         """
         Model fitted by DecisionTreeRegressor.
         """
    @@ -234,12 +287,15 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
         learning algorithm for regression.
         It supports both continuous and categorical features.
     
    +    >>> from numpy import allclose
         >>> from pyspark.mllib.linalg import Vectors
         >>> df = sqlContext.createDataFrame([
         ...     (1.0, Vectors.dense(1.0)),
         ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
    -    >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2)
    +    >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
         >>> model = rf.fit(df)
    +    >>> allclose(model.treeWeights, [1.0, 1.0])
    +    True
         >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
         >>> model.transform(test0).head().prediction
         0.0
    @@ -248,7 +304,6 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
         0.5
         """
     
    -    _java_class = "org.apache.spark.ml.regression.RandomForestRegressor"
         # a placeholder to make it appear in the generated doc
         impurity = Param(Params._dummy(), "impurity",
                          "Criterion used for information gain calculation (case-insensitive). " +
    @@ -266,14 +321,17 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
         def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                      maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
                      maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
    -                 numTrees=20, featureSubsetStrategy="auto", seed=42):
    +                 numTrees=20, featureSubsetStrategy="auto", seed=None):
             """
    -        __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
    -                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    -                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
    -                 numTrees=20, featureSubsetStrategy="auto", seed=42)
    +        __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    +                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
    +                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
    +                 impurity="variance", numTrees=20, \
    +                 featureSubsetStrategy="auto", seed=None)
             """
             super(RandomForestRegressor, self).__init__()
    +        self._java_obj = self._new_java_obj(
    +            "org.apache.spark.ml.regression.RandomForestRegressor", self.uid)
             #: param for Criterion used for information gain calculation (case-insensitive).
             self.impurity = \
                 Param(self, "impurity",
    @@ -292,7 +350,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
                       "The number of features to consider for splits at each tree node. Supported " +
                       "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies))
             self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    -                         maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42,
    +                         maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
                              impurity="variance", numTrees=20, featureSubsetStrategy="auto")
             kwargs = self.__init__._input_kwargs
             self.setParams(**kwargs)
    @@ -300,12 +358,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
         @keyword_only
         def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                       maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    -                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42,
    +                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
                       impurity="variance", numTrees=20, featureSubsetStrategy="auto"):
             """
    -        setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
    -                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    -                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42,
    +        setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    +                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
    +                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
                       impurity="variance", numTrees=20, featureSubsetStrategy="auto")
             Sets params for linear regression.
             """
    @@ -319,7 +377,7 @@ def setImpurity(self, value):
             """
             Sets the value of :py:attr:`impurity`.
             """
    -        self.paramMap[self.impurity] = value
    +        self._paramMap[self.impurity] = value
             return self
     
         def getImpurity(self):
    @@ -332,7 +390,7 @@ def setSubsamplingRate(self, value):
             """
             Sets the value of :py:attr:`subsamplingRate`.
             """
    -        self.paramMap[self.subsamplingRate] = value
    +        self._paramMap[self.subsamplingRate] = value
             return self
     
         def getSubsamplingRate(self):
    @@ -345,7 +403,7 @@ def setNumTrees(self, value):
             """
             Sets the value of :py:attr:`numTrees`.
             """
    -        self.paramMap[self.numTrees] = value
    +        self._paramMap[self.numTrees] = value
             return self
     
         def getNumTrees(self):
    @@ -358,7 +416,7 @@ def setFeatureSubsetStrategy(self, value):
             """
             Sets the value of :py:attr:`featureSubsetStrategy`.
             """
    -        self.paramMap[self.featureSubsetStrategy] = value
    +        self._paramMap[self.featureSubsetStrategy] = value
             return self
     
         def getFeatureSubsetStrategy(self):
    @@ -368,7 +426,7 @@ def getFeatureSubsetStrategy(self):
             return self.getOrDefault(self.featureSubsetStrategy)
     
     
    -class RandomForestRegressionModel(JavaModel):
    +class RandomForestRegressionModel(TreeEnsembleModels):
         """
         Model fitted by RandomForestRegressor.
         """
    @@ -382,12 +440,15 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
         learning algorithm for regression.
         It supports both continuous and categorical features.
     
    +    >>> from numpy import allclose
         >>> from pyspark.mllib.linalg import Vectors
         >>> df = sqlContext.createDataFrame([
         ...     (1.0, Vectors.dense(1.0)),
         ...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
         >>> gbt = GBTRegressor(maxIter=5, maxDepth=2)
         >>> model = gbt.fit(df)
    +    >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
    +    True
         >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
         >>> model.transform(test0).head().prediction
         0.0
    @@ -396,7 +457,6 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
         1.0
         """
     
    -    _java_class = "org.apache.spark.ml.regression.GBTRegressor"
         # a placeholder to make it appear in the generated doc
         lossType = Param(Params._dummy(), "lossType",
                          "Loss function which GBT tries to minimize (case-insensitive). " +
    @@ -414,12 +474,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
                      maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="squared",
                      maxIter=20, stepSize=0.1):
             """
    -        __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
    -                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    -                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="squared",
    -                 maxIter=20, stepSize=0.1)
    +        __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    +                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
    +                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
    +                 lossType="squared", maxIter=20, stepSize=0.1)
             """
             super(GBTRegressor, self).__init__()
    +        self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
             #: param for Loss function which GBT tries to minimize (case-insensitive).
             self.lossType = Param(self, "lossType",
                                   "Loss function which GBT tries to minimize (case-insensitive). " +
    @@ -445,9 +506,9 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
                       maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
                       lossType="squared", maxIter=20, stepSize=0.1):
             """
    -        setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
    -                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
    -                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
    +        setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
    +                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
    +                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
                       lossType="squared", maxIter=20, stepSize=0.1)
             Sets params for Gradient Boosted Tree Regression.
             """
    @@ -461,7 +522,7 @@ def setLossType(self, value):
             """
             Sets the value of :py:attr:`lossType`.
             """
    -        self.paramMap[self.lossType] = value
    +        self._paramMap[self.lossType] = value
             return self
     
         def getLossType(self):
    @@ -474,7 +535,7 @@ def setSubsamplingRate(self, value):
             """
             Sets the value of :py:attr:`subsamplingRate`.
             """
    -        self.paramMap[self.subsamplingRate] = value
    +        self._paramMap[self.subsamplingRate] = value
             return self
     
         def getSubsamplingRate(self):
    @@ -487,7 +548,7 @@ def setStepSize(self, value):
             """
             Sets the value of :py:attr:`stepSize`.
             """
    -        self.paramMap[self.stepSize] = value
    +        self._paramMap[self.stepSize] = value
             return self
     
         def getStepSize(self):
    @@ -497,7 +558,7 @@ def getStepSize(self):
             return self.getOrDefault(self.stepSize)
     
     
    -class GBTRegressionModel(JavaModel):
    +class GBTRegressionModel(TreeEnsembleModels):
         """
         Model fitted by GBTRegressor.
         """
    diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
    index ba6478dcd58a9..c151d21fd661a 100644
    --- a/python/pyspark/ml/tests.py
    +++ b/python/pyspark/ml/tests.py
    @@ -31,10 +31,13 @@
         import unittest
     
     from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
    -from pyspark.sql import DataFrame
    -from pyspark.ml.param import Param
    -from pyspark.ml.param.shared import HasMaxIter, HasInputCol
    -from pyspark.ml.pipeline import Estimator, Model, Pipeline, Transformer
    +from pyspark.sql import DataFrame, SQLContext
    +from pyspark.ml.param import Param, Params
    +from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
    +from pyspark.ml.util import keyword_only
    +from pyspark.ml import Estimator, Model, Pipeline, Transformer
    +from pyspark.ml.feature import *
    +from pyspark.mllib.linalg import DenseVector
     
     
     class MockDataset(DataFrame):
    @@ -43,44 +46,43 @@ def __init__(self):
             self.index = 0
     
     
    -class MockTransformer(Transformer):
    +class HasFake(Params):
    +
    +    def __init__(self):
    +        super(HasFake, self).__init__()
    +        self.fake = Param(self, "fake", "fake param")
    +
    +    def getFake(self):
    +        return self.getOrDefault(self.fake)
    +
    +
    +class MockTransformer(Transformer, HasFake):
     
         def __init__(self):
             super(MockTransformer, self).__init__()
    -        self.fake = Param(self, "fake", "fake")
             self.dataset_index = None
    -        self.fake_param_value = None
     
    -    def transform(self, dataset, params={}):
    +    def _transform(self, dataset):
             self.dataset_index = dataset.index
    -        if self.fake in params:
    -            self.fake_param_value = params[self.fake]
             dataset.index += 1
             return dataset
     
     
    -class MockEstimator(Estimator):
    +class MockEstimator(Estimator, HasFake):
     
         def __init__(self):
             super(MockEstimator, self).__init__()
    -        self.fake = Param(self, "fake", "fake")
             self.dataset_index = None
    -        self.fake_param_value = None
    -        self.model = None
     
    -    def fit(self, dataset, params={}):
    +    def _fit(self, dataset):
             self.dataset_index = dataset.index
    -        if self.fake in params:
    -            self.fake_param_value = params[self.fake]
             model = MockModel()
    -        self.model = model
    +        self._copyValues(model)
             return model
     
     
    -class MockModel(MockTransformer, Model):
    -
    -    def __init__(self):
    -        super(MockModel, self).__init__()
    +class MockModel(MockTransformer, Model, HasFake):
    +    pass
     
     
     class PipelineTests(PySparkTestCase):
    @@ -91,19 +93,17 @@ def test_pipeline(self):
             transformer1 = MockTransformer()
             estimator2 = MockEstimator()
             transformer3 = MockTransformer()
    -        pipeline = Pipeline() \
    -            .setStages([estimator0, transformer1, estimator2, transformer3])
    +        pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3])
             pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1})
    -        self.assertEqual(0, estimator0.dataset_index)
    -        self.assertEqual(0, estimator0.fake_param_value)
    -        model0 = estimator0.model
    +        model0, transformer1, model2, transformer3 = pipeline_model.stages
             self.assertEqual(0, model0.dataset_index)
    +        self.assertEqual(0, model0.getFake())
             self.assertEqual(1, transformer1.dataset_index)
    -        self.assertEqual(1, transformer1.fake_param_value)
    -        self.assertEqual(2, estimator2.dataset_index)
    -        model2 = estimator2.model
    -        self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should "
    -                                                "not be called during fit.")
    +        self.assertEqual(1, transformer1.getFake())
    +        self.assertEqual(2, dataset.index)
    +        self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.")
    +        self.assertIsNone(transformer3.dataset_index,
    +                          "The last transformer shouldn't be called in fit.")
             dataset = pipeline_model.transform(dataset)
             self.assertEqual(2, model0.dataset_index)
             self.assertEqual(3, transformer1.dataset_index)
    @@ -112,14 +112,46 @@ def test_pipeline(self):
             self.assertEqual(6, dataset.index)
     
     
    -class TestParams(HasMaxIter, HasInputCol):
    +class TestParams(HasMaxIter, HasInputCol, HasSeed):
         """
    -    A subclass of Params mixed with HasMaxIter and HasInputCol.
    +    A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
         """
    -
    -    def __init__(self):
    +    @keyword_only
    +    def __init__(self, seed=None):
             super(TestParams, self).__init__()
             self._setDefault(maxIter=10)
    +        kwargs = self.__init__._input_kwargs
    +        self.setParams(**kwargs)
    +
    +    @keyword_only
    +    def setParams(self, seed=None):
    +        """
    +        setParams(self, seed=None)
    +        Sets params for this test.
    +        """
    +        kwargs = self.setParams._input_kwargs
    +        return self._set(**kwargs)
    +
    +
    +class OtherTestParams(HasMaxIter, HasInputCol, HasSeed):
    +    """
    +    A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
    +    """
    +    @keyword_only
    +    def __init__(self, seed=None):
    +        super(OtherTestParams, self).__init__()
    +        self._setDefault(maxIter=10)
    +        kwargs = self.__init__._input_kwargs
    +        self.setParams(**kwargs)
    +
    +    @keyword_only
    +    def setParams(self, seed=None):
    +        """
    +        setParams(self, seed=None)
    +        Sets params for this test.
    +        """
    +        kwargs = self.setParams._input_kwargs
    +        return self._set(**kwargs)
     
     
     class ParamTests(PySparkTestCase):
    @@ -129,16 +161,18 @@ def test_param(self):
             maxIter = testParams.maxIter
             self.assertEqual(maxIter.name, "maxIter")
             self.assertEqual(maxIter.doc, "max number of iterations (>= 0)")
    -        self.assertTrue(maxIter.parent is testParams)
    +        self.assertTrue(maxIter.parent == testParams.uid)
     
         def test_params(self):
             testParams = TestParams()
             maxIter = testParams.maxIter
             inputCol = testParams.inputCol
    +        seed = testParams.seed
     
             params = testParams.params
    -        self.assertEqual(params, [inputCol, maxIter])
    +        self.assertEqual(params, [inputCol, maxIter, seed])
     
    +        self.assertTrue(testParams.hasParam(maxIter))
             self.assertTrue(testParams.hasDefault(maxIter))
             self.assertFalse(testParams.isSet(maxIter))
             self.assertTrue(testParams.isDefined(maxIter))
    @@ -147,16 +181,87 @@ def test_params(self):
             self.assertTrue(testParams.isSet(maxIter))
             self.assertEquals(testParams.getMaxIter(), 100)
     
    +        self.assertTrue(testParams.hasParam(inputCol))
             self.assertFalse(testParams.hasDefault(inputCol))
             self.assertFalse(testParams.isSet(inputCol))
             self.assertFalse(testParams.isDefined(inputCol))
             with self.assertRaises(KeyError):
                 testParams.getInputCol()
     
    +        # Since the default is normally random, set it to a known number for debug str
    +        testParams._setDefault(seed=41)
    +        testParams.setSeed(43)
    +
             self.assertEquals(
                 testParams.explainParams(),
                 "\n".join(["inputCol: input column name (undefined)",
    -                       "maxIter: max number of iterations (>= 0) (default: 10, current: 100)"]))
    +                       "maxIter: max number of iterations (>= 0) (default: 10, current: 100)",
    +                       "seed: random seed (default: 41, current: 43)"]))
    +
    +    def test_hasseed(self):
    +        noSeedSpecd = TestParams()
    +        withSeedSpecd = TestParams(seed=42)
    +        other = OtherTestParams()
    +        # Check that we no longer use 42 as the magic number
    +        self.assertNotEqual(noSeedSpecd.getSeed(), 42)
    +        origSeed = noSeedSpecd.getSeed()
    +        # Check that we only compute the seed once
    +        self.assertEqual(noSeedSpecd.getSeed(), origSeed)
    +        # Check that a specified seed is honored
    +        self.assertEqual(withSeedSpecd.getSeed(), 42)
    +        # Check that a different class has a different seed
    +        self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed())
    +
    +
    +class FeatureTests(PySparkTestCase):
    +
    +    def test_binarizer(self):
    +        b0 = Binarizer()
    +        self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold])
    +        self.assertTrue(all([~b0.isSet(p) for p in b0.params]))
    +        self.assertTrue(b0.hasDefault(b0.threshold))
    +        self.assertEqual(b0.getThreshold(), 0.0)
    +        b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0)
    +        self.assertTrue(all([b0.isSet(p) for p in b0.params]))
    +        self.assertEqual(b0.getThreshold(), 1.0)
    +        self.assertEqual(b0.getInputCol(), "input")
    +        self.assertEqual(b0.getOutputCol(), "output")
    +
    +        b0c = b0.copy({b0.threshold: 2.0})
    +        self.assertEqual(b0c.uid, b0.uid)
    +        self.assertListEqual(b0c.params, b0.params)
    +        self.assertEqual(b0c.getThreshold(), 2.0)
    +
    +        b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output")
    +        self.assertNotEqual(b1.uid, b0.uid)
    +        self.assertEqual(b1.getThreshold(), 2.0)
    +        self.assertEqual(b1.getInputCol(), "input")
    +        self.assertEqual(b1.getOutputCol(), "output")
    +
    +    def test_idf(self):
    +        sqlContext = SQLContext(self.sc)
    +        dataset = sqlContext.createDataFrame([
    +            (DenseVector([1.0, 2.0]),),
    +            (DenseVector([0.0, 1.0]),),
    +            (DenseVector([3.0, 0.2]),)], ["tf"])
    +        idf0 = IDF(inputCol="tf")
    +        self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol])
    +        idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"})
    +        self.assertEqual(idf0m.uid, idf0.uid,
    +                         "Model should inherit the UID from its parent estimator.")
    +        output = idf0m.transform(dataset)
    +        self.assertIsNotNone(output.head().idf)
    +
    +    def test_ngram(self):
    +        sqlContext = SQLContext(self.sc)
    +        dataset = sqlContext.createDataFrame([
    +            ([["a", "b", "c", "d", "e"]])], ["input"])
    +        ngram0 = NGram(n=4, inputCol="input", outputCol="output")
    +        self.assertEqual(ngram0.getN(), 4)
    +        self.assertEqual(ngram0.getInputCol(), "input")
    +        self.assertEqual(ngram0.getOutputCol(), "output")
    +        transformedDF = ngram0.transform(dataset)
    +        self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"])
     
     
     if __name__ == "__main__":
    diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
    index 86f4dc7368be0..0bf988fd72f14 100644
    --- a/python/pyspark/ml/tuning.py
    +++ b/python/pyspark/ml/tuning.py
    @@ -91,20 +91,19 @@ class CrossValidator(Estimator):
         >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
         >>> from pyspark.mllib.linalg import Vectors
         >>> dataset = sqlContext.createDataFrame(
    -    ...     [(Vectors.dense([0.0, 1.0]), 0.0),
    -    ...      (Vectors.dense([1.0, 2.0]), 1.0),
    -    ...      (Vectors.dense([0.55, 3.0]), 0.0),
    -    ...      (Vectors.dense([0.45, 4.0]), 1.0),
    -    ...      (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
    +    ...     [(Vectors.dense([0.0]), 0.0),
    +    ...      (Vectors.dense([0.4]), 1.0),
    +    ...      (Vectors.dense([0.5]), 0.0),
    +    ...      (Vectors.dense([0.6]), 1.0),
    +    ...      (Vectors.dense([1.0]), 1.0)] * 10,
         ...     ["features", "label"])
         >>> lr = LogisticRegression()
    -    >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
    +    >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
         >>> evaluator = BinaryClassificationEvaluator()
         >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
    -    >>> # SPARK-7432: The following test is flaky.
    -    >>> # cvModel = cv.fit(dataset)
    -    >>> # expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
    -    >>> # cvModel.transform(dataset).collect() == expected.collect()
    +    >>> cvModel = cv.fit(dataset)
    +    >>> evaluator.evaluate(cvModel.transform(dataset))
    +    0.8333...
         """
     
         # a placeholder to make it appear in the generated doc
    @@ -155,7 +154,7 @@ def setEstimator(self, value):
             """
             Sets the value of :py:attr:`estimator`.
             """
    -        self.paramMap[self.estimator] = value
    +        self._paramMap[self.estimator] = value
             return self
     
         def getEstimator(self):
    @@ -168,7 +167,7 @@ def setEstimatorParamMaps(self, value):
             """
             Sets the value of :py:attr:`estimatorParamMaps`.
             """
    -        self.paramMap[self.estimatorParamMaps] = value
    +        self._paramMap[self.estimatorParamMaps] = value
             return self
     
         def getEstimatorParamMaps(self):
    @@ -181,7 +180,7 @@ def setEvaluator(self, value):
             """
             Sets the value of :py:attr:`evaluator`.
             """
    -        self.paramMap[self.evaluator] = value
    +        self._paramMap[self.evaluator] = value
             return self
     
         def getEvaluator(self):
    @@ -194,7 +193,7 @@ def setNumFolds(self, value):
             """
             Sets the value of :py:attr:`numFolds`.
             """
    -        self.paramMap[self.numFolds] = value
    +        self._paramMap[self.numFolds] = value
             return self
     
         def getNumFolds(self):
    @@ -203,13 +202,12 @@ def getNumFolds(self):
             """
             return self.getOrDefault(self.numFolds)
     
    -    def fit(self, dataset, params={}):
    -        paramMap = self.extractParamMap(params)
    -        est = paramMap[self.estimator]
    -        epm = paramMap[self.estimatorParamMaps]
    +    def _fit(self, dataset):
    +        est = self.getOrDefault(self.estimator)
    +        epm = self.getOrDefault(self.estimatorParamMaps)
             numModels = len(epm)
    -        eva = paramMap[self.evaluator]
    -        nFolds = paramMap[self.numFolds]
    +        eva = self.getOrDefault(self.evaluator)
    +        nFolds = self.getOrDefault(self.numFolds)
             h = 1.0 / nFolds
             randCol = self.uid + "_rand"
             df = dataset.select("*", rand(0).alias(randCol))
    @@ -229,6 +227,15 @@ def fit(self, dataset, params={}):
             bestModel = est.fit(dataset, epm[bestIndex])
             return CrossValidatorModel(bestModel)
     
    +    def copy(self, extra={}):
    +        newCV = Params.copy(self, extra)
    +        if self.isSet(self.estimator):
    +            newCV.setEstimator(self.getEstimator().copy(extra))
    +        # estimatorParamMaps remain the same
    +        if self.isSet(self.evaluator):
    +            newCV.setEvaluator(self.getEvaluator().copy(extra))
    +        return newCV
    +
     
     class CrossValidatorModel(Model):
         """
    @@ -240,8 +247,19 @@ def __init__(self, bestModel):
             #: best model from cross validation
             self.bestModel = bestModel
     
    -    def transform(self, dataset, params={}):
    -        return self.bestModel.transform(dataset, params)
    +    def _transform(self, dataset):
    +        return self.bestModel.transform(dataset)
    +
    +    def copy(self, extra={}):
    +        """
    +        Creates a copy of this instance with a randomly generated uid
    +        and some extra params. This copies the underlying bestModel,
    +        creates a deep copy of the embedded paramMap, and
    +        copies the embedded and extra parameters over.
    +        :param extra: Extra parameters to copy to the new instance
    +        :return: Copy of this instance
    +        """
    +        return CrossValidatorModel(self.bestModel.copy(extra))
     
     
     if __name__ == "__main__":
    diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
    index d3cb100a9efa5..cee9d67b05325 100644
    --- a/python/pyspark/ml/util.py
    +++ b/python/pyspark/ml/util.py
    @@ -39,9 +39,16 @@ class Identifiable(object):
         """
     
         def __init__(self):
    -        #: A unique id for the object. The default implementation
    -        #: concatenates the class name, "_", and 8 random hex chars.
    -        self.uid = type(self).__name__ + "_" + uuid.uuid4().hex[:8]
    +        #: A unique id for the object.
    +        self.uid = self._randomUID()
     
         def __repr__(self):
             return self.uid
    +
    +    @classmethod
    +    def _randomUID(cls):
    +        """
    +        Generate a unique id for the object. The default implementation
    +        concatenates the class name, "_", and 12 random hex chars.
    +        """
    +        return cls.__name__ + "_" + uuid.uuid4().hex[12:]
    diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
    index f5ac2a398642a..253705bde913e 100644
    --- a/python/pyspark/ml/wrapper.py
    +++ b/python/pyspark/ml/wrapper.py
    @@ -20,8 +20,8 @@
     from pyspark import SparkContext
     from pyspark.sql import DataFrame
     from pyspark.ml.param import Params
    -from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
    -from pyspark.mllib.common import inherit_doc
    +from pyspark.ml.pipeline import Estimator, Transformer, Model
    +from pyspark.mllib.common import inherit_doc, _java2py, _py2java
     
     
     def _jvm():
    @@ -45,46 +45,61 @@ class JavaWrapper(Params):
     
         __metaclass__ = ABCMeta
     
    -    #: Fully-qualified class name of the wrapped Java component.
    -    _java_class = None
    +    #: The wrapped Java companion object. Subclasses should initialize
    +    #: it properly. The param values in the Java object should be
    +    #: synced with the Python wrapper in fit/transform/evaluate/copy.
    +    _java_obj = None
     
    -    def _java_obj(self):
    +    @staticmethod
    +    def _new_java_obj(java_class, *args):
             """
    -        Returns or creates a Java object.
    +        Construct a new Java object.
             """
    +        sc = SparkContext._active_spark_context
             java_obj = _jvm()
    -        for name in self._java_class.split("."):
    +        for name in java_class.split("."):
                 java_obj = getattr(java_obj, name)
    -        return java_obj()
    +        java_args = [_py2java(sc, arg) for arg in args]
    +        return java_obj(*java_args)
     
    -    def _transfer_params_to_java(self, params, java_obj):
    +    def _make_java_param_pair(self, param, value):
             """
    -        Transforms the embedded params and additional params to the
    -        input Java object.
    -        :param params: additional params (overwriting embedded values)
    -        :param java_obj: Java object to receive the params
    +        Makes a Java parm pair.
    +        """
    +        sc = SparkContext._active_spark_context
    +        param = self._resolveParam(param)
    +        java_param = self._java_obj.getParam(param.name)
    +        java_value = _py2java(sc, value)
    +        return java_param.w(java_value)
    +
    +    def _transfer_params_to_java(self):
             """
    -        paramMap = self.extractParamMap(params)
    +        Transforms the embedded params to the companion Java object.
    +        """
    +        paramMap = self.extractParamMap()
             for param in self.params:
                 if param in paramMap:
    -                value = paramMap[param]
    -                java_param = java_obj.getParam(param.name)
    -                java_obj.set(java_param.w(value))
    +                pair = self._make_java_param_pair(param, paramMap[param])
    +                self._java_obj.set(pair)
    +
    +    def _transfer_params_from_java(self):
    +        """
    +        Transforms the embedded params from the companion Java object.
    +        """
    +        sc = SparkContext._active_spark_context
    +        for param in self.params:
    +            if self._java_obj.hasParam(param.name):
    +                java_param = self._java_obj.getParam(param.name)
    +                value = _java2py(sc, self._java_obj.getOrDefault(java_param))
    +                self._paramMap[param] = value
     
    -    def _empty_java_param_map(self):
    +    @staticmethod
    +    def _empty_java_param_map():
             """
             Returns an empty Java ParamMap reference.
             """
             return _jvm().org.apache.spark.ml.param.ParamMap()
     
    -    def _create_java_param_map(self, params, java_obj):
    -        paramMap = self._empty_java_param_map()
    -        for param, value in params.items():
    -            if param.parent is self:
    -                java_param = java_obj.getParam(param.name)
    -                paramMap.put(java_param.w(value))
    -        return paramMap
    -
     
     @inherit_doc
     class JavaEstimator(Estimator, JavaWrapper):
    @@ -99,9 +114,9 @@ def _create_model(self, java_model):
             """
             Creates a model from the input Java model reference.
             """
    -        return JavaModel(java_model)
    +        raise NotImplementedError()
     
    -    def _fit_java(self, dataset, params={}):
    +    def _fit_java(self, dataset):
             """
             Fits a Java model to the input dataset.
             :param dataset: input dataset, which is an instance of
    @@ -109,12 +124,11 @@ def _fit_java(self, dataset, params={}):
             :param params: additional params (overwriting embedded values)
             :return: fitted Java model
             """
    -        java_obj = self._java_obj()
    -        self._transfer_params_to_java(params, java_obj)
    -        return java_obj.fit(dataset._jdf, self._empty_java_param_map())
    +        self._transfer_params_to_java()
    +        return self._java_obj.fit(dataset._jdf)
     
    -    def fit(self, dataset, params={}):
    -        java_model = self._fit_java(dataset, params)
    +    def _fit(self, dataset):
    +        java_model = self._fit_java(dataset)
             return self._create_model(java_model)
     
     
    @@ -127,39 +141,49 @@ class JavaTransformer(Transformer, JavaWrapper):
     
         __metaclass__ = ABCMeta
     
    -    def transform(self, dataset, params={}):
    -        java_obj = self._java_obj()
    -        self._transfer_params_to_java(params, java_obj)
    -        return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx)
    +    def _transform(self, dataset):
    +        self._transfer_params_to_java()
    +        return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
     
     
     @inherit_doc
     class JavaModel(Model, JavaTransformer):
         """
         Base class for :py:class:`Model`s that wrap Java/Scala
    -    implementations.
    +    implementations. Subclasses should inherit this class before
    +    param mix-ins, because this sets the UID from the Java model.
         """
     
         __metaclass__ = ABCMeta
     
         def __init__(self, java_model):
    -        super(JavaTransformer, self).__init__()
    -        self._java_model = java_model
    -
    -    def _java_obj(self):
    -        return self._java_model
    -
    -
    -@inherit_doc
    -class JavaEvaluator(Evaluator, JavaWrapper):
    -    """
    -    Base class for :py:class:`Evaluator`s that wrap Java/Scala
    -    implementations.
    -    """
    -
    -    __metaclass__ = ABCMeta
    +        """
    +        Initialize this instance with a Java model object.
    +        Subclasses should call this constructor, initialize params,
    +        and then call _transformer_params_from_java.
    +        """
    +        super(JavaModel, self).__init__()
    +        self._java_obj = java_model
    +        self.uid = java_model.uid()
     
    -    def evaluate(self, dataset, params={}):
    -        java_obj = self._java_obj()
    -        self._transfer_params_to_java(params, java_obj)
    -        return java_obj.evaluate(dataset._jdf, self._empty_java_param_map())
    +    def copy(self, extra=None):
    +        """
    +        Creates a copy of this instance with the same uid and some
    +        extra params. This implementation first calls Params.copy and
    +        then make a copy of the companion Java model with extra params.
    +        So both the Python wrapper and the Java model get copied.
    +        :param extra: Extra parameters to copy to the new instance
    +        :return: Copy of this instance
    +        """
    +        if extra is None:
    +            extra = dict()
    +        that = super(JavaModel, self).copy(extra)
    +        that._java_obj = self._java_obj.copy(self._empty_java_param_map())
    +        that._transfer_params_to_java()
    +        return that
    +
    +    def _call_java(self, name, *args):
    +        m = getattr(self._java_obj, name)
    +        sc = SparkContext._active_spark_context
    +        java_args = [_py2java(sc, arg) for arg in args]
    +        return _java2py(sc, m(*java_args))
    diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py
    index 07507b2ad0d05..acba3a717d21a 100644
    --- a/python/pyspark/mllib/__init__.py
    +++ b/python/pyspark/mllib/__init__.py
    @@ -23,16 +23,10 @@
     # MLlib currently needs NumPy 1.4+, so complain if lower
     
     import numpy
    -if numpy.version.version < '1.4':
    +
    +ver = [int(x) for x in numpy.version.version.split('.')[:2]]
    +if ver < [1, 4]:
         raise Exception("MLlib requires NumPy 1.4+")
     
     __all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random',
                'recommendation', 'regression', 'stat', 'tree', 'util']
    -
    -import sys
    -from . import rand as random
    -modname = __name__ + '.random'
    -random.__name__ = modname
    -random.RandomRDDs.__module__ = modname
    -sys.modules[modname] = random
    -del modname, sys
    diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
    index a70c664a71fdb..8f27c446a66e8 100644
    --- a/python/pyspark/mllib/classification.py
    +++ b/python/pyspark/mllib/classification.py
    @@ -21,20 +21,24 @@
     from numpy import array
     
     from pyspark import RDD
    +from pyspark.streaming import DStream
     from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
     from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector
    -from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
    +from pyspark.mllib.regression import (
    +    LabeledPoint, LinearModel, _regression_train_wrapper,
    +    StreamingLinearAlgorithm)
     from pyspark.mllib.util import Saveable, Loader, inherit_doc
     
     
     __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS',
    -           'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
    +           'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes',
    +           'StreamingLogisticRegressionWithSGD']
     
     
     class LinearClassificationModel(LinearModel):
         """
    -    A private abstract class representing a multiclass classification model.
    -    The categories are represented by int values: 0, 1, 2, etc.
    +    A private abstract class representing a multiclass classification
    +    model. The categories are represented by int values: 0, 1, 2, etc.
         """
         def __init__(self, weights, intercept):
             super(LinearClassificationModel, self).__init__(weights, intercept)
    @@ -44,10 +48,11 @@ def setThreshold(self, value):
             """
             .. note:: Experimental
     
    -        Sets the threshold that separates positive predictions from negative
    -        predictions. An example with prediction score greater than or equal
    -        to this threshold is identified as an positive, and negative otherwise.
    -        It is used for binary classification only.
    +        Sets the threshold that separates positive predictions from
    +        negative predictions. An example with prediction score greater
    +        than or equal to this threshold is identified as an positive,
    +        and negative otherwise. It is used for binary classification
    +        only.
             """
             self._threshold = value
     
    @@ -56,8 +61,9 @@ def threshold(self):
             """
             .. note:: Experimental
     
    -        Returns the threshold (if any) used for converting raw prediction scores
    -        into 0/1 predictions. It is used for binary classification only.
    +        Returns the threshold (if any) used for converting raw
    +        prediction scores into 0/1 predictions. It is used for
    +        binary classification only.
             """
             return self._threshold
     
    @@ -65,22 +71,35 @@ def clearThreshold(self):
             """
             .. note:: Experimental
     
    -        Clears the threshold so that `predict` will output raw prediction scores.
    -        It is used for binary classification only.
    +        Clears the threshold so that `predict` will output raw
    +        prediction scores. It is used for binary classification only.
             """
             self._threshold = None
     
         def predict(self, test):
             """
    -        Predict values for a single data point or an RDD of points using
    -        the model trained.
    +        Predict values for a single data point or an RDD of points
    +        using the model trained.
             """
             raise NotImplementedError
     
     
     class LogisticRegressionModel(LinearClassificationModel):
     
    -    """A linear binary classification model derived from logistic regression.
    +    """
    +    Classification model trained using Multinomial/Binary Logistic
    +    Regression.
    +
    +    :param weights: Weights computed for every feature.
    +    :param intercept: Intercept computed for this model. (Only used
    +            in Binary Logistic Regression. In Multinomial Logistic
    +            Regression, the intercepts will not be a single value,
    +            so the intercepts will be part of the weights.)
    +    :param numFeatures: the dimension of the features.
    +    :param numClasses: the number of possible outcomes for k classes
    +            classification problem in Multinomial Logistic Regression.
    +            By default, it is binary logistic regression so numClasses
    +            will be set to 2.
     
         >>> data = [
         ...     LabeledPoint(0.0, [0.0, 1.0]),
    @@ -120,8 +139,9 @@ class LogisticRegressionModel(LinearClassificationModel):
         1
         >>> sameModel.predict(SparseVector(2, {0: 1.0}))
         0
    +    >>> from shutil import rmtree
         >>> try:
    -    ...    os.removedirs(path)
    +    ...    rmtree(path)
         ... except:
         ...    pass
         >>> multi_class_data = [
    @@ -161,8 +181,8 @@ def numClasses(self):
     
         def predict(self, x):
             """
    -        Predict values for a single data point or an RDD of points using
    -        the model trained.
    +        Predict values for a single data point or an RDD of points
    +        using the model trained.
             """
             if isinstance(x, RDD):
                 return x.map(lambda v: self.predict(v))
    @@ -225,16 +245,19 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
             """
             Train a logistic regression model on the given data.
     
    -        :param data:              The training data, an RDD of LabeledPoint.
    -        :param iterations:        The number of iterations (default: 100).
    +        :param data:              The training data, an RDD of
    +                                  LabeledPoint.
    +        :param iterations:        The number of iterations
    +                                  (default: 100).
             :param step:              The step parameter used in SGD
                                       (default: 1.0).
    -        :param miniBatchFraction: Fraction of data to be used for each SGD
    -                                  iteration.
    +        :param miniBatchFraction: Fraction of data to be used for each
    +                                  SGD iteration (default: 1.0).
             :param initialWeights:    The initial weights (default: None).
    -        :param regParam:          The regularizer parameter (default: 0.01).
    -        :param regType:           The type of regularizer used for training
    -                                  our model.
    +        :param regParam:          The regularizer parameter
    +                                  (default: 0.01).
    +        :param regType:           The type of regularizer used for
    +                                  training our model.
     
                                       :Allowed values:
                                          - "l1" for using L1 regularization
    @@ -243,13 +266,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
     
                                          (default: "l2")
     
    -        :param intercept:         Boolean parameter which indicates the use
    -                                  or not of the augmented representation for
    -                                  training data (i.e. whether bias features
    -                                  are activated or not).
    -        :param validateData:      Boolean parameter which indicates if the
    -                                  algorithm should validate data before training.
    -                                  (default: True)
    +        :param intercept:         Boolean parameter which indicates the
    +                                  use or not of the augmented representation
    +                                  for training data (i.e. whether bias
    +                                  features are activated or not,
    +                                  default: False).
    +        :param validateData:      Boolean parameter which indicates if
    +                                  the algorithm should validate data
    +                                  before training. (default: True)
             """
             def train(rdd, i):
                 return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations),
    @@ -267,12 +291,15 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType
             """
             Train a logistic regression model on the given data.
     
    -        :param data:           The training data, an RDD of LabeledPoint.
    -        :param iterations:     The number of iterations (default: 100).
    +        :param data:           The training data, an RDD of
    +                               LabeledPoint.
    +        :param iterations:     The number of iterations
    +                               (default: 100).
             :param initialWeights: The initial weights (default: None).
    -        :param regParam:       The regularizer parameter (default: 0.01).
    -        :param regType:        The type of regularizer used for training
    -                               our model.
    +        :param regParam:       The regularizer parameter
    +                               (default: 0.01).
    +        :param regType:        The type of regularizer used for
    +                               training our model.
     
                                    :Allowed values:
                                      - "l1" for using L1 regularization
    @@ -281,19 +308,21 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType
     
                                      (default: "l2")
     
    -        :param intercept:      Boolean parameter which indicates the use
    -                               or not of the augmented representation for
    -                               training data (i.e. whether bias features
    -                               are activated or not).
    -        :param corrections:    The number of corrections used in the LBFGS
    -                               update (default: 10).
    -        :param tolerance:      The convergence tolerance of iterations for
    -                               L-BFGS (default: 1e-4).
    +        :param intercept:      Boolean parameter which indicates the
    +                               use or not of the augmented representation
    +                               for training data (i.e. whether bias
    +                               features are activated or not,
    +                               default: False).
    +        :param corrections:    The number of corrections used in the
    +                               LBFGS update (default: 10).
    +        :param tolerance:      The convergence tolerance of iterations
    +                               for L-BFGS (default: 1e-4).
             :param validateData:   Boolean parameter which indicates if the
    -                               algorithm should validate data before training.
    -                               (default: True)
    -        :param numClasses:     The number of classes (i.e., outcomes) a label can take
    -                               in Multinomial Logistic Regression (default: 2).
    +                               algorithm should validate data before
    +                               training. (default: True)
    +        :param numClasses:     The number of classes (i.e., outcomes) a
    +                               label can take in Multinomial Logistic
    +                               Regression (default: 2).
     
             >>> data = [
             ...     LabeledPoint(0.0, [0.0, 1.0]),
    @@ -323,7 +352,11 @@ def train(rdd, i):
     
     class SVMModel(LinearClassificationModel):
     
    -    """A support vector machine.
    +    """
    +    Model for Support Vector Machines (SVMs).
    +
    +    :param weights: Weights computed for every feature.
    +    :param intercept: Intercept computed for this model.
     
         >>> data = [
         ...     LabeledPoint(0.0, [0.0]),
    @@ -359,8 +392,9 @@ class SVMModel(LinearClassificationModel):
         1
         >>> sameModel.predict(SparseVector(2, {0: -1.0}))
         0
    +    >>> from shutil import rmtree
         >>> try:
    -    ...    os.removedirs(path)
    +    ...    rmtree(path)
         ... except:
         ...    pass
         """
    @@ -370,8 +404,8 @@ def __init__(self, weights, intercept):
     
         def predict(self, x):
             """
    -        Predict values for a single data point or an RDD of points using
    -        the model trained.
    +        Predict values for a single data point or an RDD of points
    +        using the model trained.
             """
             if isinstance(x, RDD):
                 return x.map(lambda v: self.predict(v))
    @@ -409,16 +443,19 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
             """
             Train a support vector machine on the given data.
     
    -        :param data:              The training data, an RDD of LabeledPoint.
    -        :param iterations:        The number of iterations (default: 100).
    +        :param data:              The training data, an RDD of
    +                                  LabeledPoint.
    +        :param iterations:        The number of iterations
    +                                  (default: 100).
             :param step:              The step parameter used in SGD
                                       (default: 1.0).
    -        :param regParam:          The regularizer parameter (default: 0.01).
    -        :param miniBatchFraction: Fraction of data to be used for each SGD
    -                                  iteration.
    +        :param regParam:          The regularizer parameter
    +                                  (default: 0.01).
    +        :param miniBatchFraction: Fraction of data to be used for each
    +                                  SGD iteration (default: 1.0).
             :param initialWeights:    The initial weights (default: None).
    -        :param regType:           The type of regularizer used for training
    -                                  our model.
    +        :param regType:           The type of regularizer used for
    +                                  training our model.
     
                                       :Allowed values:
                                          - "l1" for using L1 regularization
    @@ -427,13 +464,14 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
     
                                          (default: "l2")
     
    -        :param intercept:         Boolean parameter which indicates the use
    -                                  or not of the augmented representation for
    -                                  training data (i.e. whether bias features
    -                                  are activated or not).
    -        :param validateData:      Boolean parameter which indicates if the
    -                                  algorithm should validate data before training.
    -                                  (default: True)
    +        :param intercept:         Boolean parameter which indicates the
    +                                  use or not of the augmented representation
    +                                  for training data (i.e. whether bias
    +                                  features are activated or not,
    +                                  default: False).
    +        :param validateData:      Boolean parameter which indicates if
    +                                  the algorithm should validate data
    +                                  before training. (default: True)
             """
             def train(rdd, i):
                 return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step),
    @@ -449,9 +487,11 @@ class NaiveBayesModel(Saveable, Loader):
         """
         Model for Naive Bayes classifiers.
     
    -    Contains two parameters:
    -    - pi: vector of logs of class priors (dimension C)
    -    - theta: matrix of logs of class conditional probabilities (CxD)
    +    :param labels: list of labels.
    +    :param pi: log of class priors, whose dimension is C,
    +            number of labels.
    +    :param theta: log of class conditional probabilities, whose
    +            dimension is C-by-D, where D is number of features.
     
         >>> data = [
         ...     LabeledPoint(0.0, [0.0, 0.0]),
    @@ -481,8 +521,9 @@ class NaiveBayesModel(Saveable, Loader):
         >>> sameModel = NaiveBayesModel.load(sc, path)
         >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0}))
         True
    +    >>> from shutil import rmtree
         >>> try:
    -    ...     os.removedirs(path)
    +    ...     rmtree(path)
         ... except OSError:
         ...     pass
         """
    @@ -493,7 +534,10 @@ def __init__(self, labels, pi, theta):
             self.theta = theta
     
         def predict(self, x):
    -        """Return the most likely class for a data vector or an RDD of vectors"""
    +        """
    +        Return the most likely class for a data vector
    +        or an RDD of vectors
    +        """
             if isinstance(x, RDD):
                 return x.map(lambda v: self.predict(v))
             x = _convert_to_vector(x)
    @@ -523,24 +567,76 @@ class NaiveBayes(object):
         @classmethod
         def train(cls, data, lambda_=1.0):
             """
    -        Train a Naive Bayes model given an RDD of (label, features) vectors.
    +        Train a Naive Bayes model given an RDD of (label, features)
    +        vectors.
     
    -        This is the Multinomial NB (U{http://tinyurl.com/lsdw6p}) which can
    -        handle all kinds of discrete data.  For example, by converting
    -        documents into TF-IDF vectors, it can be used for document
    -        classification.  By making every vector a 0-1 vector, it can also be
    -        used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}).
    +        This is the Multinomial NB (U{http://tinyurl.com/lsdw6p}) which
    +        can handle all kinds of discrete data.  For example, by
    +        converting documents into TF-IDF vectors, it can be used for
    +        document classification. By making every vector a 0-1 vector,
    +        it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}).
    +        The input feature values must be nonnegative.
     
             :param data: RDD of LabeledPoint.
    -        :param lambda_: The smoothing parameter
    +        :param lambda_: The smoothing parameter (default: 1.0).
             """
             first = data.first()
             if not isinstance(first, LabeledPoint):
                 raise ValueError("`data` should be an RDD of LabeledPoint")
    -        labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_)
    +        labels, pi, theta = callMLlibFunc("trainNaiveBayesModel", data, lambda_)
             return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta))
     
     
    +@inherit_doc
    +class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm):
    +    """
    +    Run LogisticRegression with SGD on a batch of data.
    +
    +    The weights obtained at the end of training a stream are used as initial
    +    weights for the next batch.
    +
    +    :param stepSize: Step size for each iteration of gradient descent.
    +    :param numIterations: Number of iterations run for each batch of data.
    +    :param miniBatchFraction: Fraction of data on which SGD is run for each
    +                              iteration.
    +    :param regParam: L2 Regularization parameter.
    +    """
    +    def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01):
    +        self.stepSize = stepSize
    +        self.numIterations = numIterations
    +        self.regParam = regParam
    +        self.miniBatchFraction = miniBatchFraction
    +        self._model = None
    +        super(StreamingLogisticRegressionWithSGD, self).__init__(
    +            model=self._model)
    +
    +    def setInitialWeights(self, initialWeights):
    +        """
    +        Set the initial value of weights.
    +
    +        This must be set before running trainOn and predictOn.
    +        """
    +        initialWeights = _convert_to_vector(initialWeights)
    +
    +        # LogisticRegressionWithSGD does only binary classification.
    +        self._model = LogisticRegressionModel(
    +            initialWeights, 0, initialWeights.size, 2)
    +        return self
    +
    +    def trainOn(self, dstream):
    +        """Train the model on the incoming dstream."""
    +        self._validate(dstream)
    +
    +        def update(rdd):
    +            # LogisticRegressionWithSGD.train raises an error for an empty RDD.
    +            if not rdd.isEmpty():
    +                self._model = LogisticRegressionWithSGD.train(
    +                    rdd, self.numIterations, self.stepSize,
    +                    self.miniBatchFraction, self._model.weights)
    +
    +        dstream.foreachRDD(update)
    +
    +
     def _test():
         import doctest
         from pyspark import SparkContext
    diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
    index 04e67158514f5..ed4d78a2c6788 100644
    --- a/python/pyspark/mllib/clustering.py
    +++ b/python/pyspark/mllib/clustering.py
    @@ -21,16 +21,23 @@
     if sys.version > '3':
         xrange = range
     
    -from numpy import array
    +from math import exp, log
    +
    +from numpy import array, random, tile
    +
    +from collections import namedtuple
     
    -from pyspark import RDD
     from pyspark import SparkContext
    -from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py
    -from pyspark.mllib.linalg import SparseVector, _convert_to_vector
    +from pyspark.rdd import RDD, ignore_unicode_prefix
    +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py
    +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector
     from pyspark.mllib.stat.distribution import MultivariateGaussian
    -from pyspark.mllib.util import Saveable, Loader, inherit_doc
    +from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable
    +from pyspark.streaming import DStream
     
    -__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture']
    +__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture',
    +           'PowerIterationClusteringModel', 'PowerIterationClustering',
    +           'StreamingKMeans', 'StreamingKMeansModel']
     
     
     @inherit_doc
    @@ -75,8 +82,9 @@ class KMeansModel(Saveable, Loader):
         >>> sameModel = KMeansModel.load(sc, path)
         >>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0])
         True
    +    >>> from shutil import rmtree
         >>> try:
    -    ...     os.removedirs(path)
    +    ...     rmtree(path)
         ... except OSError:
         ...     pass
         """
    @@ -98,6 +106,9 @@ def predict(self, x):
             """Find the cluster to which x belongs in this model."""
             best = 0
             best_distance = float("inf")
    +        if isinstance(x, RDD):
    +            return x.map(self.predict)
    +
             x = _convert_to_vector(x)
             for i in xrange(len(self.centers)):
                 distance = x.squared_distance(self.centers[i])
    @@ -142,6 +153,7 @@ class GaussianMixtureModel(object):
     
         """A clustering model derived from the Gaussian Mixture Model method.
     
    +    >>> from pyspark.mllib.linalg import Vectors, DenseMatrix
         >>> clusterdata_1 =  sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
         ...                                         0.9,0.8,0.75,0.935,
         ...                                        -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
    @@ -154,11 +166,12 @@ class GaussianMixtureModel(object):
         True
         >>> labels[4]==labels[5]
         True
    -    >>> clusterdata_2 =  sc.parallelize(array([-5.1971, -2.5359, -3.8220,
    -    ...                                        -5.2211, -5.0602,  4.7118,
    -    ...                                         6.8989, 3.4592,  4.6322,
    -    ...                                         5.7048,  4.6567, 5.5026,
    -    ...                                         4.5605,  5.2043,  6.2734]).reshape(5, 3))
    +    >>> data =  array([-5.1971, -2.5359, -3.8220,
    +    ...                -5.2211, -5.0602,  4.7118,
    +    ...                 6.8989, 3.4592,  4.6322,
    +    ...                 5.7048,  4.6567, 5.5026,
    +    ...                 4.5605,  5.2043,  6.2734])
    +    >>> clusterdata_2 = sc.parallelize(data.reshape(5,3))
         >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001,
         ...                               maxIterations=150, seed=10)
         >>> labels = model.predict(clusterdata_2).collect()
    @@ -166,12 +179,38 @@ class GaussianMixtureModel(object):
         True
         >>> labels[3]==labels[4]
         True
    +    >>> clusterdata_3 = sc.parallelize(data.reshape(15, 1))
    +    >>> im = GaussianMixtureModel([0.5, 0.5],
    +    ...      [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, [1.0])),
    +    ...      MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, [1.0]))])
    +    >>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im)
         """
     
         def __init__(self, weights, gaussians):
    -        self.weights = weights
    -        self.gaussians = gaussians
    -        self.k = len(self.weights)
    +        self._weights = weights
    +        self._gaussians = gaussians
    +        self._k = len(self._weights)
    +
    +    @property
    +    def weights(self):
    +        """
    +        Weights for each Gaussian distribution in the mixture, where weights[i] is
    +        the weight for Gaussian i, and weights.sum == 1.
    +        """
    +        return self._weights
    +
    +    @property
    +    def gaussians(self):
    +        """
    +        Array of MultivariateGaussian where gaussians[i] represents
    +        the Multivariate Gaussian (Normal) Distribution for Gaussian i.
    +        """
    +        return self._gaussians
    +
    +    @property
    +    def k(self):
    +        """Number of gaussians in mixture."""
    +        return self._k
     
         def predict(self, x):
             """
    @@ -184,6 +223,9 @@ def predict(self, x):
             if isinstance(x, RDD):
                 cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
                 return cluster_labels
    +        else:
    +            raise TypeError("x should be represented by an RDD, "
    +                            "but got %s." % type(x))
     
         def predictSoft(self, x):
             """
    @@ -193,10 +235,13 @@ def predictSoft(self, x):
             :return:     membership_matrix. RDD of array of double values.
             """
             if isinstance(x, RDD):
    -            means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
    +            means, sigmas = zip(*[(g.mu, g.sigma) for g in self._gaussians])
                 membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
    -                                              _convert_to_vector(self.weights), means, sigmas)
    +                                              _convert_to_vector(self._weights), means, sigmas)
                 return membership_matrix.map(lambda x: pyarray.array('d', x))
    +        else:
    +            raise TypeError("x should be represented by an RDD, "
    +                            "but got %s." % type(x))
     
     
     class GaussianMixture(object):
    @@ -208,20 +253,320 @@ class GaussianMixture(object):
         :param convergenceTol:  Threshold value to check the convergence criteria. Defaults to 1e-3
         :param maxIterations:   Number of iterations. Default to 100
         :param seed:            Random Seed
    +    :param initialModel:    GaussianMixtureModel for initializing learning
         """
         @classmethod
    -    def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None):
    +    def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initialModel=None):
             """Train a Gaussian Mixture clustering model."""
    -        weight, mu, sigma = callMLlibFunc("trainGaussianMixture",
    -                                          rdd.map(_convert_to_vector), k,
    -                                          convergenceTol, maxIterations, seed)
    +        initialModelWeights = None
    +        initialModelMu = None
    +        initialModelSigma = None
    +        if initialModel is not None:
    +            if initialModel.k != k:
    +                raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s"
    +                                % (initialModel.k, k))
    +            initialModelWeights = initialModel.weights
    +            initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
    +            initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
    +        weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
    +                                          k, convergenceTol, maxIterations, seed,
    +                                          initialModelWeights, initialModelMu, initialModelSigma)
             mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)]
             return GaussianMixtureModel(weight, mvg_obj)
     
     
    +class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader):
    +
    +    """
    +    .. note:: Experimental
    +
    +    Model produced by [[PowerIterationClustering]].
    +
    +    >>> data = [(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (1, 3, 1.0),
    +    ... (2, 3, 1.0), (3, 4, 0.1), (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0),
    +    ... (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), (10, 11, 1.0),
    +    ... (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)]
    +    >>> rdd = sc.parallelize(data, 2)
    +    >>> model = PowerIterationClustering.train(rdd, 2, 100)
    +    >>> model.k
    +    2
    +    >>> result = sorted(model.assignments().collect(), key=lambda x: x.id)
    +    >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster
    +    True
    +    >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster
    +    True
    +    >>> import os, tempfile
    +    >>> path = tempfile.mkdtemp()
    +    >>> model.save(sc, path)
    +    >>> sameModel = PowerIterationClusteringModel.load(sc, path)
    +    >>> sameModel.k
    +    2
    +    >>> result = sorted(model.assignments().collect(), key=lambda x: x.id)
    +    >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster
    +    True
    +    >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster
    +    True
    +    >>> from shutil import rmtree
    +    >>> try:
    +    ...     rmtree(path)
    +    ... except OSError:
    +    ...     pass
    +    """
    +
    +    @property
    +    def k(self):
    +        """
    +        Returns the number of clusters.
    +        """
    +        return self.call("k")
    +
    +    def assignments(self):
    +        """
    +        Returns the cluster assignments of this model.
    +        """
    +        return self.call("getAssignments").map(
    +            lambda x: (PowerIterationClustering.Assignment(*x)))
    +
    +    @classmethod
    +    def load(cls, sc, path):
    +        model = cls._load_java(sc, path)
    +        wrapper = sc._jvm.PowerIterationClusteringModelWrapper(model)
    +        return PowerIterationClusteringModel(wrapper)
    +
    +
    +class PowerIterationClustering(object):
    +    """
    +    .. note:: Experimental
    +
    +    Power Iteration Clustering (PIC), a scalable graph clustering algorithm
    +    developed by [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]].
    +    From the abstract: PIC finds a very low-dimensional embedding of a
    +    dataset using truncated power iteration on a normalized pair-wise
    +    similarity matrix of the data.
    +    """
    +
    +    @classmethod
    +    def train(cls, rdd, k, maxIterations=100, initMode="random"):
    +        """
    +        :param rdd: an RDD of (i, j, s,,ij,,) tuples representing the
    +            affinity matrix, which is the matrix A in the PIC paper.
    +            The similarity s,,ij,, must be nonnegative.
    +            This is a symmetric matrix and hence s,,ij,, = s,,ji,,.
    +            For any (i, j) with nonzero similarity, there should be
    +            either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input.
    +            Tuples with i = j are ignored, because we assume
    +            s,,ij,, = 0.0.
    +        :param k: Number of clusters.
    +        :param maxIterations: Maximum number of iterations of the
    +            PIC algorithm.
    +        :param initMode: Initialization mode.
    +        """
    +        model = callMLlibFunc("trainPowerIterationClusteringModel",
    +                              rdd.map(_convert_to_vector), int(k), int(maxIterations), initMode)
    +        return PowerIterationClusteringModel(model)
    +
    +    class Assignment(namedtuple("Assignment", ["id", "cluster"])):
    +        """
    +        Represents an (id, cluster) tuple.
    +        """
    +
    +
    +class StreamingKMeansModel(KMeansModel):
    +    """
    +    .. note:: Experimental
    +
    +    Clustering model which can perform an online update of the centroids.
    +
    +    The update formula for each centroid is given by
    +
    +    * c_t+1 = ((c_t * n_t * a) + (x_t * m_t)) / (n_t + m_t)
    +    * n_t+1 = n_t * a + m_t
    +
    +    where
    +
    +    * c_t: Centroid at the n_th iteration.
    +    * n_t: Number of samples (or) weights associated with the centroid
    +           at the n_th iteration.
    +    * x_t: Centroid of the new data closest to c_t.
    +    * m_t: Number of samples (or) weights of the new data closest to c_t
    +    * c_t+1: New centroid.
    +    * n_t+1: New number of weights.
    +    * a: Decay Factor, which gives the forgetfulness.
    +
    +    Note that if a is set to 1, it is the weighted mean of the previous
    +    and new data. If it set to zero, the old centroids are completely
    +    forgotten.
    +
    +    :param clusterCenters: Initial cluster centers.
    +    :param clusterWeights: List of weights assigned to each cluster.
    +
    +    >>> initCenters = [[0.0, 0.0], [1.0, 1.0]]
    +    >>> initWeights = [1.0, 1.0]
    +    >>> stkm = StreamingKMeansModel(initCenters, initWeights)
    +    >>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1],
    +    ...                        [0.9, 0.9], [1.1, 1.1]])
    +    >>> stkm = stkm.update(data, 1.0, u"batches")
    +    >>> stkm.centers
    +    array([[ 0.,  0.],
    +           [ 1.,  1.]])
    +    >>> stkm.predict([-0.1, -0.1])
    +    0
    +    >>> stkm.predict([0.9, 0.9])
    +    1
    +    >>> stkm.clusterWeights
    +    [3.0, 3.0]
    +    >>> decayFactor = 0.0
    +    >>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2, 0.2])])
    +    >>> stkm = stkm.update(data, 0.0, u"batches")
    +    >>> stkm.centers
    +    array([[ 0.2,  0.2],
    +           [ 1.5,  1.5]])
    +    >>> stkm.clusterWeights
    +    [1.0, 1.0]
    +    >>> stkm.predict([0.2, 0.2])
    +    0
    +    >>> stkm.predict([1.5, 1.5])
    +    1
    +    """
    +    def __init__(self, clusterCenters, clusterWeights):
    +        super(StreamingKMeansModel, self).__init__(centers=clusterCenters)
    +        self._clusterWeights = list(clusterWeights)
    +
    +    @property
    +    def clusterWeights(self):
    +        """Return the cluster weights."""
    +        return self._clusterWeights
    +
    +    @ignore_unicode_prefix
    +    def update(self, data, decayFactor, timeUnit):
    +        """Update the centroids, according to data
    +
    +        :param data: Should be a RDD that represents the new data.
    +        :param decayFactor: forgetfulness of the previous centroids.
    +        :param timeUnit: Can be "batches" or "points". If points, then the
    +                         decay factor is raised to the power of number of new
    +                         points and if batches, it is used as it is.
    +        """
    +        if not isinstance(data, RDD):
    +            raise TypeError("Data should be of an RDD, got %s." % type(data))
    +        data = data.map(_convert_to_vector)
    +        decayFactor = float(decayFactor)
    +        if timeUnit not in ["batches", "points"]:
    +            raise ValueError(
    +                "timeUnit should be 'batches' or 'points', got %s." % timeUnit)
    +        vectorCenters = [_convert_to_vector(center) for center in self.centers]
    +        updatedModel = callMLlibFunc(
    +            "updateStreamingKMeansModel", vectorCenters, self._clusterWeights,
    +            data, decayFactor, timeUnit)
    +        self.centers = array(updatedModel[0])
    +        self._clusterWeights = list(updatedModel[1])
    +        return self
    +
    +
    +class StreamingKMeans(object):
    +    """
    +    .. note:: Experimental
    +
    +    Provides methods to set k, decayFactor, timeUnit to configure the
    +    KMeans algorithm for fitting and predicting on incoming dstreams.
    +    More details on how the centroids are updated are provided under the
    +    docs of StreamingKMeansModel.
    +
    +    :param k: int, number of clusters
    +    :param decayFactor: float, forgetfulness of the previous centroids.
    +    :param timeUnit: can be "batches" or "points". If points, then the
    +                     decayfactor is raised to the power of no. of new points.
    +    """
    +    def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"):
    +        self._k = k
    +        self._decayFactor = decayFactor
    +        if timeUnit not in ["batches", "points"]:
    +            raise ValueError(
    +                "timeUnit should be 'batches' or 'points', got %s." % timeUnit)
    +        self._timeUnit = timeUnit
    +        self._model = None
    +
    +    def latestModel(self):
    +        """Return the latest model"""
    +        return self._model
    +
    +    def _validate(self, dstream):
    +        if self._model is None:
    +            raise ValueError(
    +                "Initial centers should be set either by setInitialCenters "
    +                "or setRandomCenters.")
    +        if not isinstance(dstream, DStream):
    +            raise TypeError(
    +                "Expected dstream to be of type DStream, "
    +                "got type %s" % type(dstream))
    +
    +    def setK(self, k):
    +        """Set number of clusters."""
    +        self._k = k
    +        return self
    +
    +    def setDecayFactor(self, decayFactor):
    +        """Set decay factor."""
    +        self._decayFactor = decayFactor
    +        return self
    +
    +    def setHalfLife(self, halfLife, timeUnit):
    +        """
    +        Set number of batches after which the centroids of that
    +        particular batch has half the weightage.
    +        """
    +        self._timeUnit = timeUnit
    +        self._decayFactor = exp(log(0.5) / halfLife)
    +        return self
    +
    +    def setInitialCenters(self, centers, weights):
    +        """
    +        Set initial centers. Should be set before calling trainOn.
    +        """
    +        self._model = StreamingKMeansModel(centers, weights)
    +        return self
    +
    +    def setRandomCenters(self, dim, weight, seed):
    +        """
    +        Set the initial centres to be random samples from
    +        a gaussian population with constant weights.
    +        """
    +        rng = random.RandomState(seed)
    +        clusterCenters = rng.randn(self._k, dim)
    +        clusterWeights = tile(weight, self._k)
    +        self._model = StreamingKMeansModel(clusterCenters, clusterWeights)
    +        return self
    +
    +    def trainOn(self, dstream):
    +        """Train the model on the incoming dstream."""
    +        self._validate(dstream)
    +
    +        def update(rdd):
    +            self._model.update(rdd, self._decayFactor, self._timeUnit)
    +
    +        dstream.foreachRDD(update)
    +
    +    def predictOn(self, dstream):
    +        """
    +        Make predictions on a dstream.
    +        Returns a transformed dstream object
    +        """
    +        self._validate(dstream)
    +        return dstream.map(lambda x: self._model.predict(x))
    +
    +    def predictOnValues(self, dstream):
    +        """
    +        Make predictions on a keyed dstream.
    +        Returns a transformed dstream object.
    +        """
    +        self._validate(dstream)
    +        return dstream.mapValues(lambda x: self._model.predict(x))
    +
    +
     def _test():
         import doctest
    -    globs = globals().copy()
    +    import pyspark.mllib.clustering
    +    globs = pyspark.mllib.clustering.__dict__.copy()
         globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
         (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
         globs['sc'].stop()
    diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
    index ba6058978880a..855e85f57155e 100644
    --- a/python/pyspark/mllib/common.py
    +++ b/python/pyspark/mllib/common.py
    @@ -27,7 +27,7 @@
     
     from pyspark import RDD, SparkContext
     from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
    -
    +from pyspark.sql import DataFrame, SQLContext
     
     # Hack for support float('inf') in Py4j
     _old_smart_decode = py4j.protocol.smart_decode
    @@ -99,6 +99,9 @@ def _java2py(sc, r, encoding="bytes"):
                 jrdd = sc._jvm.SerDe.javaToPython(r)
                 return RDD(jrdd, sc)
     
    +        if clsName == 'DataFrame':
    +            return DataFrame(r, SQLContext(sc))
    +
             if clsName in _picklable_classes:
                 r = sc._jvm.SerDe.dumps(r)
             elif isinstance(r, (JavaArray, JavaList)):
    diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
    index 4c777f2180dc9..f21403707e12a 100644
    --- a/python/pyspark/mllib/evaluation.py
    +++ b/python/pyspark/mllib/evaluation.py
    @@ -27,6 +27,8 @@ class BinaryClassificationMetrics(JavaModelWrapper):
         """
         Evaluator for binary classification.
     
    +    :param scoreAndLabels: an RDD of (score, label) pairs
    +
         >>> scoreAndLabels = sc.parallelize([
         ...     (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
         >>> metrics = BinaryClassificationMetrics(scoreAndLabels)
    @@ -38,9 +40,6 @@ class BinaryClassificationMetrics(JavaModelWrapper):
         """
     
         def __init__(self, scoreAndLabels):
    -        """
    -        :param scoreAndLabels: an RDD of (score, label) pairs
    -        """
             sc = scoreAndLabels.ctx
             sql_ctx = SQLContext(sc)
             df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
    @@ -76,6 +75,9 @@ class RegressionMetrics(JavaModelWrapper):
         """
         Evaluator for regression.
     
    +    :param predictionAndObservations: an RDD of (prediction,
    +                                      observation) pairs.
    +
         >>> predictionAndObservations = sc.parallelize([
         ...     (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
         >>> metrics = RegressionMetrics(predictionAndObservations)
    @@ -92,9 +94,6 @@ class RegressionMetrics(JavaModelWrapper):
         """
     
         def __init__(self, predictionAndObservations):
    -        """
    -        :param predictionAndObservations: an RDD of (prediction, observation) pairs.
    -        """
             sc = predictionAndObservations.ctx
             sql_ctx = SQLContext(sc)
             df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([
    @@ -148,9 +147,15 @@ class MulticlassMetrics(JavaModelWrapper):
         """
         Evaluator for multiclass classification.
     
    +    :param predictionAndLabels an RDD of (prediction, label) pairs.
    +
         >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
         ...     (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
         >>> metrics = MulticlassMetrics(predictionAndLabels)
    +    >>> metrics.confusionMatrix().toArray()
    +    array([[ 2.,  1.,  1.],
    +           [ 1.,  3.,  0.],
    +           [ 0.,  0.,  1.]])
         >>> metrics.falsePositiveRate(0.0)
         0.2...
         >>> metrics.precision(1.0)
    @@ -176,9 +181,6 @@ class MulticlassMetrics(JavaModelWrapper):
         """
     
         def __init__(self, predictionAndLabels):
    -        """
    -        :param predictionAndLabels an RDD of (prediction, label) pairs.
    -        """
             sc = predictionAndLabels.ctx
             sql_ctx = SQLContext(sc)
             df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
    @@ -188,6 +190,13 @@ def __init__(self, predictionAndLabels):
             java_model = java_class(df._jdf)
             super(MulticlassMetrics, self).__init__(java_model)
     
    +    def confusionMatrix(self):
    +        """
    +        Returns confusion matrix: predicted classes are in columns,
    +        they are ordered by class label ascending, as in "labels".
    +        """
    +        return self.call("confusionMatrix")
    +
         def truePositiveRate(self, label):
             """
             Returns true positive rate for a given label (category).
    @@ -277,6 +286,9 @@ class RankingMetrics(JavaModelWrapper):
         """
         Evaluator for ranking algorithms.
     
    +    :param predictionAndLabels: an RDD of (predicted ranking,
    +                                ground truth set) pairs.
    +
         >>> predictionAndLabels = sc.parallelize([
         ...     ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]),
         ...     ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]),
    @@ -298,9 +310,6 @@ class RankingMetrics(JavaModelWrapper):
         """
     
         def __init__(self, predictionAndLabels):
    -        """
    -        :param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs.
    -        """
             sc = predictionAndLabels.ctx
             sql_ctx = SQLContext(sc)
             df = sql_ctx.createDataFrame(predictionAndLabels,
    @@ -334,16 +343,136 @@ def ndcgAt(self, k):
             """
             Compute the average NDCG value of all the queries, truncated at ranking position k.
             The discounted cumulative gain at position k is computed as:
    -            sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
    +        sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
             and the NDCG is obtained by dividing the DCG value on the ground truth set.
             In the current implementation, the relevance value is binary.
    -
    -        If a query has an empty ground truth set, zero will be used as ndcg together with
    +        If a query has an empty ground truth set, zero will be used as NDCG together with
             a log warning.
             """
             return self.call("ndcgAt", int(k))
     
     
    +class MultilabelMetrics(JavaModelWrapper):
    +    """
    +    Evaluator for multilabel classification.
    +
    +    :param predictionAndLabels: an RDD of (predictions, labels) pairs,
    +                                both are non-null Arrays, each with
    +                                unique elements.
    +
    +    >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
    +    ...     ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
    +    ...     ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])])
    +    >>> metrics = MultilabelMetrics(predictionAndLabels)
    +    >>> metrics.precision(0.0)
    +    1.0
    +    >>> metrics.recall(1.0)
    +    0.66...
    +    >>> metrics.f1Measure(2.0)
    +    0.5
    +    >>> metrics.precision()
    +    0.66...
    +    >>> metrics.recall()
    +    0.64...
    +    >>> metrics.f1Measure()
    +    0.63...
    +    >>> metrics.microPrecision
    +    0.72...
    +    >>> metrics.microRecall
    +    0.66...
    +    >>> metrics.microF1Measure
    +    0.69...
    +    >>> metrics.hammingLoss
    +    0.33...
    +    >>> metrics.subsetAccuracy
    +    0.28...
    +    >>> metrics.accuracy
    +    0.54...
    +    """
    +
    +    def __init__(self, predictionAndLabels):
    +        sc = predictionAndLabels.ctx
    +        sql_ctx = SQLContext(sc)
    +        df = sql_ctx.createDataFrame(predictionAndLabels,
    +                                     schema=sql_ctx._inferSchema(predictionAndLabels))
    +        java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics
    +        java_model = java_class(df._jdf)
    +        super(MultilabelMetrics, self).__init__(java_model)
    +
    +    def precision(self, label=None):
    +        """
    +        Returns precision or precision for a given label (category) if specified.
    +        """
    +        if label is None:
    +            return self.call("precision")
    +        else:
    +            return self.call("precision", float(label))
    +
    +    def recall(self, label=None):
    +        """
    +        Returns recall or recall for a given label (category) if specified.
    +        """
    +        if label is None:
    +            return self.call("recall")
    +        else:
    +            return self.call("recall", float(label))
    +
    +    def f1Measure(self, label=None):
    +        """
    +        Returns f1Measure or f1Measure for a given label (category) if specified.
    +        """
    +        if label is None:
    +            return self.call("f1Measure")
    +        else:
    +            return self.call("f1Measure", float(label))
    +
    +    @property
    +    def microPrecision(self):
    +        """
    +        Returns micro-averaged label-based precision.
    +        (equals to micro-averaged document-based precision)
    +        """
    +        return self.call("microPrecision")
    +
    +    @property
    +    def microRecall(self):
    +        """
    +        Returns micro-averaged label-based recall.
    +        (equals to micro-averaged document-based recall)
    +        """
    +        return self.call("microRecall")
    +
    +    @property
    +    def microF1Measure(self):
    +        """
    +        Returns micro-averaged label-based f1-measure.
    +        (equals to micro-averaged document-based f1-measure)
    +        """
    +        return self.call("microF1Measure")
    +
    +    @property
    +    def hammingLoss(self):
    +        """
    +        Returns Hamming-loss.
    +        """
    +        return self.call("hammingLoss")
    +
    +    @property
    +    def subsetAccuracy(self):
    +        """
    +        Returns subset accuracy.
    +        (for equal sets of labels)
    +        """
    +        return self.call("subsetAccuracy")
    +
    +    @property
    +    def accuracy(self):
    +        """
    +        Returns accuracy.
    +        """
    +        return self.call("accuracy")
    +
    +
     def _test():
         import doctest
         from pyspark import SparkContext
    diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
    index aac305db6c19a..f921e3ad1a314 100644
    --- a/python/pyspark/mllib/feature.py
    +++ b/python/pyspark/mllib/feature.py
    @@ -33,12 +33,14 @@
     from pyspark import SparkContext
     from pyspark.rdd import RDD, ignore_unicode_prefix
     from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
    -from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector, _convert_to_vector
    +from pyspark.mllib.linalg import (
    +    Vector, Vectors, DenseVector, SparseVector, _convert_to_vector)
     from pyspark.mllib.regression import LabeledPoint
    +from pyspark.mllib.util import JavaLoader, JavaSaveable
     
     __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
                'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
    -           'ChiSqSelector', 'ChiSqSelectorModel']
    +           'ChiSqSelector', 'ChiSqSelectorModel', 'ElementwiseProduct']
     
     
     class VectorTransformer(object):
    @@ -68,6 +70,8 @@ class Normalizer(VectorTransformer):
         For `p` = float('inf'), max(abs(vector)) will be used as norm for
         normalization.
     
    +    :param p: Normalization in L^p^ space, p = 2 by default.
    +
         >>> v = Vectors.dense(range(3))
         >>> nor = Normalizer(1)
         >>> nor.transform(v)
    @@ -82,9 +86,6 @@ class Normalizer(VectorTransformer):
         DenseVector([0.0, 0.5, 1.0])
         """
         def __init__(self, p=2.0):
    -        """
    -        :param p: Normalization in L^p^ space, p = 2 by default.
    -        """
             assert p >= 1.0, "p should be greater than 1.0"
             self.p = float(p)
     
    @@ -94,7 +95,7 @@ def transform(self, vector):
     
             :param vector: vector or RDD of vector to be normalized.
             :return: normalized vector. If the norm of the input is zero, it
    -                will return the input vector.
    +                 will return the input vector.
             """
             sc = SparkContext._active_spark_context
             assert sc is not None, "SparkContext should be initialized first"
    @@ -111,6 +112,15 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer):
         """
     
         def transform(self, vector):
    +        """
    +        Applies transformation on a vector or an RDD[Vector].
    +
    +        Note: In Python, transform cannot currently be used within
    +              an RDD transformation or action.
    +              Call transform directly on the RDD instead.
    +
    +        :param vector: Vector or RDD of Vector to be transformed.
    +        """
             if isinstance(vector, RDD):
                 vector = vector.map(_convert_to_vector)
             else:
    @@ -164,6 +174,13 @@ class StandardScaler(object):
         variance using column summary statistics on the samples in the
         training set.
     
    +    :param withMean: False by default. Centers the data with mean
    +                     before scaling. It will build a dense output, so this
    +                     does not work on sparse input and will raise an
    +                     exception.
    +    :param withStd: True by default. Scales the data to unit
    +                    standard deviation.
    +
         >>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])]
         >>> dataset = sc.parallelize(vs)
         >>> standardizer = StandardScaler(True, True)
    @@ -174,14 +191,6 @@ class StandardScaler(object):
         DenseVector([0.7071, -0.7071, 0.7071])
         """
         def __init__(self, withMean=False, withStd=True):
    -        """
    -        :param withMean: False by default. Centers the data with mean
    -                 before scaling. It will build a dense output, so this
    -                 does not work on sparse input and will raise an
    -                 exception.
    -        :param withStd: True by default. Scales the data to unit
    -                 standard deviation.
    -        """
             if not (withMean or withStd):
                 warnings.warn("Both withMean and withStd are false. The model does nothing.")
             self.withMean = withMean
    @@ -192,8 +201,8 @@ def fit(self, dataset):
             Computes the mean and variance and stores as a model to be used
             for later scaling.
     
    -        :param data: The data used to compute the mean and variance
    -                 to build the transformation model.
    +        :param dataset: The data used to compute the mean and variance
    +                     to build the transformation model.
             :return: a StandardScalarModel
             """
             dataset = dataset.map(_convert_to_vector)
    @@ -223,6 +232,8 @@ class ChiSqSelector(object):
     
         Creates a ChiSquared feature selector.
     
    +    :param numTopFeatures: number of features that selector will select.
    +
         >>> data = [
         ...     LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
         ...     LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})),
    @@ -236,9 +247,6 @@ class ChiSqSelector(object):
         DenseVector([5.0])
         """
         def __init__(self, numTopFeatures):
    -        """
    -        :param numTopFeatures: number of features that selector will select.
    -        """
             self.numTopFeatures = int(numTopFeatures)
     
         def fit(self, data):
    @@ -246,14 +254,49 @@ def fit(self, data):
             Returns a ChiSquared feature selector.
     
             :param data: an `RDD[LabeledPoint]` containing the labeled dataset
    -                 with categorical features. Real-valued features will be
    -                 treated as categorical for each distinct value.
    -                 Apply feature discretizer before using this function.
    +                     with categorical features. Real-valued features will be
    +                     treated as categorical for each distinct value.
    +                     Apply feature discretizer before using this function.
             """
             jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data)
             return ChiSqSelectorModel(jmodel)
     
     
    +class PCAModel(JavaVectorTransformer):
    +    """
    +    Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA.
    +    """
    +
    +
    +class PCA(object):
    +    """
    +    A feature transformer that projects vectors to a low-dimensional space using PCA.
    +
    +    >>> data = [Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),
    +    ...     Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),
    +    ...     Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0])]
    +    >>> model = PCA(2).fit(sc.parallelize(data))
    +    >>> pcArray = model.transform(Vectors.sparse(5, [(1, 1.0), (3, 7.0)])).toArray()
    +    >>> pcArray[0]
    +    1.648...
    +    >>> pcArray[1]
    +    -4.013...
    +    """
    +    def __init__(self, k):
    +        """
    +        :param k: number of principal components.
    +        """
    +        self.k = int(k)
    +
    +    def fit(self, data):
    +        """
    +        Computes a [[PCAModel]] that contains the principal components of the input vectors.
    +        :param data: source vectors
    +        """
    +        jmodel = callMLlibFunc("fitPCA", self.k, data)
    +        return PCAModel(jmodel)
    +
    +
     class HashingTF(object):
         """
         .. note:: Experimental
    @@ -263,15 +306,14 @@ class HashingTF(object):
     
         Note: the terms must be hashable (can not be dict/set/list...).
     
    +    :param numFeatures: number of features (default: 2^20)
    +
         >>> htf = HashingTF(100)
         >>> doc = "a a b b c d".split(" ")
         >>> htf.transform(doc)
         SparseVector(100, {...})
         """
         def __init__(self, numFeatures=1 << 20):
    -        """
    -        :param numFeatures: number of features (default: 2^20)
    -        """
             self.numFeatures = numFeatures
     
         def indexOf(self, term):
    @@ -311,13 +353,9 @@ def transform(self, x):
                   Call transform directly on the RDD instead.
     
             :param x: an RDD of term frequency vectors or a term frequency
    -                 vector
    +                  vector
             :return: an RDD of TF-IDF vectors or a TF-IDF vector
             """
    -        if isinstance(x, RDD):
    -            return JavaVectorTransformer.transform(self, x)
    -
    -        x = _convert_to_vector(x)
             return JavaVectorTransformer.transform(self, x)
     
         def idf(self):
    @@ -342,6 +380,9 @@ class IDF(object):
         `minDocFreq`). For terms that are not in at least `minDocFreq`
         documents, the IDF is found as 0, resulting in TF-IDFs of 0.
     
    +    :param minDocFreq: minimum of documents in which a term
    +                       should appear for filtering
    +
         >>> n = 4
         >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)),
         ...          Vectors.dense([0.0, 1.0, 2.0, 3.0]),
    @@ -362,10 +403,6 @@ class IDF(object):
         SparseVector(4, {1: 0.0, 3: 0.5754})
         """
         def __init__(self, minDocFreq=0):
    -        """
    -        :param minDocFreq: minimum of documents in which a term
    -                           should appear for filtering
    -        """
             self.minDocFreq = minDocFreq
     
         def fit(self, dataset):
    @@ -380,7 +417,7 @@ def fit(self, dataset):
             return IDFModel(jmodel)
     
     
    -class Word2VecModel(JavaVectorTransformer):
    +class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader):
         """
         class for Word2Vec model
         """
    @@ -419,6 +456,12 @@ def getVectors(self):
             """
             return self.call("getVectors")
     
    +    @classmethod
    +    def load(cls, sc, path):
    +        jmodel = sc._jvm.org.apache.spark.mllib.feature \
    +            .Word2VecModel.load(sc._jsc.sc(), path)
    +        return Word2VecModel(jmodel)
    +
     
     @ignore_unicode_prefix
     class Word2Vec(object):
    @@ -452,6 +495,18 @@ class Word2Vec(object):
         >>> syms = model.findSynonyms(vec, 2)
         >>> [s[0] for s in syms]
         [u'b', u'c']
    +
    +    >>> import os, tempfile
    +    >>> path = tempfile.mkdtemp()
    +    >>> model.save(sc, path)
    +    >>> sameModel = Word2VecModel.load(sc, path)
    +    >>> model.transform("a") == sameModel.transform("a")
    +    True
    +    >>> from shutil import rmtree
    +    >>> try:
    +    ...     rmtree(path)
    +    ... except OSError:
    +    ...     pass
         """
         def __init__(self):
             """
    @@ -518,13 +573,45 @@ def fit(self, data):
             """
             if not isinstance(data, RDD):
                 raise TypeError("data should be an RDD of list of string")
    -        jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
    +        jmodel = callMLlibFunc("trainWord2VecModel", data, int(self.vectorSize),
                                    float(self.learningRate), int(self.numPartitions),
                                    int(self.numIterations), int(self.seed),
                                    int(self.minCount))
             return Word2VecModel(jmodel)
     
     
    +class ElementwiseProduct(VectorTransformer):
    +    """
    +    .. note:: Experimental
    +
    +    Scales each column of the vector, with the supplied weight vector.
    +    i.e the elementwise product.
    +
    +    >>> weight = Vectors.dense([1.0, 2.0, 3.0])
    +    >>> eprod = ElementwiseProduct(weight)
    +    >>> a = Vectors.dense([2.0, 1.0, 3.0])
    +    >>> eprod.transform(a)
    +    DenseVector([2.0, 2.0, 9.0])
    +    >>> b = Vectors.dense([9.0, 3.0, 4.0])
    +    >>> rdd = sc.parallelize([a, b])
    +    >>> eprod.transform(rdd).collect()
    +    [DenseVector([2.0, 2.0, 9.0]), DenseVector([9.0, 6.0, 12.0])]
    +    """
    +    def __init__(self, scalingVector):
    +        self.scalingVector = _convert_to_vector(scalingVector)
    +
    +    def transform(self, vector):
    +        """
    +        Computes the Hadamard product of the vector.
    +        """
    +        if isinstance(vector, RDD):
    +            vector = vector.map(_convert_to_vector)
    +
    +        else:
    +            vector = _convert_to_vector(vector)
    +        return callMLlibFunc("elementwiseProductVector", self.scalingVector, vector)
    +
    +
     def _test():
         import doctest
         from pyspark import SparkContext
    diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py
    index d8df02bdbaba9..bdc4a132b1b18 100644
    --- a/python/pyspark/mllib/fpm.py
    +++ b/python/pyspark/mllib/fpm.py
    @@ -61,12 +61,12 @@ class FPGrowth(object):
         def train(cls, data, minSupport=0.3, numPartitions=-1):
             """
             Computes an FP-Growth model that contains frequent itemsets.
    -        :param data:            The input data set, each element
    -                                contains a transaction.
    -        :param minSupport:      The minimal support level
    -                                (default: `0.3`).
    -        :param numPartitions:   The number of partitions used by parallel
    -                                FP-growth (default: same as input data).
    +
    +        :param data: The input data set, each element contains a
    +            transaction.
    +        :param minSupport: The minimal support level (default: `0.3`).
    +        :param numPartitions: The number of partitions used by
    +            parallel FP-growth (default: same as input data).
             """
             model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions))
             return FPGrowthModel(model)
    diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
    index 23d1a79ffe511..040886f71775b 100644
    --- a/python/pyspark/mllib/linalg.py
    +++ b/python/pyspark/mllib/linalg.py
    @@ -31,12 +31,13 @@
         xrange = range
         import copyreg as copy_reg
     else:
    +    from itertools import izip as zip
         import copy_reg
     
     import numpy as np
     
     from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
    -    IntegerType, ByteType
    +    IntegerType, ByteType, BooleanType
     
     
     __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors',
    @@ -116,6 +117,10 @@ def _format_float(f, digits=4):
         return s
     
     
    +def _format_float_list(l):
    +    return [_format_float(x) for x in l]
    +
    +
     class VectorUDT(UserDefinedType):
         """
         SQL user-defined type (UDT) for Vector.
    @@ -163,6 +168,59 @@ def simpleString(self):
             return "vector"
     
     
    +class MatrixUDT(UserDefinedType):
    +    """
    +    SQL user-defined type (UDT) for Matrix.
    +    """
    +
    +    @classmethod
    +    def sqlType(cls):
    +        return StructType([
    +            StructField("type", ByteType(), False),
    +            StructField("numRows", IntegerType(), False),
    +            StructField("numCols", IntegerType(), False),
    +            StructField("colPtrs", ArrayType(IntegerType(), False), True),
    +            StructField("rowIndices", ArrayType(IntegerType(), False), True),
    +            StructField("values", ArrayType(DoubleType(), False), True),
    +            StructField("isTransposed", BooleanType(), False)])
    +
    +    @classmethod
    +    def module(cls):
    +        return "pyspark.mllib.linalg"
    +
    +    @classmethod
    +    def scalaUDT(cls):
    +        return "org.apache.spark.mllib.linalg.MatrixUDT"
    +
    +    def serialize(self, obj):
    +        if isinstance(obj, SparseMatrix):
    +            colPtrs = [int(i) for i in obj.colPtrs]
    +            rowIndices = [int(i) for i in obj.rowIndices]
    +            values = [float(v) for v in obj.values]
    +            return (0, obj.numRows, obj.numCols, colPtrs,
    +                    rowIndices, values, bool(obj.isTransposed))
    +        elif isinstance(obj, DenseMatrix):
    +            values = [float(v) for v in obj.values]
    +            return (1, obj.numRows, obj.numCols, None, None, values,
    +                    bool(obj.isTransposed))
    +        else:
    +            raise TypeError("cannot serialize type %r" % (type(obj)))
    +
    +    def deserialize(self, datum):
    +        assert len(datum) == 7, \
    +            "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum)
    +        tpe = datum[0]
    +        if tpe == 0:
    +            return SparseMatrix(*datum[1:])
    +        elif tpe == 1:
    +            return DenseMatrix(datum[1], datum[2], datum[5], datum[6])
    +        else:
    +            raise ValueError("do not recognize type %r" % tpe)
    +
    +    def simpleString(self):
    +        return "matrix"
    +
    +
     class Vector(object):
     
         __UDT__ = VectorUDT()
    @@ -387,8 +445,10 @@ def __init__(self, size, *args):
             values (sorted by index).
     
             :param size: Size of the vector.
    -        :param args: Non-zero entries, as a dictionary, list of tupes,
    -               or two sorted lists containing indices and values.
    +        :param args: Active entries, as a dictionary {index: value, ...},
    +          a list of tuples [(index, value), ...], or a list of strictly i
    +          ncreasing indices and a list of corresponding values [index, ...],
    +          [value, ...]. Inactive entries are treated as zeros.
     
             >>> SparseVector(4, {1: 1.0, 3: 5.5})
             SparseVector(4, {1: 1.0, 3: 5.5})
    @@ -398,6 +458,7 @@ def __init__(self, size, *args):
             SparseVector(4, {1: 1.0, 3: 5.5})
             """
             self.size = int(size)
    +        """ Size of the vector. """
             assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments"
             if len(args) == 1:
                 pairs = args[0]
    @@ -405,7 +466,9 @@ def __init__(self, size, *args):
                     pairs = pairs.items()
                 pairs = sorted(pairs)
                 self.indices = np.array([p[0] for p in pairs], dtype=np.int32)
    +            """ A list of indices corresponding to active entries. """
                 self.values = np.array([p[1] for p in pairs], dtype=np.float64)
    +            """ A list of values corresponding to active entries. """
             else:
                 if isinstance(args[0], bytes):
                     assert isinstance(args[1], bytes), "values should be string too"
    @@ -524,34 +587,27 @@ def dot(self, other):
                 ...
             AssertionError: dimension mismatch
             """
    -        if type(other) == np.ndarray:
    -            if other.ndim == 2:
    -                results = [self.dot(other[:, i]) for i in xrange(other.shape[1])]
    -                return np.array(results)
    -            elif other.ndim > 2:
    +
    +        if isinstance(other, np.ndarray):
    +            if other.ndim not in [2, 1]:
                     raise ValueError("Cannot call dot with %d-dimensional array" % other.ndim)
    +            assert len(self) == other.shape[0], "dimension mismatch"
    +            return np.dot(self.values, other[self.indices])
     
             assert len(self) == _vector_size(other), "dimension mismatch"
     
    -        if type(other) in (np.ndarray, array.array, DenseVector):
    -            result = 0.0
    -            for i in xrange(len(self.indices)):
    -                result += self.values[i] * other[self.indices[i]]
    -            return result
    +        if isinstance(other, DenseVector):
    +            return np.dot(other.array[self.indices], self.values)
     
    -        elif type(other) is SparseVector:
    -            result = 0.0
    -            i, j = 0, 0
    -            while i < len(self.indices) and j < len(other.indices):
    -                if self.indices[i] == other.indices[j]:
    -                    result += self.values[i] * other.values[j]
    -                    i += 1
    -                    j += 1
    -                elif self.indices[i] < other.indices[j]:
    -                    i += 1
    -                else:
    -                    j += 1
    -            return result
    +        elif isinstance(other, SparseVector):
    +            # Find out common indices.
    +            self_cmind = np.in1d(self.indices, other.indices, assume_unique=True)
    +            self_values = self.values[self_cmind]
    +            if self_values.size == 0:
    +                return 0.0
    +            else:
    +                other_cmind = np.in1d(other.indices, self.indices, assume_unique=True)
    +                return np.dot(self_values, other.values[other_cmind])
     
             else:
                 return self.dot(_convert_to_vector(other))
    @@ -582,22 +638,23 @@ def squared_distance(self, other):
             AssertionError: dimension mismatch
             """
             assert len(self) == _vector_size(other), "dimension mismatch"
    -        if type(other) in (list, array.array, DenseVector, np.array, np.ndarray):
    -            if type(other) is np.array and other.ndim != 1:
    +
    +        if isinstance(other, np.ndarray) or isinstance(other, DenseVector):
    +            if isinstance(other, np.ndarray) and other.ndim != 1:
                     raise Exception("Cannot call squared_distance with %d-dimensional array" %
                                     other.ndim)
    -            result = 0.0
    -            j = 0   # index into our own array
    -            for i in xrange(len(other)):
    -                if j < len(self.indices) and self.indices[j] == i:
    -                    diff = self.values[j] - other[i]
    -                    result += diff * diff
    -                    j += 1
    -                else:
    -                    result += other[i] * other[i]
    +            if isinstance(other, DenseVector):
    +                other = other.array
    +            sparse_ind = np.zeros(other.size, dtype=bool)
    +            sparse_ind[self.indices] = True
    +            dist = other[sparse_ind] - self.values
    +            result = np.dot(dist, dist)
    +
    +            other_ind = other[~sparse_ind]
    +            result += np.dot(other_ind, other_ind)
                 return result
     
    -        elif type(other) is SparseVector:
    +        elif isinstance(other, SparseVector):
                 result = 0.0
                 i, j = 0, 0
                 while i < len(self.indices) and j < len(other.indices):
    @@ -781,10 +838,12 @@ def zeros(size):
     
     
     class Matrix(object):
    +
    +    __UDT__ = MatrixUDT()
    +
         """
         Represents a local matrix.
         """
    -
         def __init__(self, numRows, numCols, isTransposed=False):
             self.numRows = numRows
             self.numCols = numCols
    @@ -821,6 +880,50 @@ def __reduce__(self):
                 self.numRows, self.numCols, self.values.tostring(),
                 int(self.isTransposed))
     
    +    def __str__(self):
    +        """
    +        Pretty printing of a DenseMatrix
    +
    +        >>> dm = DenseMatrix(2, 2, range(4))
    +        >>> print(dm)
    +        DenseMatrix([[ 0.,  2.],
    +                     [ 1.,  3.]])
    +        >>> dm = DenseMatrix(2, 2, range(4), isTransposed=True)
    +        >>> print(dm)
    +        DenseMatrix([[ 0.,  1.],
    +                     [ 2.,  3.]])
    +        """
    +        # Inspired by __repr__ in scipy matrices.
    +        array_lines = repr(self.toArray()).splitlines()
    +
    +        # We need to adjust six spaces which is the difference in number
    +        # of letters between "DenseMatrix" and "array"
    +        x = '\n'.join([(" " * 6 + line) for line in array_lines[1:]])
    +        return array_lines[0].replace("array", "DenseMatrix") + "\n" + x
    +
    +    def __repr__(self):
    +        """
    +        Representation of a DenseMatrix
    +
    +        >>> dm = DenseMatrix(2, 2, range(4))
    +        >>> dm
    +        DenseMatrix(2, 2, [0.0, 1.0, 2.0, 3.0], False)
    +        """
    +        # If the number of values are less than seventeen then return as it is.
    +        # Else return first eight values and last eight values.
    +        if len(self.values) < 17:
    +            entries = _format_float_list(self.values)
    +        else:
    +            entries = (
    +                _format_float_list(self.values[:8]) +
    +                ["..."] +
    +                _format_float_list(self.values[-8:])
    +            )
    +
    +        entries = ", ".join(entries)
    +        return "DenseMatrix({0}, {1}, [{2}], {3})".format(
    +            self.numRows, self.numCols, entries, self.isTransposed)
    +
         def toArray(self):
             """
             Return an numpy.ndarray
    @@ -897,6 +1000,84 @@ def __init__(self, numRows, numCols, colPtrs, rowIndices, values,
                 raise ValueError("Expected rowIndices of length %d, got %d."
                                  % (self.rowIndices.size, self.values.size))
     
    +    def __str__(self):
    +        """
    +        Pretty printing of a SparseMatrix
    +
    +        >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
    +        >>> print(sm1)
    +        2 X 2 CSCMatrix
    +        (0,0) 2.0
    +        (1,0) 3.0
    +        (1,1) 4.0
    +        >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
    +        >>> print(sm1)
    +        2 X 2 CSRMatrix
    +        (0,0) 2.0
    +        (0,1) 3.0
    +        (1,1) 4.0
    +        """
    +        spstr = "{0} X {1} ".format(self.numRows, self.numCols)
    +        if self.isTransposed:
    +            spstr += "CSRMatrix\n"
    +        else:
    +            spstr += "CSCMatrix\n"
    +
    +        cur_col = 0
    +        smlist = []
    +
    +        # Display first 16 values.
    +        if len(self.values) <= 16:
    +            zipindval = zip(self.rowIndices, self.values)
    +        else:
    +            zipindval = zip(self.rowIndices[:16], self.values[:16])
    +        for i, (rowInd, value) in enumerate(zipindval):
    +            if self.colPtrs[cur_col + 1] <= i:
    +                cur_col += 1
    +            if self.isTransposed:
    +                smlist.append('({0},{1}) {2}'.format(
    +                    cur_col, rowInd, _format_float(value)))
    +            else:
    +                smlist.append('({0},{1}) {2}'.format(
    +                    rowInd, cur_col, _format_float(value)))
    +        spstr += "\n".join(smlist)
    +
    +        if len(self.values) > 16:
    +            spstr += "\n.." * 2
    +        return spstr
    +
    +    def __repr__(self):
    +        """
    +        Representation of a SparseMatrix
    +
    +        >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
    +        >>> sm1
    +        SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2.0, 3.0, 4.0], False)
    +        """
    +        rowIndices = list(self.rowIndices)
    +        colPtrs = list(self.colPtrs)
    +
    +        if len(self.values) <= 16:
    +            values = _format_float_list(self.values)
    +
    +        else:
    +            values = (
    +                _format_float_list(self.values[:8]) +
    +                ["..."] +
    +                _format_float_list(self.values[-8:])
    +            )
    +            rowIndices = rowIndices[:8] + ["..."] + rowIndices[-8:]
    +
    +        if len(self.colPtrs) > 16:
    +            colPtrs = colPtrs[:8] + ["..."] + colPtrs[-8:]
    +
    +        values = ", ".join(values)
    +        rowIndices = ", ".join([str(ind) for ind in rowIndices])
    +        colPtrs = ", ".join([str(ptr) for ptr in colPtrs])
    +        return "SparseMatrix({0}, {1}, [{2}], [{3}], [{4}], {5})".format(
    +            self.numRows, self.numCols, colPtrs, rowIndices,
    +            values, self.isTransposed)
    +
         def __reduce__(self):
             return SparseMatrix, (
                 self.numRows, self.numCols, self.colPtrs.tostring(),
    diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/random.py
    similarity index 100%
    rename from python/pyspark/mllib/rand.py
    rename to python/pyspark/mllib/random.py
    diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
    index 9c4647ddfdcfd..506ca2151cce7 100644
    --- a/python/pyspark/mllib/recommendation.py
    +++ b/python/pyspark/mllib/recommendation.py
    @@ -106,8 +106,9 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
         0.4...
         >>> sameModel.predictAll(testset).collect()
         [Rating(...
    +    >>> from shutil import rmtree
         >>> try:
    -    ...     os.removedirs(path)
    +    ...     rmtree(path)
         ... except OSError:
         ...     pass
         """
    diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
    index 41bde2ce3e60b..8e90adee5f4c2 100644
    --- a/python/pyspark/mllib/regression.py
    +++ b/python/pyspark/mllib/regression.py
    @@ -19,6 +19,7 @@
     from numpy import array
     
     from pyspark import RDD
    +from pyspark.streaming.dstream import DStream
     from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc
     from pyspark.mllib.linalg import SparseVector, Vectors, _convert_to_vector
     from pyspark.mllib.util import Saveable, Loader
    @@ -33,12 +34,12 @@
     class LabeledPoint(object):
     
         """
    -    The features and labels of a data point.
    +    Class that represents the features and labels of a data point.
     
         :param label: Label for this data point.
         :param features: Vector of features for this point (NumPy array,
    -             list, pyspark.mllib.linalg.SparseVector, or scipy.sparse
    -             column matrix)
    +            list, pyspark.mllib.linalg.SparseVector, or scipy.sparse
    +            column matrix)
     
         Note: 'label' and 'features' are accessible as class attributes.
         """
    @@ -59,7 +60,12 @@ def __repr__(self):
     
     class LinearModel(object):
     
    -    """A linear model that has a vector of coefficients and an intercept."""
    +    """
    +    A linear model that has a vector of coefficients and an intercept.
    +
    +    :param weights: Weights computed for every feature.
    +    :param intercept: Intercept computed for this model.
    +    """
     
         def __init__(self, weights, intercept):
             self._coeff = _convert_to_vector(weights)
    @@ -128,10 +134,11 @@ class LinearRegressionModel(LinearRegressionModelBase):
         True
         >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
         True
    +    >>> from shutil import rmtree
         >>> try:
    -    ...    os.removedirs(path)
    +    ...     rmtree(path)
         ... except:
    -    ...    pass
    +    ...     pass
         >>> data = [
         ...     LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
         ...     LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
    @@ -193,18 +200,28 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
                   initialWeights=None, regParam=0.0, regType=None, intercept=False,
                   validateData=True):
             """
    -        Train a linear regression model on the given data.
    -
    -        :param data:              The training data.
    -        :param iterations:        The number of iterations (default: 100).
    +        Train a linear regression model using Stochastic Gradient
    +        Descent (SGD).
    +        This solves the least squares regression formulation
    +                f(weights) = 1/n ||A weights-y||^2^
    +        (which is the mean squared error).
    +        Here the data matrix has n rows, and the input RDD holds the
    +        set of rows of A, each with its corresponding right hand side
    +        label y. See also the documentation for the precise formulation.
    +
    +        :param data:              The training data, an RDD of
    +                                  LabeledPoint.
    +        :param iterations:        The number of iterations
    +                                  (default: 100).
             :param step:              The step parameter used in SGD
                                       (default: 1.0).
    -        :param miniBatchFraction: Fraction of data to be used for each SGD
    -                                  iteration.
    +        :param miniBatchFraction: Fraction of data to be used for each
    +                                  SGD iteration (default: 1.0).
             :param initialWeights:    The initial weights (default: None).
    -        :param regParam:          The regularizer parameter (default: 0.0).
    -        :param regType:           The type of regularizer used for training
    -                                  our model.
    +        :param regParam:          The regularizer parameter
    +                                  (default: 0.0).
    +        :param regType:           The type of regularizer used for
    +                                  training our model.
     
                                       :Allowed values:
                                          - "l1" for using L1 regularization (lasso),
    @@ -213,13 +230,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
     
                                          (default: None)
     
    -        :param intercept:         Boolean parameter which indicates the use
    -                                  or not of the augmented representation for
    -                                  training data (i.e. whether bias features
    -                                  are activated or not). (default: False)
    -        :param validateData:      Boolean parameter which indicates if the
    -                                  algorithm should validate data before training.
    -                                  (default: True)
    +        :param intercept:         Boolean parameter which indicates the
    +                                  use or not of the augmented representation
    +                                  for training data (i.e. whether bias
    +                                  features are activated or not,
    +                                  default: False).
    +        :param validateData:      Boolean parameter which indicates if
    +                                  the algorithm should validate data
    +                                  before training. (default: True)
             """
             def train(rdd, i):
                 return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations),
    @@ -232,8 +250,8 @@ def train(rdd, i):
     @inherit_doc
     class LassoModel(LinearRegressionModelBase):
     
    -    """A linear regression model derived from a least-squares fit with an
    -    l_1 penalty term.
    +    """A linear regression model derived from a least-squares fit with
    +    an l_1 penalty term.
     
         >>> from pyspark.mllib.regression import LabeledPoint
         >>> data = [
    @@ -259,8 +277,9 @@ class LassoModel(LinearRegressionModelBase):
         True
         >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
         True
    +    >>> from shutil import rmtree
         >>> try:
    -    ...    os.removedirs(path)
    +    ...    rmtree(path)
         ... except:
         ...    pass
         >>> data = [
    @@ -304,7 +323,36 @@ class LassoWithSGD(object):
         def train(cls, data, iterations=100, step=1.0, regParam=0.01,
                   miniBatchFraction=1.0, initialWeights=None, intercept=False,
                   validateData=True):
    -        """Train a Lasso regression model on the given data."""
    +        """
    +        Train a regression model with L1-regularization using
    +        Stochastic Gradient Descent.
    +        This solves the l1-regularized least squares regression
    +        formulation
    +            f(weights) = 1/2n ||A weights-y||^2^  + regParam ||weights||_1
    +        Here the data matrix has n rows, and the input RDD holds the
    +        set of rows of A, each with its corresponding right hand side
    +        label y. See also the documentation for the precise formulation.
    +
    +        :param data:              The training data, an RDD of
    +                                  LabeledPoint.
    +        :param iterations:        The number of iterations
    +                                  (default: 100).
    +        :param step:              The step parameter used in SGD
    +                                  (default: 1.0).
    +        :param regParam:          The regularizer parameter
    +                                  (default: 0.01).
    +        :param miniBatchFraction: Fraction of data to be used for each
    +                                  SGD iteration (default: 1.0).
    +        :param initialWeights:    The initial weights (default: None).
    +        :param intercept:         Boolean parameter which indicates the
    +                                  use or not of the augmented representation
    +                                  for training data (i.e. whether bias
    +                                  features are activated or not,
    +                                  default: False).
    +        :param validateData:      Boolean parameter which indicates if
    +                                  the algorithm should validate data
    +                                  before training. (default: True)
    +        """
             def train(rdd, i):
                 return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step),
                                      float(regParam), float(miniBatchFraction), i, bool(intercept),
    @@ -316,8 +364,8 @@ def train(rdd, i):
     @inherit_doc
     class RidgeRegressionModel(LinearRegressionModelBase):
     
    -    """A linear regression model derived from a least-squares fit with an
    -    l_2 penalty term.
    +    """A linear regression model derived from a least-squares fit with
    +    an l_2 penalty term.
     
         >>> from pyspark.mllib.regression import LabeledPoint
         >>> data = [
    @@ -344,8 +392,9 @@ class RidgeRegressionModel(LinearRegressionModelBase):
         True
         >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
         True
    +    >>> from shutil import rmtree
         >>> try:
    -    ...    os.removedirs(path)
    +    ...    rmtree(path)
         ... except:
         ...    pass
         >>> data = [
    @@ -389,7 +438,36 @@ class RidgeRegressionWithSGD(object):
         def train(cls, data, iterations=100, step=1.0, regParam=0.01,
                   miniBatchFraction=1.0, initialWeights=None, intercept=False,
                   validateData=True):
    -        """Train a ridge regression model on the given data."""
    +        """
    +        Train a regression model with L2-regularization using
    +        Stochastic Gradient Descent.
    +        This solves the l2-regularized least squares regression
    +        formulation
    +            f(weights) = 1/2n ||A weights-y||^2^  + regParam/2 ||weights||^2^
    +        Here the data matrix has n rows, and the input RDD holds the
    +        set of rows of A, each with its corresponding right hand side
    +        label y. See also the documentation for the precise formulation.
    +
    +        :param data:              The training data, an RDD of
    +                                  LabeledPoint.
    +        :param iterations:        The number of iterations
    +                                  (default: 100).
    +        :param step:              The step parameter used in SGD
    +                                  (default: 1.0).
    +        :param regParam:          The regularizer parameter
    +                                  (default: 0.01).
    +        :param miniBatchFraction: Fraction of data to be used for each
    +                                  SGD iteration (default: 1.0).
    +        :param initialWeights:    The initial weights (default: None).
    +        :param intercept:         Boolean parameter which indicates the
    +                                  use or not of the augmented representation
    +                                  for training data (i.e. whether bias
    +                                  features are activated or not,
    +                                  default: False).
    +        :param validateData:      Boolean parameter which indicates if
    +                                  the algorithm should validate data
    +                                  before training. (default: True)
    +        """
             def train(rdd, i):
                 return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step),
                                      float(regParam), float(miniBatchFraction), i, bool(intercept),
    @@ -400,7 +478,15 @@ def train(rdd, i):
     
     class IsotonicRegressionModel(Saveable, Loader):
     
    -    """Regression model for isotonic regression.
    +    """
    +    Regression model for isotonic regression.
    +
    +    :param boundaries: Array of boundaries for which predictions are
    +            known. Boundaries must be sorted in increasing order.
    +    :param predictions: Array of predictions associated to the
    +            boundaries at the same index. Results of isotonic
    +            regression and therefore monotone.
    +    :param isotonic: indicates whether this is isotonic or antitonic.
     
         >>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)]
         >>> irm = IsotonicRegression.train(sc.parallelize(data))
    @@ -418,8 +504,9 @@ class IsotonicRegressionModel(Saveable, Loader):
         2.0
         >>> sameModel.predict(5)
         16.5
    +    >>> from shutil import rmtree
         >>> try:
    -    ...     os.removedirs(path)
    +    ...     rmtree(path)
         ... except OSError:
         ...     pass
         """
    @@ -430,6 +517,25 @@ def __init__(self, boundaries, predictions, isotonic):
             self.isotonic = isotonic
     
         def predict(self, x):
    +        """
    +        Predict labels for provided features.
    +        Using a piecewise linear function.
    +        1) If x exactly matches a boundary then associated prediction
    +        is returned. In case there are multiple predictions with the
    +        same boundary then one of them is returned. Which one is
    +        undefined (same as java.util.Arrays.binarySearch).
    +        2) If x is lower or higher than all boundaries then first or
    +        last prediction is returned respectively. In case there are
    +        multiple predictions with the same boundary then the lowest
    +        or highest is returned respectively.
    +        3) If x falls between two values in boundary array then
    +        prediction is treated as piecewise linear function and
    +        interpolated value is returned. In case there are multiple
    +        values with the same boundary then the same rules as in 2)
    +        are used.
    +
    +        :param x: Feature or RDD of Features to be labeled.
    +        """
             if isinstance(x, RDD):
                 return x.map(lambda v: self.predict(v))
             return np.interp(x, self.boundaries, self.predictions)
    @@ -451,20 +557,109 @@ def load(cls, sc, path):
     
     
     class IsotonicRegression(object):
    -    """
    -    Run IsotonicRegression algorithm to obtain isotonic regression model.
     
    -    :param data:            RDD of (label, feature, weight) tuples.
    -    :param isotonic:        Whether this is isotonic or antitonic.
    -    """
         @classmethod
         def train(cls, data, isotonic=True):
    -        """Train a isotonic regression model on the given data."""
    +        """
    +        Train a isotonic regression model on the given data.
    +
    +        :param data: RDD of (label, feature, weight) tuples.
    +        :param isotonic: Whether this is isotonic or antitonic.
    +        """
             boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel",
                                                     data.map(_convert_to_vector), bool(isotonic))
             return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic)
     
     
    +class StreamingLinearAlgorithm(object):
    +    """
    +    Base class that has to be inherited by any StreamingLinearAlgorithm.
    +
    +    Prevents reimplementation of methods predictOn and predictOnValues.
    +    """
    +    def __init__(self, model):
    +        self._model = model
    +
    +    def latestModel(self):
    +        """
    +        Returns the latest model.
    +        """
    +        return self._model
    +
    +    def _validate(self, dstream):
    +        if not isinstance(dstream, DStream):
    +            raise TypeError(
    +                "dstream should be a DStream object, got %s" % type(dstream))
    +        if not self._model:
    +            raise ValueError(
    +                "Model must be intialized using setInitialWeights")
    +
    +    def predictOn(self, dstream):
    +        """
    +        Make predictions on a dstream.
    +
    +        :return: Transformed dstream object.
    +        """
    +        self._validate(dstream)
    +        return dstream.map(lambda x: self._model.predict(x))
    +
    +    def predictOnValues(self, dstream):
    +        """
    +        Make predictions on a keyed dstream.
    +
    +        :return: Transformed dstream object.
    +        """
    +        self._validate(dstream)
    +        return dstream.mapValues(lambda x: self._model.predict(x))
    +
    +
    +@inherit_doc
    +class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm):
    +    """
    +    Run LinearRegression with SGD on a batch of data.
    +
    +    The problem minimized is (1 / n_samples) * (y - weights'X)**2.
    +    After training on a batch of data, the weights obtained at the end of
    +    training are used as initial weights for the next batch.
    +
    +    :param: stepSize Step size for each iteration of gradient descent.
    +    :param: numIterations Total number of iterations run.
    +    :param: miniBatchFraction Fraction of data on which SGD is run for each
    +                              iteration.
    +    """
    +    def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0):
    +        self.stepSize = stepSize
    +        self.numIterations = numIterations
    +        self.miniBatchFraction = miniBatchFraction
    +        self._model = None
    +        super(StreamingLinearRegressionWithSGD, self).__init__(
    +            model=self._model)
    +
    +    def setInitialWeights(self, initialWeights):
    +        """
    +        Set the initial value of weights.
    +
    +        This must be set before running trainOn and predictOn
    +        """
    +        initialWeights = _convert_to_vector(initialWeights)
    +        self._model = LinearRegressionModel(initialWeights, 0)
    +        return self
    +
    +    def trainOn(self, dstream):
    +        """Train the model on the incoming dstream."""
    +        self._validate(dstream)
    +
    +        def update(rdd):
    +            # LinearRegressionWithSGD.train raises an error for an empty RDD.
    +            if not rdd.isEmpty():
    +                self._model = LinearRegressionWithSGD.train(
    +                    rdd, self.numIterations, self.stepSize,
    +                    self.miniBatchFraction, self._model.weights,
    +                    self._model.intercept)
    +
    +        dstream.foreachRDD(update)
    +
    +
     def _test():
         import doctest
         from pyspark import SparkContext
    diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py
    new file mode 100644
    index 0000000000000..7da921976d4d2
    --- /dev/null
    +++ b/python/pyspark/mllib/stat/KernelDensity.py
    @@ -0,0 +1,61 @@
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one or more
    +# contributor license agreements.  See the NOTICE file distributed with
    +# this work for additional information regarding copyright ownership.
    +# The ASF licenses this file to You under the Apache License, Version 2.0
    +# (the "License"); you may not use this file except in compliance with
    +# the License.  You may obtain a copy of the License at
    +#
    +#    http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing, software
    +# distributed under the License is distributed on an "AS IS" BASIS,
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +# See the License for the specific language governing permissions and
    +# limitations under the License.
    +#
    +
    +import sys
    +
    +if sys.version > '3':
    +    xrange = range
    +
    +import numpy as np
    +
    +from pyspark.mllib.common import callMLlibFunc
    +from pyspark.rdd import RDD
    +
    +
    +class KernelDensity(object):
    +    """
    +    .. note:: Experimental
    +
    +    Estimate probability density at required points given a RDD of samples
    +    from the population.
    +
    +    >>> kd = KernelDensity()
    +    >>> sample = sc.parallelize([0.0, 1.0])
    +    >>> kd.setSample(sample)
    +    >>> kd.estimate([0.0, 1.0])
    +    array([ 0.12938758,  0.12938758])
    +    """
    +    def __init__(self):
    +        self._bandwidth = 1.0
    +        self._sample = None
    +
    +    def setBandwidth(self, bandwidth):
    +        """Set bandwidth of each sample. Defaults to 1.0"""
    +        self._bandwidth = bandwidth
    +
    +    def setSample(self, sample):
    +        """Set sample points from the population. Should be a RDD"""
    +        if not isinstance(sample, RDD):
    +            raise TypeError("samples should be a RDD, received %s" % type(sample))
    +        self._sample = sample
    +
    +    def estimate(self, points):
    +        """Estimate the probability density at points"""
    +        points = list(points)
    +        densities = callMLlibFunc(
    +            "estimateKernelDensity", self._sample, self._bandwidth, points)
    +        return np.asarray(densities)
    diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py
    index e3e128513e0d7..c8a721d3fe41c 100644
    --- a/python/pyspark/mllib/stat/__init__.py
    +++ b/python/pyspark/mllib/stat/__init__.py
    @@ -22,6 +22,7 @@
     from pyspark.mllib.stat._statistics import *
     from pyspark.mllib.stat.distribution import MultivariateGaussian
     from pyspark.mllib.stat.test import ChiSqTestResult
    +from pyspark.mllib.stat.KernelDensity import KernelDensity
     
     __all__ = ["Statistics", "MultivariateStatisticalSummary", "ChiSqTestResult",
    -           "MultivariateGaussian"]
    +           "MultivariateGaussian", "KernelDensity"]
    diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
    index 36a4c7a5408c6..f2eab5b18f077 100644
    --- a/python/pyspark/mllib/tests.py
    +++ b/python/pyspark/mllib/tests.py
    @@ -23,8 +23,13 @@
     import sys
     import tempfile
     import array as pyarray
    +from time import time, sleep
    +from shutil import rmtree
    +
    +from numpy import (
    +    array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones)
    +from numpy import sum as array_sum
     
    -from numpy import array, array_equal, zeros, inf
     from py4j.protocol import Py4JJavaError
     
     if sys.version_info[:2] <= (2, 6):
    @@ -38,16 +43,22 @@
     
     from pyspark import SparkContext
     from pyspark.mllib.common import _to_java_object_rdd
    +from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel
     from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
    -    DenseMatrix, SparseMatrix, Vectors, Matrices
    -from pyspark.mllib.regression import LabeledPoint
    +    DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
    +from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD
    +from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
     from pyspark.mllib.random import RandomRDDs
     from pyspark.mllib.stat import Statistics
     from pyspark.mllib.feature import Word2Vec
     from pyspark.mllib.feature import IDF
    -from pyspark.mllib.feature import StandardScaler
    +from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
    +from pyspark.mllib.util import LinearDataGenerator
    +from pyspark.mllib.util import MLUtils
     from pyspark.serializers import PickleSerializer
    +from pyspark.streaming import StreamingContext
     from pyspark.sql import SQLContext
    +from pyspark.streaming import StreamingContext
     
     _have_scipy = False
     try:
    @@ -66,6 +77,20 @@ def setUp(self):
             self.sc = sc
     
     
    +class MLLibStreamingTestCase(unittest.TestCase):
    +    def setUp(self):
    +        self.sc = sc
    +        self.ssc = StreamingContext(self.sc, 1.0)
    +
    +    def tearDown(self):
    +        self.ssc.stop(False)
    +
    +    @staticmethod
    +    def _ssc_wait(start_time, end_time, sleep_time):
    +        while time() - start_time < end_time:
    +            sleep(0.01)
    +
    +
     def _squared_distance(a, b):
         if isinstance(a, Vector):
             return a.squared_distance(b)
    @@ -104,17 +129,22 @@ def test_dot(self):
                          [1., 2., 3., 4.],
                          [1., 2., 3., 4.],
                          [1., 2., 3., 4.]])
    +        arr = pyarray.array('d', [0, 1, 2, 3])
             self.assertEquals(10.0, sv.dot(dv))
             self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat)))
             self.assertEquals(30.0, dv.dot(dv))
             self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat)))
             self.assertEquals(30.0, lst.dot(dv))
             self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat)))
    +        self.assertEquals(7.0, sv.dot(arr))
     
         def test_squared_distance(self):
             sv = SparseVector(4, {1: 1, 3: 2})
             dv = DenseVector(array([1., 2., 3., 4.]))
             lst = DenseVector([4, 3, 2, 1])
    +        lst1 = [4, 3, 2, 1]
    +        arr = pyarray.array('d', [0, 2, 1, 3])
    +        narr = array([0, 2, 1, 3])
             self.assertEquals(15.0, _squared_distance(sv, dv))
             self.assertEquals(25.0, _squared_distance(sv, lst))
             self.assertEquals(20.0, _squared_distance(dv, lst))
    @@ -124,6 +154,9 @@ def test_squared_distance(self):
             self.assertEquals(0.0, _squared_distance(sv, sv))
             self.assertEquals(0.0, _squared_distance(dv, dv))
             self.assertEquals(0.0, _squared_distance(lst, lst))
    +        self.assertEquals(25.0, _squared_distance(sv, lst1))
    +        self.assertEquals(3.0, _squared_distance(sv, arr))
    +        self.assertEquals(3.0, _squared_distance(sv, narr))
     
         def test_conversion(self):
             # numpy arrays should be automatically upcast to float64
    @@ -156,6 +189,53 @@ def test_matrix_indexing(self):
                 for j in range(2):
                     self.assertEquals(mat[i, j], expected[i][j])
     
    +    def test_repr_dense_matrix(self):
    +        mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
    +        self.assertTrue(
    +            repr(mat),
    +            'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
    +
    +        mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True)
    +        self.assertTrue(
    +            repr(mat),
    +            'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
    +
    +        mat = DenseMatrix(6, 3, zeros(18))
    +        self.assertTrue(
    +            repr(mat),
    +            'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \
    +                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)')
    +
    +    def test_repr_sparse_matrix(self):
    +        sm1t = SparseMatrix(
    +            3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0],
    +            isTransposed=True)
    +        self.assertTrue(
    +            repr(sm1t),
    +            'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)')
    +
    +        indices = tile(arange(6), 3)
    +        values = ones(18)
    +        sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values)
    +        self.assertTrue(
    +            repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \
    +                [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \
    +                [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \
    +                1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)")
    +
    +        self.assertTrue(
    +            str(sm),
    +            "6 X 3 CSCMatrix\n\
    +            (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\
    +            (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\
    +            (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..")
    +
    +        sm = SparseMatrix(1, 18, zeros(19), [], [])
    +        self.assertTrue(
    +            repr(sm),
    +            'SparseMatrix(1, 18, \
    +                [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)')
    +
         def test_sparse_matrix(self):
             # Test sparse matrix creation.
             sm1 = SparseMatrix(
    @@ -165,6 +245,9 @@ def test_sparse_matrix(self):
             self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4])
             self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2])
             self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0])
    +        self.assertTrue(
    +            repr(sm1),
    +            'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)')
     
             # Test indexing
             expected = [
    @@ -379,7 +462,7 @@ def test_classification(self):
             self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString())
     
             try:
    -            os.removedirs(temp_dir)
    +            rmtree(temp_dir)
             except OSError:
                 pass
     
    @@ -443,6 +526,13 @@ def test_regression(self):
             except ValueError:
                 self.fail()
     
    +        # Verify that maxBins is being passed through
    +        GradientBoostedTrees.trainRegressor(
    +            rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32)
    +        with self.assertRaises(Exception) as cm:
    +            GradientBoostedTrees.trainRegressor(
    +                rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1)
    +
     
     class StatTests(MLlibTestCase):
         # SPARK-4023
    @@ -507,6 +597,38 @@ def test_infer_schema(self):
                     raise TypeError("expecting a vector but got %r of type %r" % (v, type(v)))
     
     
    +class MatrixUDTTests(MLlibTestCase):
    +
    +    dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10])
    +    dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True)
    +    sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0])
    +    sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True)
    +    udt = MatrixUDT()
    +
    +    def test_json_schema(self):
    +        self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt)
    +
    +    def test_serialization(self):
    +        for m in [self.dm1, self.dm2, self.sm1, self.sm2]:
    +            self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m)))
    +
    +    def test_infer_schema(self):
    +        sqlCtx = SQLContext(self.sc)
    +        rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)])
    +        df = rdd.toDF()
    +        schema = df.schema
    +        self.assertTrue(schema.fields[1].dataType, self.udt)
    +        matrices = df.map(lambda x: x._2).collect()
    +        self.assertEqual(len(matrices), 2)
    +        for m in matrices:
    +            if isinstance(m, DenseMatrix):
    +                self.assertTrue(m, self.dm1)
    +            elif isinstance(m, SparseMatrix):
    +                self.assertTrue(m, self.sm1)
    +            else:
    +                raise ValueError("Expected a matrix but got type %r" % type(m))
    +
    +
     @unittest.skipIf(not _have_scipy, "SciPy not installed")
     class SciPyTests(MLlibTestCase):
     
    @@ -818,6 +940,457 @@ def test_model_transform(self):
             self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0]))
     
     
    +class ElementwiseProductTests(MLlibTestCase):
    +    def test_model_transform(self):
    +        weight = Vectors.dense([3, 2, 1])
    +
    +        densevec = Vectors.dense([4, 5, 6])
    +        sparsevec = Vectors.sparse(3, [0], [1])
    +        eprod = ElementwiseProduct(weight)
    +        self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6]))
    +        self.assertEqual(
    +            eprod.transform(sparsevec), SparseVector(3, [0], [3]))
    +
    +
    +class StreamingKMeansTest(MLLibStreamingTestCase):
    +    def test_model_params(self):
    +        """Test that the model params are set correctly"""
    +        stkm = StreamingKMeans()
    +        stkm.setK(5).setDecayFactor(0.0)
    +        self.assertEquals(stkm._k, 5)
    +        self.assertEquals(stkm._decayFactor, 0.0)
    +
    +        # Model not set yet.
    +        self.assertIsNone(stkm.latestModel())
    +        self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0])
    +
    +        stkm.setInitialCenters(
    +            centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0])
    +        self.assertEquals(
    +            stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]])
    +        self.assertEquals(stkm.latestModel().clusterWeights, [1.0, 1.0])
    +
    +    def test_accuracy_for_single_center(self):
    +        """Test that parameters obtained are correct for a single center."""
    +        centers, batches = self.streamingKMeansDataGenerator(
    +            batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0)
    +        stkm = StreamingKMeans(1)
    +        stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.])
    +        input_stream = self.ssc.queueStream(
    +            [self.sc.parallelize(batch, 1) for batch in batches])
    +        stkm.trainOn(input_stream)
    +
    +        t = time()
    +        self.ssc.start()
    +        self._ssc_wait(t, 10.0, 0.01)
    +        self.assertEquals(stkm.latestModel().clusterWeights, [25.0])
    +        realCenters = array_sum(array(centers), axis=0)
    +        for i in range(5):
    +            modelCenters = stkm.latestModel().centers[0][i]
    +            self.assertAlmostEqual(centers[0][i], modelCenters, 1)
    +            self.assertAlmostEqual(realCenters[i], modelCenters, 1)
    +
    +    def streamingKMeansDataGenerator(self, batches, numPoints,
    +                                     k, d, r, seed, centers=None):
    +        rng = random.RandomState(seed)
    +
    +        # Generate centers.
    +        centers = [rng.randn(d) for i in range(k)]
    +
    +        return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d))
    +                          for j in range(numPoints)]
    +                         for i in range(batches)]
    +
    +    def test_trainOn_model(self):
    +        """Test the model on toy data with four clusters."""
    +        stkm = StreamingKMeans()
    +        initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]]
    +        stkm.setInitialCenters(
    +            centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0])
    +
    +        # Create a toy dataset by setting a tiny offest for each point.
    +        offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]]
    +        batches = []
    +        for offset in offsets:
    +            batches.append([[offset[0] + center[0], offset[1] + center[1]]
    +                            for center in initCenters])
    +
    +        batches = [self.sc.parallelize(batch, 1) for batch in batches]
    +        input_stream = self.ssc.queueStream(batches)
    +        stkm.trainOn(input_stream)
    +        t = time()
    +        self.ssc.start()
    +
    +        # Give enough time to train the model.
    +        self._ssc_wait(t, 6.0, 0.01)
    +        finalModel = stkm.latestModel()
    +        self.assertTrue(all(finalModel.centers == array(initCenters)))
    +        self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0])
    +
    +    def test_predictOn_model(self):
    +        """Test that the model predicts correctly on toy data."""
    +        stkm = StreamingKMeans()
    +        stkm._model = StreamingKMeansModel(
    +            clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]],
    +            clusterWeights=[1.0, 1.0, 1.0, 1.0])
    +
    +        predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]]
    +        predict_data = [sc.parallelize(batch, 1) for batch in predict_data]
    +        predict_stream = self.ssc.queueStream(predict_data)
    +        predict_val = stkm.predictOn(predict_stream)
    +
    +        result = []
    +
    +        def update(rdd):
    +            rdd_collect = rdd.collect()
    +            if rdd_collect:
    +                result.append(rdd_collect)
    +
    +        predict_val.foreachRDD(update)
    +        t = time()
    +        self.ssc.start()
    +        self._ssc_wait(t, 6.0, 0.01)
    +        self.assertEquals(result, [[0], [1], [2], [3]])
    +
    +    def test_trainOn_predictOn(self):
    +        """Test that prediction happens on the updated model."""
    +        stkm = StreamingKMeans(decayFactor=0.0, k=2)
    +        stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0])
    +
    +        # Since decay factor is set to zero, once the first batch
    +        # is passed the clusterCenters are updated to [-0.5, 0.7]
    +        # which causes 0.2 & 0.3 to be classified as 1, even though the
    +        # classification based in the initial model would have been 0
    +        # proving that the model is updated.
    +        batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]]
    +        batches = [sc.parallelize(batch) for batch in batches]
    +        input_stream = self.ssc.queueStream(batches)
    +        predict_results = []
    +
    +        def collect(rdd):
    +            rdd_collect = rdd.collect()
    +            if rdd_collect:
    +                predict_results.append(rdd_collect)
    +
    +        stkm.trainOn(input_stream)
    +        predict_stream = stkm.predictOn(input_stream)
    +        predict_stream.foreachRDD(collect)
    +
    +        t = time()
    +        self.ssc.start()
    +        self._ssc_wait(t, 6.0, 0.01)
    +        self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]])
    +
    +
    +class LinearDataGeneratorTests(MLlibTestCase):
    +    def test_dim(self):
    +        linear_data = LinearDataGenerator.generateLinearInput(
    +            intercept=0.0, weights=[0.0, 0.0, 0.0],
    +            xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33],
    +            nPoints=4, seed=0, eps=0.1)
    +        self.assertEqual(len(linear_data), 4)
    +        for point in linear_data:
    +            self.assertEqual(len(point.features), 3)
    +
    +        linear_data = LinearDataGenerator.generateLinearRDD(
    +            sc=sc, nexamples=6, nfeatures=2, eps=0.1,
    +            nParts=2, intercept=0.0).collect()
    +        self.assertEqual(len(linear_data), 6)
    +        for point in linear_data:
    +            self.assertEqual(len(point.features), 2)
    +
    +
    +class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase):
    +
    +    @staticmethod
    +    def generateLogisticInput(offset, scale, nPoints, seed):
    +        """
    +        Generate 1 / (1 + exp(-x * scale + offset))
    +
    +        where,
    +        x is randomnly distributed and the threshold
    +        and labels for each sample in x is obtained from a random uniform
    +        distribution.
    +        """
    +        rng = random.RandomState(seed)
    +        x = rng.randn(nPoints)
    +        sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset)))
    +        y_p = rng.rand(nPoints)
    +        cut_off = y_p <= sigmoid
    +        y_p[cut_off] = 1.0
    +        y_p[~cut_off] = 0.0
    +        return [
    +            LabeledPoint(y_p[i], Vectors.dense([x[i]]))
    +            for i in range(nPoints)]
    +
    +    def test_parameter_accuracy(self):
    +        """
    +        Test that the final value of weights is close to the desired value.
    +        """
    +        input_batches = [
    +            self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
    +            for i in range(20)]
    +        input_stream = self.ssc.queueStream(input_batches)
    +
    +        slr = StreamingLogisticRegressionWithSGD(
    +            stepSize=0.2, numIterations=25)
    +        slr.setInitialWeights([0.0])
    +        slr.trainOn(input_stream)
    +
    +        t = time()
    +        self.ssc.start()
    +        self._ssc_wait(t, 20.0, 0.01)
    +        rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5
    +        self.assertAlmostEqual(rel, 0.1, 1)
    +
    +    def test_convergence(self):
    +        """
    +        Test that weights converge to the required value on toy data.
    +        """
    +        input_batches = [
    +            self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
    +            for i in range(20)]
    +        input_stream = self.ssc.queueStream(input_batches)
    +        models = []
    +
    +        slr = StreamingLogisticRegressionWithSGD(
    +            stepSize=0.2, numIterations=25)
    +        slr.setInitialWeights([0.0])
    +        slr.trainOn(input_stream)
    +        input_stream.foreachRDD(
    +            lambda x: models.append(slr.latestModel().weights[0]))
    +
    +        t = time()
    +        self.ssc.start()
    +        self._ssc_wait(t, 15.0, 0.01)
    +        t_models = array(models)
    +        diff = t_models[1:] - t_models[:-1]
    +
    +        # Test that weights improve with a small tolerance,
    +        self.assertTrue(all(diff >= -0.1))
    +        self.assertTrue(array_sum(diff > 0) > 1)
    +
    +    @staticmethod
    +    def calculate_accuracy_error(true, predicted):
    +        return sum(abs(array(true) - array(predicted))) / len(true)
    +
    +    def test_predictions(self):
    +        """Test predicted values on a toy model."""
    +        input_batches = []
    +        for i in range(20):
    +            batch = self.sc.parallelize(
    +                self.generateLogisticInput(0, 1.5, 100, 42 + i))
    +            input_batches.append(batch.map(lambda x: (x.label, x.features)))
    +        input_stream = self.ssc.queueStream(input_batches)
    +
    +        slr = StreamingLogisticRegressionWithSGD(
    +            stepSize=0.2, numIterations=25)
    +        slr.setInitialWeights([1.5])
    +        predict_stream = slr.predictOnValues(input_stream)
    +        true_predicted = []
    +        predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect()))
    +        t = time()
    +        self.ssc.start()
    +        self._ssc_wait(t, 5.0, 0.01)
    +
    +        # Test that the accuracy error is no more than 0.4 on each batch.
    +        for batch in true_predicted:
    +            true, predicted = zip(*batch)
    +            self.assertTrue(
    +                self.calculate_accuracy_error(true, predicted) < 0.4)
    +
    +    def test_training_and_prediction(self):
    +        """Test that the model improves on toy data with no. of batches"""
    +        input_batches = [
    +            self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
    +            for i in range(20)]
    +        predict_batches = [
    +            b.map(lambda lp: (lp.label, lp.features)) for b in input_batches]
    +
    +        slr = StreamingLogisticRegressionWithSGD(
    +            stepSize=0.01, numIterations=25)
    +        slr.setInitialWeights([-0.1])
    +        errors = []
    +
    +        def collect_errors(rdd):
    +            true, predicted = zip(*rdd.collect())
    +            errors.append(self.calculate_accuracy_error(true, predicted))
    +
    +        true_predicted = []
    +        input_stream = self.ssc.queueStream(input_batches)
    +        predict_stream = self.ssc.queueStream(predict_batches)
    +        slr.trainOn(input_stream)
    +        ps = slr.predictOnValues(predict_stream)
    +        ps.foreachRDD(lambda x: collect_errors(x))
    +
    +        t = time()
    +        self.ssc.start()
    +        self._ssc_wait(t, 20.0, 0.01)
    +
    +        # Test that the improvement in error is atleast 0.3
    +        self.assertTrue(errors[1] - errors[-1] > 0.3)
    +
    +
    +class StreamingLinearRegressionWithTests(MLLibStreamingTestCase):
    +
    +    def assertArrayAlmostEqual(self, array1, array2, dec):
    +        for i, j in array1, array2:
    +            self.assertAlmostEqual(i, j, dec)
    +
    +    def test_parameter_accuracy(self):
    +        """Test that coefs are predicted accurately by fitting on toy data."""
    +
    +        # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients
    +        # (10, 10)
    +        slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
    +        slr.setInitialWeights([0.0, 0.0])
    +        xMean = [0.0, 0.0]
    +        xVariance = [1.0 / 3.0, 1.0 / 3.0]
    +
    +        # Create ten batches with 100 sample points in each.
    +        batches = []
    +        for i in range(10):
    +            batch = LinearDataGenerator.generateLinearInput(
    +                0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1)
    +            batches.append(sc.parallelize(batch))
    +
    +        input_stream = self.ssc.queueStream(batches)
    +        t = time()
    +        slr.trainOn(input_stream)
    +        self.ssc.start()
    +        self._ssc_wait(t, 10, 0.01)
    +        self.assertArrayAlmostEqual(
    +            slr.latestModel().weights.array, [10., 10.], 1)
    +        self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1)
    +
    +    def test_parameter_convergence(self):
    +        """Test that the model parameters improve with streaming data."""
    +        slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
    +        slr.setInitialWeights([0.0])
    +
    +        # Create ten batches with 100 sample points in each.
    +        batches = []
    +        for i in range(10):
    +            batch = LinearDataGenerator.generateLinearInput(
    +                0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1)
    +            batches.append(sc.parallelize(batch))
    +
    +        model_weights = []
    +        input_stream = self.ssc.queueStream(batches)
    +        input_stream.foreachRDD(
    +            lambda x: model_weights.append(slr.latestModel().weights[0]))
    +        t = time()
    +        slr.trainOn(input_stream)
    +        self.ssc.start()
    +        self._ssc_wait(t, 10, 0.01)
    +
    +        model_weights = array(model_weights)
    +        diff = model_weights[1:] - model_weights[:-1]
    +        self.assertTrue(all(diff >= -0.1))
    +
    +    def test_prediction(self):
    +        """Test prediction on a model with weights already set."""
    +        # Create a model with initial Weights equal to coefs
    +        slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
    +        slr.setInitialWeights([10.0, 10.0])
    +
    +        # Create ten batches with 100 sample points in each.
    +        batches = []
    +        for i in range(10):
    +            batch = LinearDataGenerator.generateLinearInput(
    +                0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0],
    +                100, 42 + i, 0.1)
    +            batches.append(
    +                sc.parallelize(batch).map(lambda lp: (lp.label, lp.features)))
    +
    +        input_stream = self.ssc.queueStream(batches)
    +        t = time()
    +        output_stream = slr.predictOnValues(input_stream)
    +        samples = []
    +        output_stream.foreachRDD(lambda x: samples.append(x.collect()))
    +
    +        self.ssc.start()
    +        self._ssc_wait(t, 5, 0.01)
    +
    +        # Test that mean absolute error on each batch is less than 0.1
    +        for batch in samples:
    +            true, predicted = zip(*batch)
    +            self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1)
    +
    +    def test_train_prediction(self):
    +        """Test that error on test data improves as model is trained."""
    +        slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
    +        slr.setInitialWeights([0.0])
    +
    +        # Create ten batches with 100 sample points in each.
    +        batches = []
    +        for i in range(10):
    +            batch = LinearDataGenerator.generateLinearInput(
    +                0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1)
    +            batches.append(sc.parallelize(batch))
    +
    +        predict_batches = [
    +            b.map(lambda lp: (lp.label, lp.features)) for b in batches]
    +        mean_absolute_errors = []
    +
    +        def func(rdd):
    +            true, predicted = zip(*rdd.collect())
    +            mean_absolute_errors.append(mean(abs(true) - abs(predicted)))
    +
    +        model_weights = []
    +        input_stream = self.ssc.queueStream(batches)
    +        output_stream = self.ssc.queueStream(predict_batches)
    +        t = time()
    +        slr.trainOn(input_stream)
    +        output_stream = slr.predictOnValues(output_stream)
    +        output_stream.foreachRDD(func)
    +        self.ssc.start()
    +        self._ssc_wait(t, 10, 0.01)
    +        self.assertTrue(mean_absolute_errors[1] - mean_absolute_errors[-1] > 2)
    +
    +
    +class MLUtilsTests(MLlibTestCase):
    +    def test_append_bias(self):
    +        data = [2.0, 2.0, 2.0]
    +        ret = MLUtils.appendBias(data)
    +        self.assertEqual(ret[3], 1.0)
    +        self.assertEqual(type(ret), DenseVector)
    +
    +    def test_append_bias_with_vector(self):
    +        data = Vectors.dense([2.0, 2.0, 2.0])
    +        ret = MLUtils.appendBias(data)
    +        self.assertEqual(ret[3], 1.0)
    +        self.assertEqual(type(ret), DenseVector)
    +
    +    def test_append_bias_with_sp_vector(self):
    +        data = Vectors.sparse(3, {0: 2.0, 2: 2.0})
    +        expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0})
    +        # Returned value must be SparseVector
    +        ret = MLUtils.appendBias(data)
    +        self.assertEqual(ret, expected)
    +        self.assertEqual(type(ret), SparseVector)
    +
    +    def test_load_vectors(self):
    +        import shutil
    +        data = [
    +            [1.0, 2.0, 3.0],
    +            [1.0, 2.0, 3.0]
    +        ]
    +        temp_dir = tempfile.mkdtemp()
    +        load_vectors_path = os.path.join(temp_dir, "test_load_vectors")
    +        try:
    +            self.sc.parallelize(data).saveAsTextFile(load_vectors_path)
    +            ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path)
    +            ret = ret_rdd.collect()
    +            self.assertEqual(len(ret), 2)
    +            self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0]))
    +            self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0]))
    +        except:
    +            self.fail()
    +        finally:
    +            shutil.rmtree(load_vectors_path)
    +
    +
     if __name__ == "__main__":
         if not _have_scipy:
             print("NOTE: Skipping SciPy tests as it does not seem to be installed")
    diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
    index cfcbea573fd22..372b86a7c95d9 100644
    --- a/python/pyspark/mllib/tree.py
    +++ b/python/pyspark/mllib/tree.py
    @@ -299,7 +299,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees,
                      1 internal node + 2 leaf nodes. (default: 4)
             :param maxBins: maximum number of bins used for splitting
                      features
    -                 (default: 100)
    +                 (default: 32)
             :param seed: Random seed for bootstrapping and choosing feature
                      subsets.
             :return: RandomForestModel that can be used for prediction
    @@ -377,7 +377,7 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt
                      1 leaf node; depth 1 means 1 internal node + 2 leaf
                      nodes. (default: 4)
             :param maxBins: maximum number of bins used for splitting
    -                 features (default: 100)
    +                 features (default: 32)
             :param seed: Random seed for bootstrapping and choosing feature
                      subsets.
             :return: RandomForestModel that can be used for prediction
    @@ -435,16 +435,17 @@ class GradientBoostedTrees(object):
     
         @classmethod
         def _train(cls, data, algo, categoricalFeaturesInfo,
    -               loss, numIterations, learningRate, maxDepth):
    +               loss, numIterations, learningRate, maxDepth, maxBins):
             first = data.first()
             assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
             model = callMLlibFunc("trainGradientBoostedTreesModel", data, algo, categoricalFeaturesInfo,
    -                              loss, numIterations, learningRate, maxDepth)
    +                              loss, numIterations, learningRate, maxDepth, maxBins)
             return GradientBoostedTreesModel(model)
     
         @classmethod
         def trainClassifier(cls, data, categoricalFeaturesInfo,
    -                        loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3):
    +                        loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3,
    +                        maxBins=32):
             """
             Method to train a gradient-boosted trees model for
             classification.
    @@ -467,6 +468,8 @@ def trainClassifier(cls, data, categoricalFeaturesInfo,
             :param maxDepth: Maximum depth of the tree. E.g., depth 0 means
                      1 leaf node; depth 1 means 1 internal node + 2 leaf
                      nodes. (default: 3)
    +        :param maxBins: maximum number of bins used for splitting
    +                 features (default: 32) DecisionTree requires maxBins >= max categories
             :return: GradientBoostedTreesModel that can be used for
                        prediction
     
    @@ -499,11 +502,12 @@ def trainClassifier(cls, data, categoricalFeaturesInfo,
             [1.0, 0.0]
             """
             return cls._train(data, "classification", categoricalFeaturesInfo,
    -                          loss, numIterations, learningRate, maxDepth)
    +                          loss, numIterations, learningRate, maxDepth, maxBins)
     
         @classmethod
         def trainRegressor(cls, data, categoricalFeaturesInfo,
    -                       loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3):
    +                       loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3,
    +                       maxBins=32):
             """
             Method to train a gradient-boosted trees model for regression.
     
    @@ -522,6 +526,8 @@ def trainRegressor(cls, data, categoricalFeaturesInfo,
                      contribution of each estimator. The learning rate
                      should be between in the interval (0, 1].
                      (default: 0.1)
    +        :param maxBins: maximum number of bins used for splitting
    +                 features (default: 32) DecisionTree requires maxBins >= max categories
             :param maxDepth: Maximum depth of the tree. E.g., depth 0 means
                      1 leaf node; depth 1 means 1 internal node + 2 leaf
                      nodes.  (default: 3)
    @@ -556,7 +562,7 @@ def trainRegressor(cls, data, categoricalFeaturesInfo,
             [1.0, 0.0]
             """
             return cls._train(data, "regression", categoricalFeaturesInfo,
    -                          loss, numIterations, learningRate, maxDepth)
    +                          loss, numIterations, learningRate, maxDepth, maxBins)
     
     
     def _test():
    diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
    index 16a90db146ef0..875d3b2d642c6 100644
    --- a/python/pyspark/mllib/util.py
    +++ b/python/pyspark/mllib/util.py
    @@ -169,6 +169,28 @@ def loadLabeledPoints(sc, path, minPartitions=None):
             minPartitions = minPartitions or min(sc.defaultParallelism, 2)
             return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
     
    +    @staticmethod
    +    def appendBias(data):
    +        """
    +        Returns a new vector with `1.0` (bias) appended to
    +        the end of the input vector.
    +        """
    +        vec = _convert_to_vector(data)
    +        if isinstance(vec, SparseVector):
    +            newIndices = np.append(vec.indices, len(vec))
    +            newValues = np.append(vec.values, 1.0)
    +            return SparseVector(len(vec) + 1, newIndices, newValues)
    +        else:
    +            return _convert_to_vector(np.append(vec.toArray(), 1.0))
    +
    +    @staticmethod
    +    def loadVectors(sc, path):
    +        """
    +        Loads vectors saved using `RDD[Vector].saveAsTextFile`
    +        with the default number of partitions.
    +        """
    +        return callMLlibFunc("loadVectors", sc, path)
    +
     
     class Saveable(object):
         """
    @@ -257,6 +279,41 @@ def load(cls, sc, path):
             return cls(java_model)
     
     
    +class LinearDataGenerator(object):
    +    """Utils for generating linear data"""
    +
    +    @staticmethod
    +    def generateLinearInput(intercept, weights, xMean, xVariance,
    +                            nPoints, seed, eps):
    +        """
    +        :param: intercept bias factor, the term c in X'w + c
    +        :param: weights   feature vector, the term w in X'w + c
    +        :param: xMean     Point around which the data X is centered.
    +        :param: xVariance Variance of the given data
    +        :param: nPoints   Number of points to be generated
    +        :param: seed      Random Seed
    +        :param: eps       Used to scale the noise. If eps is set high,
    +                          the amount of gaussian noise added is more.
    +        Returns a list of LabeledPoints of length nPoints
    +        """
    +        weights = [float(weight) for weight in weights]
    +        xMean = [float(mean) for mean in xMean]
    +        xVariance = [float(var) for var in xVariance]
    +        return list(callMLlibFunc(
    +            "generateLinearInputWrapper", float(intercept), weights, xMean,
    +            xVariance, int(nPoints), int(seed), float(eps)))
    +
    +    @staticmethod
    +    def generateLinearRDD(sc, nexamples, nfeatures, eps,
    +                          nParts=2, intercept=0.0):
    +        """
    +        Generate a RDD of LabeledPoints.
    +        """
    +        return callMLlibFunc(
    +            "generateLinearRDDWrapper", sc, int(nexamples), int(nfeatures),
    +            float(eps), int(nParts), float(intercept))
    +
    +
     def _test():
         import doctest
         from pyspark.context import SparkContext
    diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py
    index d18daaabfcb3c..44d17bd629473 100644
    --- a/python/pyspark/profiler.py
    +++ b/python/pyspark/profiler.py
    @@ -90,9 +90,11 @@ class Profiler(object):
         >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler)
         >>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10)
         [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
    +    >>> sc.parallelize(range(1000)).count()
    +    1000
         >>> sc.show_profiles()
         My custom profiles for RDD:1
    -    My custom profiles for RDD:2
    +    My custom profiles for RDD:3
         >>> sc.stop()
         """
     
    @@ -169,4 +171,6 @@ def stats(self):
     
     if __name__ == "__main__":
         import doctest
    -    doctest.testmod()
    +    (failure_count, test_count) = doctest.testmod()
    +    if failure_count:
    +        exit(-1)
    diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
    index 545c5ad20cb96..3218bed5c74fc 100644
    --- a/python/pyspark/rdd.py
    +++ b/python/pyspark/rdd.py
    @@ -121,10 +121,23 @@ def _parse_memory(s):
     
     
     def _load_from_socket(port, serializer):
    -    sock = socket.socket()
    -    sock.settimeout(3)
    +    sock = None
    +    # Support for both IPv4 and IPv6.
    +    # On most of IPv6-ready systems, IPv6 will take precedence.
    +    for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
    +        af, socktype, proto, canonname, sa = res
    +        sock = socket.socket(af, socktype, proto)
    +        try:
    +            sock.settimeout(3)
    +            sock.connect(sa)
    +        except socket.error:
    +            sock.close()
    +            sock = None
    +            continue
    +        break
    +    if not sock:
    +        raise Exception("could not open socket")
         try:
    -        sock.connect(("localhost", port))
             rf = sock.makefile("rb", 65536)
             for item in serializer.load_stream(rf):
                 yield item
    @@ -687,12 +700,14 @@ def groupBy(self, f, numPartitions=None):
             return self.map(lambda x: (f(x), x)).groupByKey(numPartitions)
     
         @ignore_unicode_prefix
    -    def pipe(self, command, env={}):
    +    def pipe(self, command, env={}, checkCode=False):
             """
             Return an RDD created by piping elements to a forked external process.
     
             >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect()
             [u'1', u'2', u'', u'3']
    +
    +        :param checkCode: whether or not to check the return value of the shell command.
             """
             def func(iterator):
                 pipe = Popen(
    @@ -704,7 +719,17 @@ def pipe_objs(out):
                         out.write(s.encode('utf-8'))
                     out.close()
                 Thread(target=pipe_objs, args=[pipe.stdin]).start()
    -            return (x.rstrip(b'\n').decode('utf-8') for x in iter(pipe.stdout.readline, b''))
    +
    +            def check_return_code():
    +                pipe.wait()
    +                if checkCode and pipe.returncode:
    +                    raise Exception("Pipe function `%s' exited "
    +                                    "with error code %d" % (command, pipe.returncode))
    +                else:
    +                    for i in range(0):
    +                        yield i
    +            return (x.rstrip(b'\n').decode('utf-8') for x in
    +                    chain(iter(pipe.stdout.readline, b''), check_return_code()))
             return self.mapPartitions(func)
     
         def foreach(self, f):
    @@ -813,13 +838,21 @@ def op(x, y):
         def fold(self, zeroValue, op):
             """
             Aggregate the elements of each partition, and then the results for all
    -        the partitions, using a given associative function and a neutral "zero
    -        value."
    +        the partitions, using a given associative and commutative function and
    +        a neutral "zero value."
     
             The function C{op(t1, t2)} is allowed to modify C{t1} and return it
             as its result value to avoid object allocation; however, it should not
             modify C{t2}.
     
    +        This behaves somewhat differently from fold operations implemented
    +        for non-distributed collections in functional languages like Scala.
    +        This fold operation may be applied to partitions individually, and then
    +        fold those results into the final result, rather than apply the fold
    +        to each element sequentially in some defined ordering. For functions
    +        that are not commutative, the result may differ from that of a fold
    +        applied to a non-distributed collection.
    +
             >>> from operator import add
             >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
             15
    @@ -952,7 +985,7 @@ def sum(self):
             >>> sc.parallelize([1.0, 2.0, 3.0]).sum()
             6.0
             """
    -        return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
    +        return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add)
     
         def count(self):
             """
    @@ -2190,7 +2223,7 @@ def sumApprox(self, timeout, confidence=0.95):
     
             >>> rdd = sc.parallelize(range(1000), 10)
             >>> r = sum(range(1000))
    -        >>> (rdd.sumApprox(1000) - r) / r < 0.05
    +        >>> abs(rdd.sumApprox(1000) - r) / r < 0.05
             True
             """
             jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd()
    @@ -2207,7 +2240,7 @@ def meanApprox(self, timeout, confidence=0.95):
     
             >>> rdd = sc.parallelize(range(1000), 10)
             >>> r = sum(range(1000)) / 1000.0
    -        >>> (rdd.meanApprox(1000) - r) / r < 0.05
    +        >>> abs(rdd.meanApprox(1000) - r) / r < 0.05
             True
             """
             jrdd = self.map(float)._to_java_object_rdd()
    @@ -2260,7 +2293,7 @@ def toLocalIterator(self):
     def _prepare_for_python_RDD(sc, command, obj=None):
         # the serialized command will be compressed by broadcast
         ser = CloudPickleSerializer()
    -    pickled_command = ser.dumps((command, sys.version_info[:2]))
    +    pickled_command = ser.dumps(command)
         if len(pickled_command) > (1 << 20):  # 1M
             # The broadcast will have same life cycle as created PythonRDD
             broadcast = sc.broadcast(pickled_command)
    @@ -2344,7 +2377,7 @@ def _jrdd(self):
             python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
                                                  bytearray(pickled_cmd),
                                                  env, includes, self.preservesPartitioning,
    -                                             self.ctx.pythonExec,
    +                                             self.ctx.pythonExec, self.ctx.pythonVer,
                                                  bvars, self.ctx._javaAccumulator)
             self._jrdd_val = python_rdd.asJavaRDD()
     
    diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
    index d8cdcda3a3783..411b4dbf481f1 100644
    --- a/python/pyspark/serializers.py
    +++ b/python/pyspark/serializers.py
    @@ -44,8 +44,8 @@
     
     >>> rdd.glom().collect()
     [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
    ->>> rdd._jrdd.count()
    -8L
    +>>> int(rdd._jrdd.count())
    +8
     >>> sc.stop()
     """
     
    @@ -272,7 +272,7 @@ def dump_stream(self, iterator, stream):
                 if size < best:
                     batch *= 2
                 elif size > best * 10 and batch > 1:
    -                batch /= 2
    +                batch //= 2
     
         def __repr__(self):
             return "AutoBatchedSerializer(%s)" % self.serializer
    @@ -556,4 +556,6 @@ def write_with_length(obj, stream):
     
     if __name__ == '__main__':
         import doctest
    -    doctest.testmod()
    +    (failure_count, test_count) = doctest.testmod()
    +    if failure_count:
    +        exit(-1)
    diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
    index 1d0b16cade8bb..8fb71bac64a5e 100644
    --- a/python/pyspark/shuffle.py
    +++ b/python/pyspark/shuffle.py
    @@ -362,7 +362,7 @@ def _spill(self):
     
             self.spills += 1
             gc.collect()  # release the memory as much as possible
    -        MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
    +        MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
     
         def items(self):
             """ Return all merged items as iterator """
    @@ -486,7 +486,7 @@ def sorted(self, iterator, key=None, reverse=False):
             goes above the limit.
             """
             global MemoryBytesSpilled, DiskBytesSpilled
    -        batch, limit = 100, self.memory_limit
    +        batch, limit = 100, self._next_limit()
             chunks, current_chunk = [], []
             iterator = iter(iterator)
             while True:
    @@ -512,10 +512,7 @@ def load(f):
                         f.close()
                     chunks.append(load(open(path, 'rb')))
                     current_chunk = []
    -                gc.collect()
    -                batch //= 2
    -                limit = self._next_limit()
    -                MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
    +                MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
                     DiskBytesSpilled += os.path.getsize(path)
                     os.unlink(path)  # data will be deleted after close
     
    @@ -630,7 +627,7 @@ def _spill(self):
             self.values = []
             gc.collect()
             DiskBytesSpilled += self._file.tell() - pos
    -        MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
    +        MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
     
     
     class ExternalListOfList(ExternalList):
    @@ -794,7 +791,7 @@ def _spill(self):
     
             self.spills += 1
             gc.collect()  # release the memory as much as possible
    -        MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
    +        MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
     
         def _merged_items(self, index):
             size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index)))
    @@ -841,4 +838,6 @@ def load_partition(j):
     
     if __name__ == "__main__":
         import doctest
    -    doctest.testmod()
    +    (failure_count, test_count) = doctest.testmod()
    +    if failure_count:
    +        exit(-1)
    diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
    index 7192c89b3dc7f..ad9c891ba1c04 100644
    --- a/python/pyspark/sql/__init__.py
    +++ b/python/pyspark/sql/__init__.py
    @@ -18,47 +18,58 @@
     """
     Important classes of Spark SQL and DataFrames:
     
    -    - L{SQLContext}
    +    - :class:`pyspark.sql.SQLContext`
           Main entry point for :class:`DataFrame` and SQL functionality.
    -    - L{DataFrame}
    +    - :class:`pyspark.sql.DataFrame`
           A distributed collection of data grouped into named columns.
    -    - L{Column}
    +    - :class:`pyspark.sql.Column`
           A column expression in a :class:`DataFrame`.
    -    - L{Row}
    +    - :class:`pyspark.sql.Row`
           A row of data in a :class:`DataFrame`.
    -    - L{HiveContext}
    +    - :class:`pyspark.sql.HiveContext`
           Main entry point for accessing data stored in Apache Hive.
    -    - L{GroupedData}
    +    - :class:`pyspark.sql.GroupedData`
           Aggregation methods, returned by :func:`DataFrame.groupBy`.
    -    - L{DataFrameNaFunctions}
    +    - :class:`pyspark.sql.DataFrameNaFunctions`
           Methods for handling missing data (null values).
    -    - L{DataFrameStatFunctions}
    +    - :class:`pyspark.sql.DataFrameStatFunctions`
           Methods for statistics functionality.
    -    - L{functions}
    +    - :class:`pyspark.sql.functions`
           List of built-in functions available for :class:`DataFrame`.
    -    - L{types}
    +    - :class:`pyspark.sql.types`
           List of data types available.
    +    - :class:`pyspark.sql.Window`
    +      For working with window functions.
     """
     from __future__ import absolute_import
     
    -# fix the module name conflict for Python 3+
    -import sys
    -from . import _types as types
    -modname = __name__ + '.types'
    -types.__name__ = modname
    -# update the __module__ for all objects, make them picklable
    -for v in types.__dict__.values():
    -    if hasattr(v, "__module__") and v.__module__.endswith('._types'):
    -        v.__module__ = modname
    -sys.modules[modname] = types
    -del modname, sys
    +
    +def since(version):
    +    """
    +    A decorator that annotates a function to append the version of Spark the function was added.
    +    """
    +    import re
    +    indent_p = re.compile(r'\n( +)')
    +
    +    def deco(f):
    +        indents = indent_p.findall(f.__doc__)
    +        indent = ' ' * (min(len(m) for m in indents) if indents else 0)
    +        f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version)
    +        return f
    +    return deco
    +
     
     from pyspark.sql.types import Row
     from pyspark.sql.context import SQLContext, HiveContext
    -from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions
    -from pyspark.sql.dataframe import DataFrameStatFunctions
    +from pyspark.sql.column import Column
    +from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions
    +from pyspark.sql.group import GroupedData
    +from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
    +from pyspark.sql.window import Window, WindowSpec
    +
     
     __all__ = [
         'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
    -    'DataFrameNaFunctions', 'DataFrameStatFunctions'
    +    'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
    +    'DataFrameReader', 'DataFrameWriter'
     ]
    diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
    new file mode 100644
    index 0000000000000..0a85da7443d3d
    --- /dev/null
    +++ b/python/pyspark/sql/column.py
    @@ -0,0 +1,430 @@
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one or more
    +# contributor license agreements.  See the NOTICE file distributed with
    +# this work for additional information regarding copyright ownership.
    +# The ASF licenses this file to You under the Apache License, Version 2.0
    +# (the "License"); you may not use this file except in compliance with
    +# the License.  You may obtain a copy of the License at
    +#
    +#    http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing, software
    +# distributed under the License is distributed on an "AS IS" BASIS,
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +# See the License for the specific language governing permissions and
    +# limitations under the License.
    +#
    +
    +import sys
    +
    +if sys.version >= '3':
    +    basestring = str
    +    long = int
    +
    +from pyspark.context import SparkContext
    +from pyspark.rdd import ignore_unicode_prefix
    +from pyspark.sql import since
    +from pyspark.sql.types import *
    +
    +__all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions",
    +           "DataFrameStatFunctions"]
    +
    +
    +def _create_column_from_literal(literal):
    +    sc = SparkContext._active_spark_context
    +    return sc._jvm.functions.lit(literal)
    +
    +
    +def _create_column_from_name(name):
    +    sc = SparkContext._active_spark_context
    +    return sc._jvm.functions.col(name)
    +
    +
    +def _to_java_column(col):
    +    if isinstance(col, Column):
    +        jcol = col._jc
    +    else:
    +        jcol = _create_column_from_name(col)
    +    return jcol
    +
    +
    +def _to_seq(sc, cols, converter=None):
    +    """
    +    Convert a list of Column (or names) into a JVM Seq of Column.
    +
    +    An optional `converter` could be used to convert items in `cols`
    +    into JVM Column objects.
    +    """
    +    if converter:
    +        cols = [converter(c) for c in cols]
    +    return sc._jvm.PythonUtils.toSeq(cols)
    +
    +
    +def _unary_op(name, doc="unary operator"):
    +    """ Create a method for given unary operator """
    +    def _(self):
    +        jc = getattr(self._jc, name)()
    +        return Column(jc)
    +    _.__doc__ = doc
    +    return _
    +
    +
    +def _func_op(name, doc=''):
    +    def _(self):
    +        sc = SparkContext._active_spark_context
    +        jc = getattr(sc._jvm.functions, name)(self._jc)
    +        return Column(jc)
    +    _.__doc__ = doc
    +    return _
    +
    +
    +def _bin_op(name, doc="binary operator"):
    +    """ Create a method for given binary operator
    +    """
    +    def _(self, other):
    +        jc = other._jc if isinstance(other, Column) else other
    +        njc = getattr(self._jc, name)(jc)
    +        return Column(njc)
    +    _.__doc__ = doc
    +    return _
    +
    +
    +def _reverse_op(name, doc="binary operator"):
    +    """ Create a method for binary operator (this object is on right side)
    +    """
    +    def _(self, other):
    +        jother = _create_column_from_literal(other)
    +        jc = getattr(jother, name)(self._jc)
    +        return Column(jc)
    +    _.__doc__ = doc
    +    return _
    +
    +
    +class Column(object):
    +
    +    """
    +    A column in a DataFrame.
    +
    +    :class:`Column` instances can be created by::
    +
    +        # 1. Select a column out of a DataFrame
    +
    +        df.colName
    +        df["colName"]
    +
    +        # 2. Create from an expression
    +        df.colName + 1
    +        1 / df.colName
    +
    +    .. note:: Experimental
    +
    +    .. versionadded:: 1.3
    +    """
    +
    +    def __init__(self, jc):
    +        self._jc = jc
    +
    +    # arithmetic operators
    +    __neg__ = _func_op("negate")
    +    __add__ = _bin_op("plus")
    +    __sub__ = _bin_op("minus")
    +    __mul__ = _bin_op("multiply")
    +    __div__ = _bin_op("divide")
    +    __truediv__ = _bin_op("divide")
    +    __mod__ = _bin_op("mod")
    +    __radd__ = _bin_op("plus")
    +    __rsub__ = _reverse_op("minus")
    +    __rmul__ = _bin_op("multiply")
    +    __rdiv__ = _reverse_op("divide")
    +    __rtruediv__ = _reverse_op("divide")
    +    __rmod__ = _reverse_op("mod")
    +
    +    # logistic operators
    +    __eq__ = _bin_op("equalTo")
    +    __ne__ = _bin_op("notEqual")
    +    __lt__ = _bin_op("lt")
    +    __le__ = _bin_op("leq")
    +    __ge__ = _bin_op("geq")
    +    __gt__ = _bin_op("gt")
    +
    +    # `and`, `or`, `not` cannot be overloaded in Python,
    +    # so use bitwise operators as boolean operators
    +    __and__ = _bin_op('and')
    +    __or__ = _bin_op('or')
    +    __invert__ = _func_op('not')
    +    __rand__ = _bin_op("and")
    +    __ror__ = _bin_op("or")
    +
    +    # container operators
    +    __contains__ = _bin_op("contains")
    +    __getitem__ = _bin_op("apply")
    +
    +    # bitwise operators
    +    bitwiseOR = _bin_op("bitwiseOR")
    +    bitwiseAND = _bin_op("bitwiseAND")
    +    bitwiseXOR = _bin_op("bitwiseXOR")
    +
    +    @since(1.3)
    +    def getItem(self, key):
    +        """
    +        An expression that gets an item at position ``ordinal`` out of a list,
    +        or gets an item by key out of a dict.
    +
    +        >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"])
    +        >>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
    +        +----+------+
    +        |l[0]|d[key]|
    +        +----+------+
    +        |   1| value|
    +        +----+------+
    +        >>> df.select(df.l[0], df.d["key"]).show()
    +        +----+------+
    +        |l[0]|d[key]|
    +        +----+------+
    +        |   1| value|
    +        +----+------+
    +        """
    +        return self[key]
    +
    +    @since(1.3)
    +    def getField(self, name):
    +        """
    +        An expression that gets a field by name in a StructField.
    +
    +        >>> from pyspark.sql import Row
    +        >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
    +        >>> df.select(df.r.getField("b")).show()
    +        +----+
    +        |r[b]|
    +        +----+
    +        |   b|
    +        +----+
    +        >>> df.select(df.r.a).show()
    +        +----+
    +        |r[a]|
    +        +----+
    +        |   1|
    +        +----+
    +        """
    +        return self[name]
    +
    +    def __getattr__(self, item):
    +        if item.startswith("__"):
    +            raise AttributeError(item)
    +        return self.getField(item)
    +
    +    # string methods
    +    rlike = _bin_op("rlike")
    +    like = _bin_op("like")
    +    startswith = _bin_op("startsWith")
    +    endswith = _bin_op("endsWith")
    +
    +    @ignore_unicode_prefix
    +    @since(1.3)
    +    def substr(self, startPos, length):
    +        """
    +        Return a :class:`Column` which is a substring of the column.
    +
    +        :param startPos: start position (int or Column)
    +        :param length:  length of the substring (int or Column)
    +
    +        >>> df.select(df.name.substr(1, 3).alias("col")).collect()
    +        [Row(col=u'Ali'), Row(col=u'Bob')]
    +        """
    +        if type(startPos) != type(length):
    +            raise TypeError("Can not mix the type")
    +        if isinstance(startPos, (int, long)):
    +            jc = self._jc.substr(startPos, length)
    +        elif isinstance(startPos, Column):
    +            jc = self._jc.substr(startPos._jc, length._jc)
    +        else:
    +            raise TypeError("Unexpected type: %s" % type(startPos))
    +        return Column(jc)
    +
    +    __getslice__ = substr
    +
    +    @ignore_unicode_prefix
    +    @since(1.3)
    +    def inSet(self, *cols):
    +        """
    +        A boolean expression that is evaluated to true if the value of this
    +        expression is contained by the evaluated values of the arguments.
    +
    +        >>> df[df.name.inSet("Bob", "Mike")].collect()
    +        [Row(age=5, name=u'Bob')]
    +        >>> df[df.age.inSet([1, 2, 3])].collect()
    +        [Row(age=2, name=u'Alice')]
    +        """
    +        if len(cols) == 1 and isinstance(cols[0], (list, set)):
    +            cols = cols[0]
    +        cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
    +        sc = SparkContext._active_spark_context
    +        jc = getattr(self._jc, "in")(_to_seq(sc, cols))
    +        return Column(jc)
    +
    +    # order
    +    asc = _unary_op("asc", "Returns a sort expression based on the"
    +                           " ascending order of the given column name.")
    +    desc = _unary_op("desc", "Returns a sort expression based on the"
    +                             " descending order of the given column name.")
    +
    +    isNull = _unary_op("isNull", "True if the current expression is null.")
    +    isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
    +
    +    @since(1.3)
    +    def alias(self, *alias):
    +        """
    +        Returns this column aliased with a new name or names (in the case of expressions that
    +        return more than one column, such as explode).
    +
    +        >>> df.select(df.age.alias("age2")).collect()
    +        [Row(age2=2), Row(age2=5)]
    +        """
    +
    +        if len(alias) == 1:
    +            return Column(getattr(self._jc, "as")(alias[0]))
    +        else:
    +            sc = SparkContext._active_spark_context
    +            return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
    +
    +    @ignore_unicode_prefix
    +    @since(1.3)
    +    def cast(self, dataType):
    +        """ Convert the column into type ``dataType``.
    +
    +        >>> df.select(df.age.cast("string").alias('ages')).collect()
    +        [Row(ages=u'2'), Row(ages=u'5')]
    +        >>> df.select(df.age.cast(StringType()).alias('ages')).collect()
    +        [Row(ages=u'2'), Row(ages=u'5')]
    +        """
    +        if isinstance(dataType, basestring):
    +            jc = self._jc.cast(dataType)
    +        elif isinstance(dataType, DataType):
    +            sc = SparkContext._active_spark_context
    +            ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
    +            jdt = ssql_ctx.parseDataType(dataType.json())
    +            jc = self._jc.cast(jdt)
    +        else:
    +            raise TypeError("unexpected type: %s" % type(dataType))
    +        return Column(jc)
    +
    +    astype = cast
    +
    +    @since(1.3)
    +    def between(self, lowerBound, upperBound):
    +        """
    +        A boolean expression that is evaluated to true if the value of this
    +        expression is between the given columns.
    +
    +        >>> df.select(df.name, df.age.between(2, 4)).show()
    +        +-----+--------------------------+
    +        | name|((age >= 2) && (age <= 4))|
    +        +-----+--------------------------+
    +        |Alice|                      true|
    +        |  Bob|                     false|
    +        +-----+--------------------------+
    +        """
    +        return (self >= lowerBound) & (self <= upperBound)
    +
    +    @since(1.4)
    +    def when(self, condition, value):
    +        """
    +        Evaluates a list of conditions and returns one of multiple possible result expressions.
    +        If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
    +
    +        See :func:`pyspark.sql.functions.when` for example usage.
    +
    +        :param condition: a boolean :class:`Column` expression.
    +        :param value: a literal value, or a :class:`Column` expression.
    +
    +        >>> from pyspark.sql import functions as F
    +        >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
    +        +-----+--------------------------------------------------------+
    +        | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0|
    +        +-----+--------------------------------------------------------+
    +        |Alice|                                                      -1|
    +        |  Bob|                                                       1|
    +        +-----+--------------------------------------------------------+
    +        """
    +        if not isinstance(condition, Column):
    +            raise TypeError("condition should be a Column")
    +        v = value._jc if isinstance(value, Column) else value
    +        jc = self._jc.when(condition._jc, v)
    +        return Column(jc)
    +
    +    @since(1.4)
    +    def otherwise(self, value):
    +        """
    +        Evaluates a list of conditions and returns one of multiple possible result expressions.
    +        If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
    +
    +        See :func:`pyspark.sql.functions.when` for example usage.
    +
    +        :param value: a literal value, or a :class:`Column` expression.
    +
    +        >>> from pyspark.sql import functions as F
    +        >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
    +        +-----+---------------------------------+
    +        | name|CASE WHEN (age > 3) THEN 1 ELSE 0|
    +        +-----+---------------------------------+
    +        |Alice|                                0|
    +        |  Bob|                                1|
    +        +-----+---------------------------------+
    +        """
    +        v = value._jc if isinstance(value, Column) else value
    +        jc = self._jc.otherwise(v)
    +        return Column(jc)
    +
    +    @since(1.4)
    +    def over(self, window):
    +        """
    +        Define a windowing column.
    +
    +        :param window: a :class:`WindowSpec`
    +        :return: a Column
    +
    +        >>> from pyspark.sql import Window
    +        >>> window = Window.partitionBy("name").orderBy("age").rowsBetween(-1, 1)
    +        >>> from pyspark.sql.functions import rank, min
    +        >>> # df.select(rank().over(window), min('age').over(window))
    +
    +        .. note:: Window functions is only supported with HiveContext in 1.4
    +        """
    +        from pyspark.sql.window import WindowSpec
    +        if not isinstance(window, WindowSpec):
    +            raise TypeError("window should be WindowSpec")
    +        jc = self._jc.over(window._jspec)
    +        return Column(jc)
    +
    +    def __nonzero__(self):
    +        raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
    +                         "'~' for 'not' when building DataFrame boolean expressions.")
    +    __bool__ = __nonzero__
    +
    +    def __repr__(self):
    +        return 'Column<%s>' % self._jc.toString().encode('utf8')
    +
    +
    +def _test():
    +    import doctest
    +    from pyspark.context import SparkContext
    +    from pyspark.sql import SQLContext
    +    import pyspark.sql.column
    +    globs = pyspark.sql.column.__dict__.copy()
    +    sc = SparkContext('local[4]', 'PythonTest')
    +    globs['sc'] = sc
    +    globs['sqlContext'] = SQLContext(sc)
    +    globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
    +        .toDF(StructType([StructField('age', IntegerType()),
    +                          StructField('name', StringType())]))
    +
    +    (failure_count, test_count) = doctest.testmod(
    +        pyspark.sql.column, globs=globs,
    +        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
    +    globs['sc'].stop()
    +    if failure_count:
    +        exit(-1)
    +
    +
    +if __name__ == "__main__":
    +    _test()
    diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
    index f6f107ca32d2f..c93a15badae29 100644
    --- a/python/pyspark/sql/context.py
    +++ b/python/pyspark/sql/context.py
    @@ -28,9 +28,12 @@
     
     from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
     from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
    +from pyspark.sql import since
     from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
    -    _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
    +    _infer_schema, _has_nulltype, _merge_type, _create_converter
     from pyspark.sql.dataframe import DataFrame
    +from pyspark.sql.readwriter import DataFrameReader
    +from pyspark.sql.utils import install_exception_handler
     
     try:
         import pandas
    @@ -84,7 +87,8 @@ def __init__(self, sparkContext, sqlContext=None):
             >>> df.registerTempTable("allTypes")
             >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
             ...            'from allTypes where b and i > 0').collect()
    -        [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
    +        [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
    +            time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
             >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
             [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
             """
    @@ -93,6 +97,7 @@ def __init__(self, sparkContext, sqlContext=None):
             self._jvm = self._sc._jvm
             self._scala_SQLContext = sqlContext
             _monkey_patch_RDD(self)
    +        install_exception_handler()
     
         @property
         def _ssql_ctx(self):
    @@ -105,11 +110,13 @@ def _ssql_ctx(self):
                 self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
             return self._scala_SQLContext
     
    +    @since(1.3)
         def setConf(self, key, value):
             """Sets the given Spark SQL configuration property.
             """
             self._ssql_ctx.setConf(key, value)
     
    +    @since(1.3)
         def getConf(self, key, defaultValue):
             """Returns the value of Spark SQL configuration property for the given key.
     
    @@ -118,11 +125,47 @@ def getConf(self, key, defaultValue):
             return self._ssql_ctx.getConf(key, defaultValue)
     
         @property
    +    @since("1.3.1")
         def udf(self):
    -        """Returns a :class:`UDFRegistration` for UDF registration."""
    +        """Returns a :class:`UDFRegistration` for UDF registration.
    +
    +        :return: :class:`UDFRegistration`
    +        """
             return UDFRegistration(self)
     
    +    @since(1.4)
    +    def range(self, start, end=None, step=1, numPartitions=None):
    +        """
    +        Create a :class:`DataFrame` with single LongType column named `id`,
    +        containing elements in a range from `start` to `end` (exclusive) with
    +        step value `step`.
    +
    +        :param start: the start value
    +        :param end: the end value (exclusive)
    +        :param step: the incremental step (default: 1)
    +        :param numPartitions: the number of partitions of the DataFrame
    +        :return: :class:`DataFrame`
    +
    +        >>> sqlContext.range(1, 7, 2).collect()
    +        [Row(id=1), Row(id=3), Row(id=5)]
    +
    +        If only one argument is specified, it will be used as the end value.
    +
    +        >>> sqlContext.range(3).collect()
    +        [Row(id=0), Row(id=1), Row(id=2)]
    +        """
    +        if numPartitions is None:
    +            numPartitions = self._sc.defaultParallelism
    +
    +        if end is None:
    +            jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions))
    +        else:
    +            jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
    +
    +        return DataFrame(jdf, self)
    +
         @ignore_unicode_prefix
    +    @since(1.2)
         def registerFunction(self, name, f, returnType=StringType()):
             """Registers a lambda function as a UDF so it can be used in SQL statements.
     
    @@ -136,17 +179,17 @@ def registerFunction(self, name, f, returnType=StringType()):
     
             >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
             >>> sqlContext.sql("SELECT stringLengthString('test')").collect()
    -        [Row(c0=u'4')]
    +        [Row(_c0=u'4')]
     
             >>> from pyspark.sql.types import IntegerType
             >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
             >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
    -        [Row(c0=4)]
    +        [Row(_c0=4)]
     
             >>> from pyspark.sql.types import IntegerType
             >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
             >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
    -        [Row(c0=4)]
    +        [Row(_c0=4)]
             """
             func = lambda _, it: map(lambda x: f(*x), it)
             ser = AutoBatchedSerializer(PickleSerializer())
    @@ -157,18 +200,49 @@ def registerFunction(self, name, f, returnType=StringType()):
                                                 env,
                                                 includes,
                                                 self._sc.pythonExec,
    +                                            self._sc.pythonVer,
                                                 bvars,
                                                 self._sc._javaAccumulator,
                                                 returnType.json())
     
    +    def _inferSchemaFromList(self, data):
    +        """
    +        Infer schema from list of Row or tuple.
    +
    +        :param data: list of Row or tuple
    +        :return: StructType
    +        """
    +        if not data:
    +            raise ValueError("can not infer schema from empty dataset")
    +        first = data[0]
    +        if type(first) is dict:
    +            warnings.warn("inferring schema from dict is deprecated,"
    +                          "please use pyspark.sql.Row instead")
    +        schema = _infer_schema(first)
    +        if _has_nulltype(schema):
    +            for r in data:
    +                schema = _merge_type(schema, _infer_schema(r))
    +                if not _has_nulltype(schema):
    +                    break
    +            else:
    +                raise ValueError("Some of types cannot be determined after inferring")
    +        return schema
    +
         def _inferSchema(self, rdd, samplingRatio=None):
    +        """
    +        Infer schema from an RDD of Row or tuple.
    +
    +        :param rdd: an RDD of Row or tuple
    +        :param samplingRatio: sampling ratio, or no sampling (default)
    +        :return: StructType
    +        """
             first = rdd.first()
             if not first:
                 raise ValueError("The first row in RDD is empty, "
                                  "can not infer schema")
             if type(first) is dict:
    -            warnings.warn("Using RDD of dict to inferSchema is deprecated,"
    -                          "please use pyspark.sql.Row instead")
    +            warnings.warn("Using RDD of dict to inferSchema is deprecated. "
    +                          "Use pyspark.sql.Row instead")
     
             if samplingRatio is None:
                 schema = _infer_schema(first)
    @@ -188,9 +262,10 @@ def _inferSchema(self, rdd, samplingRatio=None):
     
         @ignore_unicode_prefix
         def inferSchema(self, rdd, samplingRatio=None):
    -        """::note: Deprecated in 1.3, use :func:`createDataFrame` instead.
             """
    -        warnings.warn("inferSchema is deprecated, please use createDataFrame instead")
    +        .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead.
    +        """
    +        warnings.warn("inferSchema is deprecated, please use createDataFrame instead.")
     
             if isinstance(rdd, DataFrame):
                 raise TypeError("Cannot apply schema to DataFrame")
    @@ -199,7 +274,8 @@ def inferSchema(self, rdd, samplingRatio=None):
     
         @ignore_unicode_prefix
         def applySchema(self, rdd, schema):
    -        """::note: Deprecated in 1.3, use :func:`createDataFrame` instead.
    +        """
    +        .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead.
             """
             warnings.warn("applySchema is deprecated, please use createDataFrame instead")
     
    @@ -211,6 +287,7 @@ def applySchema(self, rdd, schema):
     
             return self.createDataFrame(rdd, schema)
     
    +    @since(1.3)
         @ignore_unicode_prefix
         def createDataFrame(self, data, schema=None, samplingRatio=None):
             """
    @@ -231,6 +308,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
                 :class:`list`, or :class:`pandas.DataFrame`.
             :param schema: a :class:`StructType` or list of column names. default None.
             :param samplingRatio: the sample ratio of rows used for inferring
    +        :return: :class:`DataFrame`
     
             >>> l = [('Alice', 1)]
             >>> sqlContext.createDataFrame(l).collect()
    @@ -266,16 +344,20 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
     
             >>> sqlContext.createDataFrame(df.toPandas()).collect()  # doctest: +SKIP
             [Row(name=u'Alice', age=1)]
    +        >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]]).collect())  # doctest: +SKIP
    +        [Row(0=1, 1=2)]
             """
             if isinstance(data, DataFrame):
                 raise TypeError("data is already a DataFrame")
     
             if has_pandas and isinstance(data, pandas.DataFrame):
                 if schema is None:
    -                schema = list(data.columns)
    +                schema = [str(x) for x in data.columns]
                 data = [r.tolist() for r in data.to_records(index=False)]
     
             if not isinstance(data, RDD):
    +            if not isinstance(data, list):
    +                data = list(data)
                 try:
                     # data could be list, tuple, generator ...
                     rdd = self._sc.parallelize(data)
    @@ -284,37 +366,35 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
             else:
                 rdd = data
     
    -        if schema is None:
    -            schema = self._inferSchema(rdd, samplingRatio)
    +        if schema is None or isinstance(schema, (list, tuple)):
    +            if isinstance(data, RDD):
    +                struct = self._inferSchema(rdd, samplingRatio)
    +            else:
    +                struct = self._inferSchemaFromList(data)
    +            if isinstance(schema, (list, tuple)):
    +                for i, name in enumerate(schema):
    +                    struct.fields[i].name = name
    +            schema = struct
                 converter = _create_converter(schema)
                 rdd = rdd.map(converter)
     
    -        if isinstance(schema, (list, tuple)):
    -            first = rdd.first()
    -            if not isinstance(first, (list, tuple)):
    -                raise TypeError("each row in `rdd` should be list or tuple, "
    -                                "but got %r" % type(first))
    -            row_cls = Row(*schema)
    -            schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio)
    -
    -        # take the first few rows to verify schema
    -        rows = rdd.take(10)
    -        # Row() cannot been deserialized by Pyrolite
    -        if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
    -            rdd = rdd.map(tuple)
    +        elif isinstance(schema, StructType):
    +            # take the first few rows to verify schema
                 rows = rdd.take(10)
    +            for row in rows:
    +                _verify_type(row, schema)
     
    -        for row in rows:
    -            _verify_type(row, schema)
    +        else:
    +            raise TypeError("schema should be StructType or list or None")
     
             # convert python objects to sql data
    -        converter = _python_to_sql_converter(schema)
    -        rdd = rdd.map(converter)
    +        rdd = rdd.map(schema.toInternal)
     
             jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
             df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
             return DataFrame(df, self)
     
    +    @since(1.3)
         def registerDataFrameAsTable(self, df, tableName):
             """Registers the given :class:`DataFrame` as a temporary table in the catalog.
     
    @@ -330,14 +410,12 @@ def registerDataFrameAsTable(self, df, tableName):
         def parquetFile(self, *paths):
             """Loads a Parquet file, returning the result as a :class:`DataFrame`.
     
    -        >>> import tempfile, shutil
    -        >>> parquetFile = tempfile.mkdtemp()
    -        >>> shutil.rmtree(parquetFile)
    -        >>> df.saveAsParquetFile(parquetFile)
    -        >>> df2 = sqlContext.parquetFile(parquetFile)
    -        >>> sorted(df.collect()) == sorted(df2.collect())
    -        True
    +        .. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead.
    +
    +        >>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes
    +        [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
             """
    +        warnings.warn("parquetFile is deprecated. Use read.parquet() instead.")
             gateway = self._sc._gateway
             jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths))
             for i in range(0, len(paths)):
    @@ -348,35 +426,12 @@ def parquetFile(self, *paths):
         def jsonFile(self, path, schema=None, samplingRatio=1.0):
             """Loads a text file storing one JSON object per line as a :class:`DataFrame`.
     
    -        If the schema is provided, applies the given schema to this JSON dataset.
    -        Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema.
    +        .. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead.
     
    -        >>> import tempfile, shutil
    -        >>> jsonFile = tempfile.mkdtemp()
    -        >>> shutil.rmtree(jsonFile)
    -        >>> with open(jsonFile, 'w') as f:
    -        ...     f.writelines(jsonStrings)
    -        >>> df1 = sqlContext.jsonFile(jsonFile)
    -        >>> df1.printSchema()
    -        root
    -         |-- field1: long (nullable = true)
    -         |-- field2: string (nullable = true)
    -         |-- field3: struct (nullable = true)
    -         |    |-- field4: long (nullable = true)
    -
    -        >>> from pyspark.sql.types import *
    -        >>> schema = StructType([
    -        ...     StructField("field2", StringType()),
    -        ...     StructField("field3",
    -        ...         StructType([StructField("field5", ArrayType(IntegerType()))]))])
    -        >>> df2 = sqlContext.jsonFile(jsonFile, schema)
    -        >>> df2.printSchema()
    -        root
    -         |-- field2: string (nullable = true)
    -         |-- field3: struct (nullable = true)
    -         |    |-- field5: array (nullable = true)
    -         |    |    |-- element: integer (containsNull = true)
    +        >>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes
    +        [('age', 'bigint'), ('name', 'string')]
             """
    +        warnings.warn("jsonFile is deprecated. Use read.json() instead.")
             if schema is None:
                 df = self._ssql_ctx.jsonFile(path, samplingRatio)
             else:
    @@ -385,6 +440,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0):
             return DataFrame(df, self)
     
         @ignore_unicode_prefix
    +    @since(1.0)
         def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
             """Loads an RDD storing one JSON object per string as a :class:`DataFrame`.
     
    @@ -430,28 +486,13 @@ def func(iterator):
         def load(self, path=None, source=None, schema=None, **options):
             """Returns the dataset in a data source as a :class:`DataFrame`.
     
    -        The data source is specified by the ``source`` and a set of ``options``.
    -        If ``source`` is not specified, the default data source configured by
    -        ``spark.sql.sources.default`` will be used.
    -
    -        Optionally, a schema can be provided as the schema of the returned DataFrame.
    +        .. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead.
             """
    -        if path is not None:
    -            options["path"] = path
    -        if source is None:
    -            source = self.getConf("spark.sql.sources.default",
    -                                  "org.apache.spark.sql.parquet")
    -        if schema is None:
    -            df = self._ssql_ctx.load(source, options)
    -        else:
    -            if not isinstance(schema, StructType):
    -                raise TypeError("schema should be StructType")
    -            scala_datatype = self._ssql_ctx.parseDataType(schema.json())
    -            df = self._ssql_ctx.load(source, scala_datatype, options)
    -        return DataFrame(df, self)
    +        warnings.warn("load is deprecated. Use read.load() instead.")
    +        return self.read.load(path, source, schema, **options)
     
    -    def createExternalTable(self, tableName, path=None, source=None,
    -                            schema=None, **options):
    +    @since(1.3)
    +    def createExternalTable(self, tableName, path=None, source=None, schema=None, **options):
             """Creates an external table based on the dataset in a data source.
     
             It returns the DataFrame associated with the external table.
    @@ -462,6 +503,8 @@ def createExternalTable(self, tableName, path=None, source=None,
     
             Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
             created external table.
    +
    +        :return: :class:`DataFrame`
             """
             if path is not None:
                 options["path"] = path
    @@ -479,9 +522,12 @@ def createExternalTable(self, tableName, path=None, source=None,
             return DataFrame(df, self)
     
         @ignore_unicode_prefix
    +    @since(1.0)
         def sql(self, sqlQuery):
             """Returns a :class:`DataFrame` representing the result of the given query.
     
    +        :return: :class:`DataFrame`
    +
             >>> sqlContext.registerDataFrameAsTable(df, "table1")
             >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1")
             >>> df2.collect()
    @@ -489,9 +535,12 @@ def sql(self, sqlQuery):
             """
             return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
     
    +    @since(1.0)
         def table(self, tableName):
             """Returns the specified table as a :class:`DataFrame`.
     
    +        :return: :class:`DataFrame`
    +
             >>> sqlContext.registerDataFrameAsTable(df, "table1")
             >>> df2 = sqlContext.table("table1")
             >>> sorted(df.collect()) == sorted(df2.collect())
    @@ -500,6 +549,7 @@ def table(self, tableName):
             return DataFrame(self._ssql_ctx.table(tableName), self)
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def tables(self, dbName=None):
             """Returns a :class:`DataFrame` containing names of tables in the given database.
     
    @@ -508,6 +558,9 @@ def tables(self, dbName=None):
             The returned DataFrame has two columns: ``tableName`` and ``isTemporary``
             (a column with :class:`BooleanType` indicating if a table is a temporary one or not).
     
    +        :param dbName: string, name of the database to use.
    +        :return: :class:`DataFrame`
    +
             >>> sqlContext.registerDataFrameAsTable(df, "table1")
             >>> df2 = sqlContext.tables()
             >>> df2.filter("tableName = 'table1'").first()
    @@ -518,10 +571,12 @@ def tables(self, dbName=None):
             else:
                 return DataFrame(self._ssql_ctx.tables(dbName), self)
     
    +    @since(1.3)
         def tableNames(self, dbName=None):
             """Returns a list of names of tables in the database ``dbName``.
     
    -        If ``dbName`` is not specified, the current database will be used.
    +        :param dbName: string, name of the database to use. Default to the current database.
    +        :return: list of table names, in string
     
             >>> sqlContext.registerDataFrameAsTable(df, "table1")
             >>> "table1" in sqlContext.tableNames()
    @@ -534,18 +589,32 @@ def tableNames(self, dbName=None):
             else:
                 return [name for name in self._ssql_ctx.tableNames(dbName)]
     
    +    @since(1.0)
         def cacheTable(self, tableName):
             """Caches the specified table in-memory."""
             self._ssql_ctx.cacheTable(tableName)
     
    +    @since(1.0)
         def uncacheTable(self, tableName):
             """Removes the specified table from the in-memory cache."""
             self._ssql_ctx.uncacheTable(tableName)
     
    +    @since(1.3)
         def clearCache(self):
             """Removes all cached tables from the in-memory cache. """
             self._ssql_ctx.clearCache()
     
    +    @property
    +    @since(1.4)
    +    def read(self):
    +        """
    +        Returns a :class:`DataFrameReader` that can be used to read data
    +        in as a :class:`DataFrame`.
    +
    +        :return: :class:`DataFrameReader`
    +        """
    +        return DataFrameReader(self)
    +
     
     class HiveContext(SQLContext):
         """A variant of Spark SQL that integrates with data stored in Hive.
    @@ -600,10 +669,14 @@ def register(self, name, f, returnType=StringType()):
     
     
     def _test():
    +    import os
         import doctest
         from pyspark.context import SparkContext
         from pyspark.sql import Row, SQLContext
         import pyspark.sql.context
    +
    +    os.chdir(os.environ["SPARK_HOME"])
    +
         globs = pyspark.sql.context.__dict__.copy()
         sc = SparkContext('local[4]', 'PythonTest')
         globs['sc'] = sc
    diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
    index 82cb1c2fdbf94..83e02b85f06f1 100644
    --- a/python/pyspark/sql/dataframe.py
    +++ b/python/pyspark/sql/dataframe.py
    @@ -22,20 +22,21 @@
     if sys.version >= '3':
         basestring = unicode = str
         long = int
    +    from functools import reduce
     else:
         from itertools import imap as map
     
    -from pyspark.context import SparkContext
     from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
     from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
     from pyspark.storagelevel import StorageLevel
     from pyspark.traceback_utils import SCCallSiteSync
    +from pyspark.sql import since
    +from pyspark.sql.types import _parse_datatype_json_string
    +from pyspark.sql.column import Column, _to_seq, _to_java_column
    +from pyspark.sql.readwriter import DataFrameWriter
     from pyspark.sql.types import *
    -from pyspark.sql.types import _create_cls, _parse_datatype_json_string
     
    -
    -__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions",
    -           "DataFrameStatFunctions"]
    +__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"]
     
     
     class DataFrame(object):
    @@ -44,7 +45,7 @@ class DataFrame(object):
         A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
         and can be created using various functions in :class:`SQLContext`::
     
    -        people = sqlContext.parquetFile("...")
    +        people = sqlContext.read.parquet("...")
     
         Once created, it can be manipulated using the various domain-specific-language
         (DSL) functions defined in: :class:`DataFrame`, :class:`Column`.
    @@ -56,11 +57,15 @@ class DataFrame(object):
         A more concrete example::
     
             # To create DataFrame using SQLContext
    -        people = sqlContext.parquetFile("...")
    -        department = sqlContext.parquetFile("...")
    +        people = sqlContext.read.parquet("...")
    +        department = sqlContext.read.parquet("...")
     
             people.filter(people.age > 30).join(department, people.deptId == department.id)) \
               .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
    +
    +    .. note:: Experimental
    +
    +    .. versionadded:: 1.3
         """
     
         def __init__(self, jdf, sql_ctx):
    @@ -72,35 +77,31 @@ def __init__(self, jdf, sql_ctx):
             self._lazy_rdd = None
     
         @property
    +    @since(1.3)
         def rdd(self):
             """Returns the content as an :class:`pyspark.RDD` of :class:`Row`.
             """
             if self._lazy_rdd is None:
                 jrdd = self._jdf.javaToPython()
    -            rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
    -            schema = self.schema
    -
    -            def applySchema(it):
    -                cls = _create_cls(schema)
    -                return map(cls, it)
    -
    -            self._lazy_rdd = rdd.mapPartitions(applySchema)
    -
    +            self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
             return self._lazy_rdd
     
         @property
    +    @since("1.3.1")
         def na(self):
             """Returns a :class:`DataFrameNaFunctions` for handling missing values.
             """
             return DataFrameNaFunctions(self)
     
         @property
    +    @since(1.4)
         def stat(self):
             """Returns a :class:`DataFrameStatFunctions` for statistic functions.
             """
             return DataFrameStatFunctions(self)
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def toJSON(self, use_unicode=True):
             """Converts a :class:`DataFrame` into a :class:`RDD` of string.
     
    @@ -115,19 +116,12 @@ def toJSON(self, use_unicode=True):
         def saveAsParquetFile(self, path):
             """Saves the contents as a Parquet file, preserving the schema.
     
    -        Files that are written out using this method can be read back in as
    -        a :class:`DataFrame` using :func:`SQLContext.parquetFile`.
    -
    -        >>> import tempfile, shutil
    -        >>> parquetFile = tempfile.mkdtemp()
    -        >>> shutil.rmtree(parquetFile)
    -        >>> df.saveAsParquetFile(parquetFile)
    -        >>> df2 = sqlContext.parquetFile(parquetFile)
    -        >>> sorted(df2.collect()) == sorted(df.collect())
    -        True
    +        .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.parquet` instead.
             """
    +        warnings.warn("saveAsParquetFile is deprecated. Use write.parquet() instead.")
             self._jdf.saveAsParquetFile(path)
     
    +    @since(1.3)
         def registerTempTable(self, name):
             """Registers this RDD as a temporary table using the given name.
     
    @@ -142,81 +136,49 @@ def registerTempTable(self, name):
             self._jdf.registerTempTable(name)
     
         def registerAsTable(self, name):
    -        """DEPRECATED: use :func:`registerTempTable` instead"""
    -        warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning)
    +        """
    +        .. note:: Deprecated in 1.4, use :func:`registerTempTable` instead.
    +        """
    +        warnings.warn("Use registerTempTable instead of registerAsTable.")
             self.registerTempTable(name)
     
         def insertInto(self, tableName, overwrite=False):
             """Inserts the contents of this :class:`DataFrame` into the specified table.
     
    -        Optionally overwriting any existing data.
    +        .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.insertInto` instead.
             """
    -        self._jdf.insertInto(tableName, overwrite)
    -
    -    def _java_save_mode(self, mode):
    -        """Returns the Java save mode based on the Python save mode represented by a string.
    -        """
    -        jSaveMode = self._sc._jvm.org.apache.spark.sql.SaveMode
    -        jmode = jSaveMode.ErrorIfExists
    -        mode = mode.lower()
    -        if mode == "append":
    -            jmode = jSaveMode.Append
    -        elif mode == "overwrite":
    -            jmode = jSaveMode.Overwrite
    -        elif mode == "ignore":
    -            jmode = jSaveMode.Ignore
    -        elif mode == "error":
    -            pass
    -        else:
    -            raise ValueError(
    -                "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
    -        return jmode
    +        warnings.warn("insertInto is deprecated. Use write.insertInto() instead.")
    +        self.write.insertInto(tableName, overwrite)
     
         def saveAsTable(self, tableName, source=None, mode="error", **options):
             """Saves the contents of this :class:`DataFrame` to a data source as a table.
     
    -        The data source is specified by the ``source`` and a set of ``options``.
    -        If ``source`` is not specified, the default data source configured by
    -        ``spark.sql.sources.default`` will be used.
    -
    -        Additionally, mode is used to specify the behavior of the saveAsTable operation when
    -        table already exists in the data source. There are four modes:
    -
    -        * `append`: Append contents of this :class:`DataFrame` to existing data.
    -        * `overwrite`: Overwrite existing data.
    -        * `error`: Throw an exception if data already exists.
    -        * `ignore`: Silently ignore this operation if data already exists.
    +        .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.saveAsTable` instead.
             """
    -        if source is None:
    -            source = self.sql_ctx.getConf("spark.sql.sources.default",
    -                                          "org.apache.spark.sql.parquet")
    -        jmode = self._java_save_mode(mode)
    -        self._jdf.saveAsTable(tableName, source, jmode, options)
    +        warnings.warn("insertInto is deprecated. Use write.saveAsTable() instead.")
    +        self.write.saveAsTable(tableName, source, mode, **options)
     
    +    @since(1.3)
         def save(self, path=None, source=None, mode="error", **options):
             """Saves the contents of the :class:`DataFrame` to a data source.
     
    -        The data source is specified by the ``source`` and a set of ``options``.
    -        If ``source`` is not specified, the default data source configured by
    -        ``spark.sql.sources.default`` will be used.
    +        .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.save` instead.
    +        """
    +        warnings.warn("insertInto is deprecated. Use write.save() instead.")
    +        return self.write.save(path, source, mode, **options)
     
    -        Additionally, mode is used to specify the behavior of the save operation when
    -        data already exists in the data source. There are four modes:
    +    @property
    +    @since(1.4)
    +    def write(self):
    +        """
    +        Interface for saving the content of the :class:`DataFrame` out into external storage.
     
    -        * `append`: Append contents of this :class:`DataFrame` to existing data.
    -        * `overwrite`: Overwrite existing data.
    -        * `error`: Throw an exception if data already exists.
    -        * `ignore`: Silently ignore this operation if data already exists.
    +        :return: :class:`DataFrameWriter`
             """
    -        if path is not None:
    -            options["path"] = path
    -        if source is None:
    -            source = self.sql_ctx.getConf("spark.sql.sources.default",
    -                                          "org.apache.spark.sql.parquet")
    -        jmode = self._java_save_mode(mode)
    -        self._jdf.save(source, jmode, options)
    +        return DataFrameWriter(self)
     
         @property
    +    @since(1.3)
         def schema(self):
             """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`.
     
    @@ -224,9 +186,14 @@ def schema(self):
             StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
             """
             if self._schema is None:
    -            self._schema = _parse_datatype_json_string(self._jdf.schema().json())
    +            try:
    +                self._schema = _parse_datatype_json_string(self._jdf.schema().json())
    +            except AttributeError as e:
    +                raise Exception(
    +                    "Unable to parse datatype from schema. %s" % e)
             return self._schema
     
    +    @since(1.3)
         def printSchema(self):
             """Prints out the schema in the tree format.
     
    @@ -238,6 +205,7 @@ def printSchema(self):
             """
             print(self._jdf.schema().treeString())
     
    +    @since(1.3)
         def explain(self, extended=False):
             """Prints the (logical and physical) plans to the console for debugging purpose.
     
    @@ -263,15 +231,20 @@ def explain(self, extended=False):
             else:
                 print(self._jdf.queryExecution().executedPlan().toString())
     
    +    @since(1.3)
         def isLocal(self):
             """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally
             (without any Spark executors).
             """
             return self._jdf.isLocal()
     
    -    def show(self, n=20):
    +    @since(1.3)
    +    def show(self, n=20, truncate=True):
             """Prints the first ``n`` rows to the console.
     
    +        :param n: Number of rows to show.
    +        :param truncate: Whether truncate long strings and align cells right.
    +
             >>> df
             DataFrame[age: int, name: string]
             >>> df.show()
    @@ -282,11 +255,12 @@ def show(self, n=20):
             |  5|  Bob|
             +---+-----+
             """
    -        print(self._jdf.showString(n))
    +        print(self._jdf.showString(n, truncate))
     
         def __repr__(self):
             return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
     
    +    @since(1.3)
         def count(self):
             """Returns the number of rows in this :class:`DataFrame`.
     
    @@ -296,6 +270,7 @@ def count(self):
             return int(self._jdf.count())
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def collect(self):
             """Returns all the records as a list of :class:`Row`.
     
    @@ -304,11 +279,10 @@ def collect(self):
             """
             with SCCallSiteSync(self._sc) as css:
                 port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd())
    -        rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
    -        cls = _create_cls(self.schema)
    -        return [cls(r) for r in rs]
    +        return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def limit(self, num):
             """Limits the result count to the number specified.
     
    @@ -321,6 +295,7 @@ def limit(self, num):
             return DataFrame(jdf, self.sql_ctx)
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def take(self, num):
             """Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
     
    @@ -330,6 +305,7 @@ def take(self, num):
             return self.limit(num).collect()
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def map(self, f):
             """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`.
     
    @@ -341,6 +317,7 @@ def map(self, f):
             return self.rdd.map(f)
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def flatMap(self, f):
             """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`,
             and then flattening the results.
    @@ -352,6 +329,7 @@ def flatMap(self, f):
             """
             return self.rdd.flatMap(f)
     
    +    @since(1.3)
         def mapPartitions(self, f, preservesPartitioning=False):
             """Returns a new :class:`RDD` by applying the ``f`` function to each partition.
     
    @@ -364,6 +342,7 @@ def mapPartitions(self, f, preservesPartitioning=False):
             """
             return self.rdd.mapPartitions(f, preservesPartitioning)
     
    +    @since(1.3)
         def foreach(self, f):
             """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`.
     
    @@ -375,6 +354,7 @@ def foreach(self, f):
             """
             return self.rdd.foreach(f)
     
    +    @since(1.3)
         def foreachPartition(self, f):
             """Applies the ``f`` function to each partition of this :class:`DataFrame`.
     
    @@ -387,6 +367,7 @@ def foreachPartition(self, f):
             """
             return self.rdd.foreachPartition(f)
     
    +    @since(1.3)
         def cache(self):
             """ Persists with the default storage level (C{MEMORY_ONLY_SER}).
             """
    @@ -394,6 +375,7 @@ def cache(self):
             self._jdf.cache()
             return self
     
    +    @since(1.3)
         def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
             """Sets the storage level to persist its values across operations
             after the first time it is computed. This can only be used to assign
    @@ -405,6 +387,7 @@ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
             self._jdf.persist(javaStorageLevel)
             return self
     
    +    @since(1.3)
         def unpersist(self, blocking=True):
             """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from
             memory and disk.
    @@ -413,10 +396,22 @@ def unpersist(self, blocking=True):
             self._jdf.unpersist(blocking)
             return self
     
    -    # def coalesce(self, numPartitions, shuffle=False):
    -    #     rdd = self._jdf.coalesce(numPartitions, shuffle, None)
    -    #     return DataFrame(rdd, self.sql_ctx)
    +    @since(1.4)
    +    def coalesce(self, numPartitions):
    +        """
    +        Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions.
    +
    +        Similar to coalesce defined on an :class:`RDD`, this operation results in a
    +        narrow dependency, e.g. if you go from 1000 partitions to 100 partitions,
    +        there will not be a shuffle, instead each of the 100 new partitions will
    +        claim 10 of the current partitions.
    +
    +        >>> df.coalesce(1).rdd.getNumPartitions()
    +        1
    +        """
    +        return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx)
     
    +    @since(1.3)
         def repartition(self, numPartitions):
             """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions.
     
    @@ -425,6 +420,7 @@ def repartition(self, numPartitions):
             """
             return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx)
     
    +    @since(1.3)
         def distinct(self):
             """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.
     
    @@ -433,6 +429,7 @@ def distinct(self):
             """
             return DataFrame(self._jdf.distinct(), self.sql_ctx)
     
    +    @since(1.3)
         def sample(self, withReplacement, fraction, seed=None):
             """Returns a sampled subset of this :class:`DataFrame`.
     
    @@ -444,6 +441,7 @@ def sample(self, withReplacement, fraction, seed=None):
             rdd = self._jdf.sample(withReplacement, fraction, long(seed))
             return DataFrame(rdd, self.sql_ctx)
     
    +    @since(1.4)
         def randomSplit(self, weights, seed=None):
             """Randomly splits this :class:`DataFrame` with the provided weights.
     
    @@ -466,6 +464,7 @@ def randomSplit(self, weights, seed=None):
             return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
     
         @property
    +    @since(1.3)
         def dtypes(self):
             """Returns all column names and their data types as a list.
     
    @@ -475,16 +474,17 @@ def dtypes(self):
             return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]
     
         @property
    -    @ignore_unicode_prefix
    +    @since(1.3)
         def columns(self):
             """Returns all column names as a list.
     
             >>> df.columns
    -        [u'age', u'name']
    +        ['age', 'name']
             """
             return [f.name for f in self.schema.fields]
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def alias(self, alias):
             """Returns a new :class:`DataFrame` with an alias set.
     
    @@ -499,39 +499,57 @@ def alias(self, alias):
             return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx)
     
         @ignore_unicode_prefix
    -    def join(self, other, joinExprs=None, joinType=None):
    +    @since(1.3)
    +    def join(self, other, on=None, how=None):
             """Joins with another :class:`DataFrame`, using the given join expression.
     
             The following performs a full outer join between ``df1`` and ``df2``.
     
             :param other: Right side of the join
    -        :param joinExprs: a string for join column name, or a join expression (Column).
    -            If joinExprs is a string indicating the name of the join column,
    -            the column must exist on both sides, and this performs an inner equi-join.
    -        :param joinType: str, default 'inner'.
    +        :param on: a string for join column name, a list of column names,
    +            , a join expression (Column) or a list of Columns.
    +            If `on` is a string or a list of string indicating the name of the join column(s),
    +            the column(s) must exist on both sides, and this performs an inner equi-join.
    +        :param how: str, default 'inner'.
                 One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
     
             >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
             [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
     
    +        >>> cond = [df.name == df3.name, df.age == df3.age]
    +        >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect()
    +        [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)]
    +
             >>> df.join(df2, 'name').select(df.name, df2.height).collect()
             [Row(name=u'Bob', height=85)]
    +
    +        >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect()
    +        [Row(name=u'Bob', age=5)]
             """
     
    -        if joinExprs is None:
    +        if on is not None and not isinstance(on, list):
    +            on = [on]
    +
    +        if on is None or len(on) == 0:
                 jdf = self._jdf.join(other._jdf)
    -        elif isinstance(joinExprs, basestring):
    -            jdf = self._jdf.join(other._jdf, joinExprs)
    +
    +        if isinstance(on[0], basestring):
    +            jdf = self._jdf.join(other._jdf, self._jseq(on))
             else:
    -            assert isinstance(joinExprs, Column), "joinExprs should be Column"
    -            if joinType is None:
    -                jdf = self._jdf.join(other._jdf, joinExprs._jc)
    +            assert isinstance(on[0], Column), "on should be Column or list of Column"
    +            if len(on) > 1:
    +                on = reduce(lambda x, y: x.__and__(y), on)
    +            else:
    +                on = on[0]
    +            if how is None:
    +                jdf = self._jdf.join(other._jdf, on._jc, "inner")
                 else:
    -                assert isinstance(joinType, basestring), "joinType should be basestring"
    -                jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
    +                assert isinstance(how, basestring), "how should be basestring"
    +                jdf = self._jdf.join(other._jdf, on._jc, how)
             return DataFrame(jdf, self.sql_ctx)
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def sort(self, *cols, **kwargs):
             """Returns a new :class:`DataFrame` sorted by the specified column(s).
     
    @@ -591,12 +609,16 @@ def _jcols(self, *cols):
                 cols = cols[0]
             return self._jseq(cols, _to_java_column)
     
    +    @since("1.3.1")
         def describe(self, *cols):
             """Computes statistics for numeric columns.
     
             This include count, mean, stddev, min, and max. If no columns are
             given, this function computes statistics for all numerical columns.
     
    +        .. note:: This function is meant for exploratory data analysis, as we make no \
    +        guarantee about the backward compatibility of the schema of the resulting DataFrame.
    +
             >>> df.describe().show()
             +-------+---+
             |summary|age|
    @@ -607,15 +629,30 @@ def describe(self, *cols):
             |    min|  2|
             |    max|  5|
             +-------+---+
    +        >>> df.describe(['age', 'name']).show()
    +        +-------+---+-----+
    +        |summary|age| name|
    +        +-------+---+-----+
    +        |  count|  2|    2|
    +        |   mean|3.5| null|
    +        | stddev|1.5| null|
    +        |    min|  2|Alice|
    +        |    max|  5|  Bob|
    +        +-------+---+-----+
             """
    +        if len(cols) == 1 and isinstance(cols[0], list):
    +            cols = cols[0]
             jdf = self._jdf.describe(self._jseq(cols))
             return DataFrame(jdf, self.sql_ctx)
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def head(self, n=None):
    -        """
    -        Returns the first ``n`` rows as a list of :class:`Row`,
    -        or the first :class:`Row` if ``n`` is ``None.``
    +        """Returns the first ``n`` rows.
    +
    +        :param n: int, default 1. Number of rows to return.
    +        :return: If n is greater than 1, return a list of :class:`Row`.
    +            If n is 1, return a single Row.
     
             >>> df.head()
             Row(age=2, name=u'Alice')
    @@ -628,6 +665,7 @@ def head(self, n=None):
             return self.take(n)
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def first(self):
             """Returns the first row as a :class:`Row`.
     
    @@ -637,6 +675,7 @@ def first(self):
             return self.head()
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def __getitem__(self, item):
             """Returns the column as a :class:`Column`.
     
    @@ -664,6 +703,7 @@ def __getitem__(self, item):
             else:
                 raise TypeError("unexpected item type: %s" % type(item))
     
    +    @since(1.3)
         def __getattr__(self, name):
             """Returns the :class:`Column` denoted by ``name``.
     
    @@ -677,6 +717,7 @@ def __getattr__(self, name):
             return Column(jc)
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def select(self, *cols):
             """Projects a set of expressions and returns a new :class:`DataFrame`.
     
    @@ -694,13 +735,14 @@ def select(self, *cols):
             jdf = self._jdf.select(self._jcols(*cols))
             return DataFrame(jdf, self.sql_ctx)
     
    +    @since(1.3)
         def selectExpr(self, *expr):
             """Projects a set of SQL expressions and returns a new :class:`DataFrame`.
     
             This is a variant of :func:`select` that accepts SQL expressions.
     
             >>> df.selectExpr("age * 2", "abs(age)").collect()
    -        [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
    +        [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)]
             """
             if len(expr) == 1 and isinstance(expr[0], list):
                 expr = expr[0]
    @@ -708,6 +750,7 @@ def selectExpr(self, *expr):
             return DataFrame(jdf, self.sql_ctx)
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def filter(self, condition):
             """Filters rows using the given condition.
     
    @@ -737,6 +780,7 @@ def filter(self, condition):
         where = filter
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def groupBy(self, *cols):
             """Groups the :class:`DataFrame` using the specified columns,
             so we can run aggregation on them. See :class:`GroupedData`
    @@ -748,29 +792,76 @@ def groupBy(self, *cols):
                 Each element should be a column name (string) or an expression (:class:`Column`).
     
             >>> df.groupBy().avg().collect()
    -        [Row(AVG(age)=3.5)]
    +        [Row(avg(age)=3.5)]
             >>> df.groupBy('name').agg({'age': 'mean'}).collect()
    -        [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
    +        [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
             >>> df.groupBy(df.name).avg().collect()
    -        [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
    +        [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
             >>> df.groupBy(['name', df.age]).count().collect()
             [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
             """
    -        jdf = self._jdf.groupBy(self._jcols(*cols))
    -        return GroupedData(jdf, self.sql_ctx)
    -
    +        jgd = self._jdf.groupBy(self._jcols(*cols))
    +        from pyspark.sql.group import GroupedData
    +        return GroupedData(jgd, self.sql_ctx)
    +
    +    @since(1.4)
    +    def rollup(self, *cols):
    +        """
    +        Create a multi-dimensional rollup for the current :class:`DataFrame` using
    +        the specified columns, so we can run aggregation on them.
    +
    +        >>> df.rollup('name', df.age).count().show()
    +        +-----+----+-----+
    +        | name| age|count|
    +        +-----+----+-----+
    +        |Alice|null|    1|
    +        |  Bob|   5|    1|
    +        |  Bob|null|    1|
    +        | null|null|    2|
    +        |Alice|   2|    1|
    +        +-----+----+-----+
    +        """
    +        jgd = self._jdf.rollup(self._jcols(*cols))
    +        from pyspark.sql.group import GroupedData
    +        return GroupedData(jgd, self.sql_ctx)
    +
    +    @since(1.4)
    +    def cube(self, *cols):
    +        """
    +        Create a multi-dimensional cube for the current :class:`DataFrame` using
    +        the specified columns, so we can run aggregation on them.
    +
    +        >>> df.cube('name', df.age).count().show()
    +        +-----+----+-----+
    +        | name| age|count|
    +        +-----+----+-----+
    +        | null|   2|    1|
    +        |Alice|null|    1|
    +        |  Bob|   5|    1|
    +        |  Bob|null|    1|
    +        | null|   5|    1|
    +        | null|null|    2|
    +        |Alice|   2|    1|
    +        +-----+----+-----+
    +        """
    +        jgd = self._jdf.cube(self._jcols(*cols))
    +        from pyspark.sql.group import GroupedData
    +        return GroupedData(jgd, self.sql_ctx)
    +
    +    @since(1.3)
         def agg(self, *exprs):
             """ Aggregate on the entire :class:`DataFrame` without groups
             (shorthand for ``df.groupBy.agg()``).
     
             >>> df.agg({"age": "max"}).collect()
    -        [Row(MAX(age)=5)]
    +        [Row(max(age)=5)]
             >>> from pyspark.sql import functions as F
             >>> df.agg(F.min(df.age)).collect()
    -        [Row(MIN(age)=2)]
    +        [Row(min(age)=2)]
             """
             return self.groupBy().agg(*exprs)
     
    +    @since(1.3)
         def unionAll(self, other):
             """ Return a new :class:`DataFrame` containing union of rows in this
             frame and another frame.
    @@ -779,6 +870,7 @@ def unionAll(self, other):
             """
             return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
     
    +    @since(1.3)
         def intersect(self, other):
             """ Return a new :class:`DataFrame` containing rows only in
             both this frame and another frame.
    @@ -787,6 +879,7 @@ def intersect(self, other):
             """
             return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
     
    +    @since(1.3)
         def subtract(self, other):
             """ Return a new :class:`DataFrame` containing rows in this frame
             but not in another frame.
    @@ -795,6 +888,7 @@ def subtract(self, other):
             """
             return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
     
    +    @since(1.4)
         def dropDuplicates(self, subset=None):
             """Return a new :class:`DataFrame` with duplicate rows removed,
             optionally only considering certain columns.
    @@ -825,10 +919,10 @@ def dropDuplicates(self, subset=None):
                 jdf = self._jdf.dropDuplicates(self._jseq(subset))
             return DataFrame(jdf, self.sql_ctx)
     
    +    @since("1.3.1")
         def dropna(self, how='any', thresh=None, subset=None):
             """Returns a new :class:`DataFrame` omitting rows with null values.
    -
    -        This is an alias for ``na.drop()``.
    +        :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other.
     
             :param how: 'any' or 'all'.
                 If 'any', drop a row if it contains any nulls.
    @@ -838,13 +932,6 @@ def dropna(self, how='any', thresh=None, subset=None):
                 This overwrites the `how` parameter.
             :param subset: optional list of column names to consider.
     
    -        >>> df4.dropna().show()
    -        +---+------+-----+
    -        |age|height| name|
    -        +---+------+-----+
    -        | 10|    80|Alice|
    -        +---+------+-----+
    -
             >>> df4.na.drop().show()
             +---+------+-----+
             |age|height| name|
    @@ -867,8 +954,10 @@ def dropna(self, how='any', thresh=None, subset=None):
     
             return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx)
     
    +    @since("1.3.1")
         def fillna(self, value, subset=None):
             """Replace null values, alias for ``na.fill()``.
    +        :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other.
     
             :param value: int, long, float, string, or dict.
                 Value to replace null values with.
    @@ -880,7 +969,7 @@ def fillna(self, value, subset=None):
                 For example, if `value` is a string, and subset contains a non-string column,
                 then the non-string column is simply ignored.
     
    -        >>> df4.fillna(50).show()
    +        >>> df4.na.fill(50).show()
             +---+------+-----+
             |age|height| name|
             +---+------+-----+
    @@ -890,16 +979,6 @@ def fillna(self, value, subset=None):
             | 50|    50| null|
             +---+------+-----+
     
    -        >>> df4.fillna({'age': 50, 'name': 'unknown'}).show()
    -        +---+------+-------+
    -        |age|height|   name|
    -        +---+------+-------+
    -        | 10|    80|  Alice|
    -        |  5|  null|    Bob|
    -        | 50|  null|    Tom|
    -        | 50|  null|unknown|
    -        +---+------+-------+
    -
             >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()
             +---+------+-------+
             |age|height|   name|
    @@ -928,8 +1007,11 @@ def fillna(self, value, subset=None):
     
                 return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
     
    +    @since(1.4)
         def replace(self, to_replace, value, subset=None):
             """Returns a new :class:`DataFrame` replacing a value with another value.
    +        :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
    +        aliases of each other.
     
             :param to_replace: int, long, float, string, or list.
                 Value to be replaced.
    @@ -944,7 +1026,8 @@ def replace(self, to_replace, value, subset=None):
                 Columns specified in subset that do not have matching data type are ignored.
                 For example, if `value` is a string, and subset contains a non-string column,
                 then the non-string column is simply ignored.
    -        >>> df4.replace(10, 20).show()
    +
    +        >>> df4.na.replace(10, 20).show()
             +----+------+-----+
             | age|height| name|
             +----+------+-----+
    @@ -954,7 +1037,7 @@ def replace(self, to_replace, value, subset=None):
             |null|  null| null|
             +----+------+-----+
     
    -        >>> df4.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()
    +        >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()
             +----+------+----+
             | age|height|name|
             +----+------+----+
    @@ -1002,11 +1085,12 @@ def replace(self, to_replace, value, subset=None):
             return DataFrame(
                 self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)
     
    +    @since(1.4)
         def corr(self, col1, col2, method=None):
             """
    -        Calculates the correlation of two columns of a DataFrame as a double value. Currently only
    -        supports the Pearson Correlation Coefficient.
    -        :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases.
    +        Calculates the correlation of two columns of a DataFrame as a double value.
    +        Currently only supports the Pearson Correlation Coefficient.
    +        :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases of each other.
     
             :param col1: The name of the first column
             :param col2: The name of the second column
    @@ -1023,6 +1107,7 @@ def corr(self, col1, col2, method=None):
                                  "coefficient is supported.")
             return self._jdf.stat().corr(col1, col2, method)
     
    +    @since(1.4)
         def cov(self, col1, col2):
             """
             Calculate the sample covariance for the given columns, specified by their names, as a
    @@ -1037,6 +1122,7 @@ def cov(self, col1, col2):
                 raise ValueError("col2 should be a string.")
             return self._jdf.stat().cov(col1, col2)
     
    +    @since(1.4)
         def crosstab(self, col1, col2):
             """
             Computes a pair-wise frequency table of the given columns. Also known as a contingency
    @@ -1058,6 +1144,7 @@ def crosstab(self, col1, col2):
                 raise ValueError("col2 should be a string.")
             return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
     
    +    @since(1.4)
         def freqItems(self, cols, support=None):
             """
             Finding frequent items for columns, possibly with false positives. Using the
    @@ -1065,6 +1152,9 @@ def freqItems(self, cols, support=None):
             "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou".
             :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases.
     
    +        .. note::  This function is meant for exploratory data analysis, as we make no \
    +        guarantee about the backward compatibility of the schema of the resulting DataFrame.
    +
             :param cols: Names of the columns to calculate frequent items for as a list or tuple of
                 strings.
             :param support: The frequency with which to consider an item 'frequent'. Default is 1%.
    @@ -1079,6 +1169,7 @@ def freqItems(self, cols, support=None):
             return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sql_ctx)
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def withColumn(self, colName, col):
             """Returns a new :class:`DataFrame` by adding a column.
     
    @@ -1091,6 +1182,7 @@ def withColumn(self, colName, col):
             return self.select('*', col.alias(colName))
     
         @ignore_unicode_prefix
    +    @since(1.3)
         def withColumnRenamed(self, existing, new):
             """Returns a new :class:`DataFrame` by renaming an existing column.
     
    @@ -1105,18 +1197,35 @@ def withColumnRenamed(self, existing, new):
                     for c in self.columns]
             return self.select(*cols)
     
    +    @since(1.4)
         @ignore_unicode_prefix
    -    def drop(self, colName):
    +    def drop(self, col):
             """Returns a new :class:`DataFrame` that drops the specified column.
     
    -        :param colName: string, name of the column to drop.
    +        :param col: a string name of the column to drop, or a
    +            :class:`Column` to drop.
     
             >>> df.drop('age').collect()
             [Row(name=u'Alice'), Row(name=u'Bob')]
    +
    +        >>> df.drop(df.age).collect()
    +        [Row(name=u'Alice'), Row(name=u'Bob')]
    +
    +        >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect()
    +        [Row(age=5, height=85, name=u'Bob')]
    +
    +        >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect()
    +        [Row(age=5, name=u'Bob', height=85)]
             """
    -        jdf = self._jdf.drop(colName)
    +        if isinstance(col, basestring):
    +            jdf = self._jdf.drop(col)
    +        elif isinstance(col, Column):
    +            jdf = self._jdf.drop(col._jc)
    +        else:
    +            raise TypeError("col should be a string or a Column")
             return DataFrame(jdf, self.sql_ctx)
     
    +    @since(1.3)
         def toPandas(self):
             """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
     
    @@ -1130,7 +1239,10 @@ def toPandas(self):
             import pandas as pd
             return pd.DataFrame.from_records(self.collect(), columns=self.columns)
     
    +    ##########################################################################################
         # Pandas compatibility
    +    ##########################################################################################
    +
         groupby = groupBy
         drop_duplicates = dropDuplicates
     
    @@ -1141,169 +1253,6 @@ class SchemaRDD(DataFrame):
         """
     
     
    -def dfapi(f):
    -    def _api(self):
    -        name = f.__name__
    -        jdf = getattr(self._jdf, name)()
    -        return DataFrame(jdf, self.sql_ctx)
    -    _api.__name__ = f.__name__
    -    _api.__doc__ = f.__doc__
    -    return _api
    -
    -
    -def df_varargs_api(f):
    -    def _api(self, *args):
    -        name = f.__name__
    -        jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
    -        return DataFrame(jdf, self.sql_ctx)
    -    _api.__name__ = f.__name__
    -    _api.__doc__ = f.__doc__
    -    return _api
    -
    -
    -class GroupedData(object):
    -    """
    -    A set of methods for aggregations on a :class:`DataFrame`,
    -    created by :func:`DataFrame.groupBy`.
    -    """
    -
    -    def __init__(self, jdf, sql_ctx):
    -        self._jdf = jdf
    -        self.sql_ctx = sql_ctx
    -
    -    @ignore_unicode_prefix
    -    def agg(self, *exprs):
    -        """Compute aggregates and returns the result as a :class:`DataFrame`.
    -
    -        The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`.
    -
    -        If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
    -        is the column to perform aggregation on, and the value is the aggregate function.
    -
    -        Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
    -
    -        :param exprs: a dict mapping from column name (string) to aggregate functions (string),
    -            or a list of :class:`Column`.
    -
    -        >>> gdf = df.groupBy(df.name)
    -        >>> gdf.agg({"*": "count"}).collect()
    -        [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
    -
    -        >>> from pyspark.sql import functions as F
    -        >>> gdf.agg(F.min(df.age)).collect()
    -        [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
    -        """
    -        assert exprs, "exprs should not be empty"
    -        if len(exprs) == 1 and isinstance(exprs[0], dict):
    -            jdf = self._jdf.agg(exprs[0])
    -        else:
    -            # Columns
    -            assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
    -            jdf = self._jdf.agg(exprs[0]._jc,
    -                                _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
    -        return DataFrame(jdf, self.sql_ctx)
    -
    -    @dfapi
    -    def count(self):
    -        """Counts the number of records for each group.
    -
    -        >>> df.groupBy(df.age).count().collect()
    -        [Row(age=2, count=1), Row(age=5, count=1)]
    -        """
    -
    -    @df_varargs_api
    -    def mean(self, *cols):
    -        """Computes average values for each numeric columns for each group.
    -
    -        :func:`mean` is an alias for :func:`avg`.
    -
    -        :param cols: list of column names (string). Non-numeric columns are ignored.
    -
    -        >>> df.groupBy().mean('age').collect()
    -        [Row(AVG(age)=3.5)]
    -        >>> df3.groupBy().mean('age', 'height').collect()
    -        [Row(AVG(age)=3.5, AVG(height)=82.5)]
    -        """
    -
    -    @df_varargs_api
    -    def avg(self, *cols):
    -        """Computes average values for each numeric columns for each group.
    -
    -        :func:`mean` is an alias for :func:`avg`.
    -
    -        :param cols: list of column names (string). Non-numeric columns are ignored.
    -
    -        >>> df.groupBy().avg('age').collect()
    -        [Row(AVG(age)=3.5)]
    -        >>> df3.groupBy().avg('age', 'height').collect()
    -        [Row(AVG(age)=3.5, AVG(height)=82.5)]
    -        """
    -
    -    @df_varargs_api
    -    def max(self, *cols):
    -        """Computes the max value for each numeric columns for each group.
    -
    -        >>> df.groupBy().max('age').collect()
    -        [Row(MAX(age)=5)]
    -        >>> df3.groupBy().max('age', 'height').collect()
    -        [Row(MAX(age)=5, MAX(height)=85)]
    -        """
    -
    -    @df_varargs_api
    -    def min(self, *cols):
    -        """Computes the min value for each numeric column for each group.
    -
    -        :param cols: list of column names (string). Non-numeric columns are ignored.
    -
    -        >>> df.groupBy().min('age').collect()
    -        [Row(MIN(age)=2)]
    -        >>> df3.groupBy().min('age', 'height').collect()
    -        [Row(MIN(age)=2, MIN(height)=80)]
    -        """
    -
    -    @df_varargs_api
    -    def sum(self, *cols):
    -        """Compute the sum for each numeric columns for each group.
    -
    -        :param cols: list of column names (string). Non-numeric columns are ignored.
    -
    -        >>> df.groupBy().sum('age').collect()
    -        [Row(SUM(age)=7)]
    -        >>> df3.groupBy().sum('age', 'height').collect()
    -        [Row(SUM(age)=7, SUM(height)=165)]
    -        """
    -
    -
    -def _create_column_from_literal(literal):
    -    sc = SparkContext._active_spark_context
    -    return sc._jvm.functions.lit(literal)
    -
    -
    -def _create_column_from_name(name):
    -    sc = SparkContext._active_spark_context
    -    return sc._jvm.functions.col(name)
    -
    -
    -def _to_java_column(col):
    -    if isinstance(col, Column):
    -        jcol = col._jc
    -    else:
    -        jcol = _create_column_from_name(col)
    -    return jcol
    -
    -
    -def _to_seq(sc, cols, converter=None):
    -    """
    -    Convert a list of Column (or names) into a JVM Seq of Column.
    -
    -    An optional `converter` could be used to convert items in `cols`
    -    into JVM Column objects.
    -    """
    -    if converter:
    -        cols = [converter(c) for c in cols]
    -    return sc._jvm.PythonUtils.toSeq(cols)
    -
    -
     def _to_scala_map(sc, jm):
         """
         Convert a dict into a JVM Map.
    @@ -1311,278 +1260,10 @@ def _to_scala_map(sc, jm):
         return sc._jvm.PythonUtils.toScalaMap(jm)
     
     
    -def _unary_op(name, doc="unary operator"):
    -    """ Create a method for given unary operator """
    -    def _(self):
    -        jc = getattr(self._jc, name)()
    -        return Column(jc)
    -    _.__doc__ = doc
    -    return _
    -
    -
    -def _func_op(name, doc=''):
    -    def _(self):
    -        sc = SparkContext._active_spark_context
    -        jc = getattr(sc._jvm.functions, name)(self._jc)
    -        return Column(jc)
    -    _.__doc__ = doc
    -    return _
    -
    -
    -def _bin_op(name, doc="binary operator"):
    -    """ Create a method for given binary operator
    -    """
    -    def _(self, other):
    -        jc = other._jc if isinstance(other, Column) else other
    -        njc = getattr(self._jc, name)(jc)
    -        return Column(njc)
    -    _.__doc__ = doc
    -    return _
    -
    -
    -def _reverse_op(name, doc="binary operator"):
    -    """ Create a method for binary operator (this object is on right side)
    -    """
    -    def _(self, other):
    -        jother = _create_column_from_literal(other)
    -        jc = getattr(jother, name)(self._jc)
    -        return Column(jc)
    -    _.__doc__ = doc
    -    return _
    -
    -
    -class Column(object):
    -
    -    """
    -    A column in a DataFrame.
    -
    -    :class:`Column` instances can be created by::
    -
    -        # 1. Select a column out of a DataFrame
    -
    -        df.colName
    -        df["colName"]
    -
    -        # 2. Create from an expression
    -        df.colName + 1
    -        1 / df.colName
    -    """
    -
    -    def __init__(self, jc):
    -        self._jc = jc
    -
    -    # arithmetic operators
    -    __neg__ = _func_op("negate")
    -    __add__ = _bin_op("plus")
    -    __sub__ = _bin_op("minus")
    -    __mul__ = _bin_op("multiply")
    -    __div__ = _bin_op("divide")
    -    __truediv__ = _bin_op("divide")
    -    __mod__ = _bin_op("mod")
    -    __radd__ = _bin_op("plus")
    -    __rsub__ = _reverse_op("minus")
    -    __rmul__ = _bin_op("multiply")
    -    __rdiv__ = _reverse_op("divide")
    -    __rtruediv__ = _reverse_op("divide")
    -    __rmod__ = _reverse_op("mod")
    -
    -    # logistic operators
    -    __eq__ = _bin_op("equalTo")
    -    __ne__ = _bin_op("notEqual")
    -    __lt__ = _bin_op("lt")
    -    __le__ = _bin_op("leq")
    -    __ge__ = _bin_op("geq")
    -    __gt__ = _bin_op("gt")
    -
    -    # `and`, `or`, `not` cannot be overloaded in Python,
    -    # so use bitwise operators as boolean operators
    -    __and__ = _bin_op('and')
    -    __or__ = _bin_op('or')
    -    __invert__ = _func_op('not')
    -    __rand__ = _bin_op("and")
    -    __ror__ = _bin_op("or")
    -
    -    # container operators
    -    __contains__ = _bin_op("contains")
    -    __getitem__ = _bin_op("apply")
    -
    -    # bitwise operators
    -    bitwiseOR = _bin_op("bitwiseOR")
    -    bitwiseAND = _bin_op("bitwiseAND")
    -    bitwiseXOR = _bin_op("bitwiseXOR")
    -
    -    def getItem(self, key):
    -        """An expression that gets an item at position `ordinal` out of a list,
    -         or gets an item by key out of a dict.
    -
    -        >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"])
    -        >>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
    -        +----+------+
    -        |l[0]|d[key]|
    -        +----+------+
    -        |   1| value|
    -        +----+------+
    -        >>> df.select(df.l[0], df.d["key"]).show()
    -        +----+------+
    -        |l[0]|d[key]|
    -        +----+------+
    -        |   1| value|
    -        +----+------+
    -        """
    -        return self[key]
    -
    -    def getField(self, name):
    -        """An expression that gets a field by name in a StructField.
    -
    -        >>> from pyspark.sql import Row
    -        >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
    -        >>> df.select(df.r.getField("b")).show()
    -        +----+
    -        |r[b]|
    -        +----+
    -        |   b|
    -        +----+
    -        >>> df.select(df.r.a).show()
    -        +----+
    -        |r[a]|
    -        +----+
    -        |   1|
    -        +----+
    -        """
    -        return self[name]
    -
    -    def __getattr__(self, item):
    -        if item.startswith("__"):
    -            raise AttributeError(item)
    -        return self.getField(item)
    -
    -    # string methods
    -    rlike = _bin_op("rlike")
    -    like = _bin_op("like")
    -    startswith = _bin_op("startsWith")
    -    endswith = _bin_op("endsWith")
    -
    -    @ignore_unicode_prefix
    -    def substr(self, startPos, length):
    -        """
    -        Return a :class:`Column` which is a substring of the column
    -
    -        :param startPos: start position (int or Column)
    -        :param length:  length of the substring (int or Column)
    -
    -        >>> df.select(df.name.substr(1, 3).alias("col")).collect()
    -        [Row(col=u'Ali'), Row(col=u'Bob')]
    -        """
    -        if type(startPos) != type(length):
    -            raise TypeError("Can not mix the type")
    -        if isinstance(startPos, (int, long)):
    -            jc = self._jc.substr(startPos, length)
    -        elif isinstance(startPos, Column):
    -            jc = self._jc.substr(startPos._jc, length._jc)
    -        else:
    -            raise TypeError("Unexpected type: %s" % type(startPos))
    -        return Column(jc)
    -
    -    __getslice__ = substr
    -
    -    @ignore_unicode_prefix
    -    def inSet(self, *cols):
    -        """ A boolean expression that is evaluated to true if the value of this
    -        expression is contained by the evaluated values of the arguments.
    -
    -        >>> df[df.name.inSet("Bob", "Mike")].collect()
    -        [Row(age=5, name=u'Bob')]
    -        >>> df[df.age.inSet([1, 2, 3])].collect()
    -        [Row(age=2, name=u'Alice')]
    -        """
    -        if len(cols) == 1 and isinstance(cols[0], (list, set)):
    -            cols = cols[0]
    -        cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
    -        sc = SparkContext._active_spark_context
    -        jc = getattr(self._jc, "in")(_to_seq(sc, cols))
    -        return Column(jc)
    -
    -    # order
    -    asc = _unary_op("asc", "Returns a sort expression based on the"
    -                           " ascending order of the given column name.")
    -    desc = _unary_op("desc", "Returns a sort expression based on the"
    -                             " descending order of the given column name.")
    -
    -    isNull = _unary_op("isNull", "True if the current expression is null.")
    -    isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
    -
    -    def alias(self, alias):
    -        """Return a alias for this column
    -
    -        >>> df.select(df.age.alias("age2")).collect()
    -        [Row(age2=2), Row(age2=5)]
    -        """
    -        return Column(getattr(self._jc, "as")(alias))
    -
    -    @ignore_unicode_prefix
    -    def cast(self, dataType):
    -        """ Convert the column into type `dataType`
    -
    -        >>> df.select(df.age.cast("string").alias('ages')).collect()
    -        [Row(ages=u'2'), Row(ages=u'5')]
    -        >>> df.select(df.age.cast(StringType()).alias('ages')).collect()
    -        [Row(ages=u'2'), Row(ages=u'5')]
    -        """
    -        if isinstance(dataType, basestring):
    -            jc = self._jc.cast(dataType)
    -        elif isinstance(dataType, DataType):
    -            sc = SparkContext._active_spark_context
    -            ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
    -            jdt = ssql_ctx.parseDataType(dataType.json())
    -            jc = self._jc.cast(jdt)
    -        else:
    -            raise TypeError("unexpected type: %s" % type(dataType))
    -        return Column(jc)
    -
    -    @ignore_unicode_prefix
    -    def between(self, lowerBound, upperBound):
    -        """ A boolean expression that is evaluated to true if the value of this
    -        expression is between the given columns.
    -        """
    -        return (self >= lowerBound) & (self <= upperBound)
    -
    -    @ignore_unicode_prefix
    -    def when(self, condition, value):
    -        """Evaluates a list of conditions and returns one of multiple possible result expressions.
    -        If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
    -
    -        See :func:`pyspark.sql.functions.when` for example usage.
    -
    -        :param condition: a boolean :class:`Column` expression.
    -        :param value: a literal value, or a :class:`Column` expression.
    -
    -        """
    -        sc = SparkContext._active_spark_context
    -        if not isinstance(condition, Column):
    -            raise TypeError("condition should be a Column")
    -        v = value._jc if isinstance(value, Column) else value
    -        jc = sc._jvm.functions.when(condition._jc, v)
    -        return Column(jc)
    -
    -    @ignore_unicode_prefix
    -    def otherwise(self, value):
    -        """Evaluates a list of conditions and returns one of multiple possible result expressions.
    -        If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
    -
    -        See :func:`pyspark.sql.functions.when` for example usage.
    -
    -        :param value: a literal value, or a :class:`Column` expression.
    -        """
    -        v = value._jc if isinstance(value, Column) else value
    -        jc = self._jc.otherwise(value)
    -        return Column(jc)
    -
    -    def __repr__(self):
    -        return 'Column<%s>' % self._jc.toString().encode('utf8')
    -
    -
     class DataFrameNaFunctions(object):
         """Functionality for working with missing data in :class:`DataFrame`.
    +
    +    .. versionadded:: 1.4
         """
     
         def __init__(self, df):
    @@ -1598,9 +1279,16 @@ def fill(self, value, subset=None):
     
         fill.__doc__ = DataFrame.fillna.__doc__
     
    +    def replace(self, to_replace, value, subset=None):
    +        return self.df.replace(to_replace, value, subset)
    +
    +    replace.__doc__ = DataFrame.replace.__doc__
    +
     
     class DataFrameStatFunctions(object):
         """Functionality for statistic functions with :class:`DataFrame`.
    +
    +    .. versionadded:: 1.4
         """
     
         def __init__(self, df):
    @@ -1640,9 +1328,8 @@ def _test():
             .toDF(StructType([StructField('age', IntegerType()),
                               StructField('name', StringType())]))
         globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
    -    globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
    -                                  Row(name='Bob', age=5, height=85)]).toDF()
    -
    +    globs['df3'] = sc.parallelize([Row(name='Alice', age=2),
    +                                   Row(name='Bob', age=5)]).toDF()
         globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80),
                                       Row(name='Bob', age=5, height=None),
                                       Row(name='Tom', age=None, height=None),
    diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
    index d91265ee0bec8..dca39fa833435 100644
    --- a/python/pyspark/sql/functions.py
    +++ b/python/pyspark/sql/functions.py
    @@ -18,6 +18,7 @@
     """
     A collections of builtin functions
     """
    +import math
     import sys
     
     if sys.version < "3":
    @@ -26,21 +27,33 @@
     from pyspark import SparkContext
     from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
     from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
    +from pyspark.sql import since
     from pyspark.sql.types import StringType
    -from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
    +from pyspark.sql.column import Column, _to_java_column, _to_seq
     
     
     __all__ = [
    +    'array',
         'approxCountDistinct',
    +    'bin',
         'coalesce',
         'countDistinct',
    +    'explode',
    +    'log2',
    +    'md5',
         'monotonicallyIncreasingId',
         'rand',
         'randn',
    +    'sha1',
    +    'sha2',
         'sparkPartitionId',
    +    'strlen',
    +    'struct',
         'udf',
         'when']
     
    +__all__ += ['lag', 'lead', 'ntile']
    +
     
     def _create_function(name, doc=""):
         """ Create a function for aggregator by name"""
    @@ -66,6 +79,17 @@ def _(col1, col2):
         return _
     
     
    +def _create_window_function(name, doc=''):
    +    """ Create a window function by name """
    +    def _():
    +        sc = SparkContext._active_spark_context
    +        jc = getattr(sc._jvm.functions, name)()
    +        return Column(jc)
    +    _.__name__ = name
    +    _.__doc__ = 'Window function: ' + doc
    +    return _
    +
    +
     _functions = {
         'lit': 'Creates a :class:`Column` of literal value.',
         'col': 'Returns a :class:`Column` based on the given column name.',
    @@ -78,6 +102,18 @@ def _(col1, col2):
         'sqrt': 'Computes the square root of the specified float value.',
         'abs': 'Computes the absolute value.',
     
    +    'max': 'Aggregate function: returns the maximum value of the expression in a group.',
    +    'min': 'Aggregate function: returns the minimum value of the expression in a group.',
    +    'first': 'Aggregate function: returns the first value in a group.',
    +    'last': 'Aggregate function: returns the last value in a group.',
    +    'count': 'Aggregate function: returns the number of items in a group.',
    +    'sum': 'Aggregate function: returns the sum of all values in the expression.',
    +    'avg': 'Aggregate function: returns the average of the values in a group.',
    +    'mean': 'Aggregate function: returns the average of the values in a group.',
    +    'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
    +}
    +
    +_functions_1_4 = {
         # unary math functions
         'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' +
                 '0.0 through pi.',
    @@ -102,21 +138,11 @@ def _(col1, col2):
         'tan': 'Computes the tangent of the given value.',
         'tanh': 'Computes the hyperbolic tangent of the given value.',
         'toDegrees': 'Converts an angle measured in radians to an approximately equivalent angle ' +
    -             'measured in degrees.',
    +                 'measured in degrees.',
         'toRadians': 'Converts an angle measured in degrees to an approximately equivalent angle ' +
    -             'measured in radians.',
    +                 'measured in radians.',
     
         'bitwiseNOT': 'Computes bitwise not.',
    -
    -    'max': 'Aggregate function: returns the maximum value of the expression in a group.',
    -    'min': 'Aggregate function: returns the minimum value of the expression in a group.',
    -    'first': 'Aggregate function: returns the first value in a group.',
    -    'last': 'Aggregate function: returns the last value in a group.',
    -    'count': 'Aggregate function: returns the number of items in a group.',
    -    'sum': 'Aggregate function: returns the sum of all values in the expression.',
    -    'avg': 'Aggregate function: returns the average of the values in a group.',
    -    'mean': 'Aggregate function: returns the average of the values in a group.',
    -    'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
     }
     
     # math functions that take two arguments as input
    @@ -124,19 +150,60 @@ def _(col1, col2):
         'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
                  'polar coordinates (r, theta).',
         'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
    -    'pow': 'Returns the value of the first argument raised to the power of the second argument.'
    +    'pow': 'Returns the value of the first argument raised to the power of the second argument.',
    +}
    +
    +_window_functions = {
    +    'rowNumber':
    +        """returns a sequential number starting at 1 within a window partition.
    +
    +        This is equivalent to the ROW_NUMBER function in SQL.""",
    +    'denseRank':
    +        """returns the rank of rows within a window partition, without any gaps.
    +
    +        The difference between rank and denseRank is that denseRank leaves no gaps in ranking
    +        sequence when there are ties. That is, if you were ranking a competition using denseRank
    +        and had three people tie for second place, you would say that all three were in second
    +        place and that the next person came in third.
    +
    +        This is equivalent to the DENSE_RANK function in SQL.""",
    +    'rank':
    +        """returns the rank of rows within a window partition.
    +
    +        The difference between rank and denseRank is that denseRank leaves no gaps in ranking
    +        sequence when there are ties. That is, if you were ranking a competition using denseRank
    +        and had three people tie for second place, you would say that all three were in second
    +        place and that the next person came in third.
    +
    +        This is equivalent to the RANK function in SQL.""",
    +    'cumeDist':
    +        """returns the cumulative distribution of values within a window partition,
    +        i.e. the fraction of rows that are below the current row.
    +
    +        This is equivalent to the CUME_DIST function in SQL.""",
    +    'percentRank':
    +        """returns the relative rank (i.e. percentile) of rows within a window partition.
    +
    +        This is equivalent to the PERCENT_RANK function in SQL.""",
     }
     
     for _name, _doc in _functions.items():
    -    globals()[_name] = _create_function(_name, _doc)
    +    globals()[_name] = since(1.3)(_create_function(_name, _doc))
    +for _name, _doc in _functions_1_4.items():
    +    globals()[_name] = since(1.4)(_create_function(_name, _doc))
     for _name, _doc in _binary_mathfunctions.items():
    -    globals()[_name] = _create_binary_mathfunction(_name, _doc)
    +    globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc))
    +for _name, _doc in _window_functions.items():
    +    globals()[_name] = since(1.4)(_create_window_function(_name, _doc))
     del _name, _doc
     __all__ += _functions.keys()
    +__all__ += _functions_1_4.keys()
     __all__ += _binary_mathfunctions.keys()
    +__all__ += _window_functions.keys()
     __all__.sort()
     
     
    +@since(1.4)
     def array(*cols):
         """Creates a new array column.
     
    @@ -155,6 +222,7 @@ def array(*cols):
         return Column(jc)
     
     
    +@since(1.3)
     def approxCountDistinct(col, rsd=None):
         """Returns a new :class:`Column` for approximate distinct count of ``col``.
     
    @@ -169,6 +237,20 @@ def approxCountDistinct(col, rsd=None):
         return Column(jc)
     
     
    +@ignore_unicode_prefix
    +@since(1.5)
    +def bin(col):
    +    """Returns the string representation of the binary value of the given column.
    +
    +    >>> df.select(bin(df.age).alias('c')).collect()
    +    [Row(c=u'10'), Row(c=u'101')]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.bin(_to_java_column(col))
    +    return Column(jc)
    +
    +
    +@since(1.4)
     def coalesce(*cols):
         """Returns the first column that is not null.
     
    @@ -184,7 +266,7 @@ def coalesce(*cols):
     
         >>> cDf.select(coalesce(cDf["a"], cDf["b"])).show()
         +-------------+
    -    |Coalesce(a,b)|
    +    |coalesce(a,b)|
         +-------------+
         |         null|
         |            1|
    @@ -193,7 +275,7 @@ def coalesce(*cols):
     
         >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show()
         +----+----+---------------+
    -    |   a|   b|Coalesce(a,0.0)|
    +    |   a|   b|coalesce(a,0.0)|
         +----+----+---------------+
         |null|null|            0.0|
         |   1|null|            1.0|
    @@ -205,6 +287,7 @@ def coalesce(*cols):
         return Column(jc)
     
     
    +@since(1.3)
     def countDistinct(col, *cols):
         """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``.
     
    @@ -219,6 +302,55 @@ def countDistinct(col, *cols):
         return Column(jc)
     
     
    +@since(1.4)
    +def explode(col):
    +    """Returns a new row for each element in the given array or map.
    +
    +    >>> from pyspark.sql import Row
    +    >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
    +    >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()
    +    [Row(anInt=1), Row(anInt=2), Row(anInt=3)]
    +
    +    >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show()
    +    +---+-----+
    +    |key|value|
    +    +---+-----+
    +    |  a|    b|
    +    +---+-----+
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.explode(_to_java_column(col))
    +    return Column(jc)
    +
    +
    +@ignore_unicode_prefix
    +@since(1.5)
    +def levenshtein(left, right):
    +    """Computes the Levenshtein distance of the two given strings.
    +
    +    >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
    +    >>> df0.select(levenshtein('l', 'r').alias('d')).collect()
    +    [Row(d=3)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right))
    +    return Column(jc)
    +
    +
    +@ignore_unicode_prefix
    +@since(1.5)
    +def md5(col):
    +    """Calculates the MD5 digest and returns the value as a 32 character hex string.
    +
    +    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect()
    +    [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.md5(_to_java_column(col))
    +    return Column(jc)
    +
    +
    +@since(1.4)
     def monotonicallyIncreasingId():
         """A column that generates monotonically increasing 64-bit integers.
     
    @@ -227,7 +359,7 @@ def monotonicallyIncreasingId():
         within each partition in the lower 33 bits. The assumption is that the data frame has
         less than 1 billion partitions, and each partition has less than 8 billion records.
     
    -    As an example, consider a [[DataFrame]] with two partitions, each with 3 records.
    +    As an example, consider a :class:`DataFrame` with two partitions, each with 3 records.
         This expression would return the following IDs:
         0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594.
     
    @@ -239,6 +371,7 @@ def monotonicallyIncreasingId():
         return Column(sc._jvm.functions.monotonicallyIncreasingId())
     
     
    +@since(1.4)
     def rand(seed=None):
         """Generates a random column with i.i.d. samples from U[0.0, 1.0].
         """
    @@ -250,6 +383,7 @@ def rand(seed=None):
         return Column(jc)
     
     
    +@since(1.4)
     def randn(seed=None):
         """Generates a column with i.i.d. samples from the standard normal distribution.
         """
    @@ -261,6 +395,103 @@ def randn(seed=None):
         return Column(jc)
     
     
    +@ignore_unicode_prefix
    +@since(1.5)
    +def hex(col):
    +    """Computes hex value of the given column, which could be StringType,
    +    BinaryType, IntegerType or LongType.
    +
    +    >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect()
    +    [Row(hex(a)=u'414243', hex(b)=u'3')]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.hex(_to_java_column(col))
    +    return Column(jc)
    +
    +
    +@ignore_unicode_prefix
    +@since(1.5)
    +def unhex(col):
    +    """Inverse of hex. Interprets each pair of characters as a hexadecimal number
    +    and converts to the byte representation of number.
    +
    +    >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect()
    +    [Row(unhex(a)=bytearray(b'ABC'))]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.unhex(_to_java_column(col))
    +    return Column(jc)
    +
    +
    +@ignore_unicode_prefix
    +@since(1.5)
    +def sha1(col):
    +    """Returns the hex string result of SHA-1.
    +
    +    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect()
    +    [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.sha1(_to_java_column(col))
    +    return Column(jc)
    +
    +
    +@ignore_unicode_prefix
    +@since(1.5)
    +def sha2(col, numBits):
    +    """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384,
    +    and SHA-512). The numBits indicates the desired bit length of the result, which must have a
    +    value of 224, 256, 384, 512, or 0 (which is equivalent to 256).
    +
    +    >>> digests = df.select(sha2(df.name, 256).alias('s')).collect()
    +    >>> digests[0]
    +    Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043')
    +    >>> digests[1]
    +    Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961')
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.sha2(_to_java_column(col), numBits)
    +    return Column(jc)
    +
    +
    +@since(1.5)
    +def shiftLeft(col, numBits):
    +    """Shift the the given value numBits left.
    +
    +    >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect()
    +    [Row(r=42)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.shiftLeft(_to_java_column(col), numBits)
    +    return Column(jc)
    +
    +
    +@since(1.5)
    +def shiftRight(col, numBits):
    +    """Shift the the given value numBits right.
    +
    +    >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect()
    +    [Row(r=21)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits)
    +    return Column(jc)
    +
    +
    +@since(1.5)
    +def shiftRightUnsigned(col, numBits):
    +    """Unsigned shift the the given value numBits right.
    +
    +    >>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\
    +    .collect()
    +    [Row(r=9223372036854775787)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits)
    +    return Column(jc)
    +
    +
    +@since(1.4)
     def sparkPartitionId():
         """A column for partition ID of the Spark task.
     
    @@ -274,11 +505,23 @@ def sparkPartitionId():
     
     
     @ignore_unicode_prefix
    +@since(1.5)
    +def strlen(col):
    +    """Calculates the length of a string expression.
    +
    +    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect()
    +    [Row(length=3)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.strlen(_to_java_column(col)))
    +
    +
    +@ignore_unicode_prefix
    +@since(1.4)
     def struct(*cols):
         """Creates a new struct column.
     
         :param cols: list of column names (string) or list of :class:`Column` expressions
    -        that are named or aliased.
     
         >>> df.select(struct('age', 'name').alias("struct")).collect()
         [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]
    @@ -292,6 +535,7 @@ def struct(*cols):
         return Column(jc)
     
     
    +@since(1.4)
     def when(condition, value):
         """Evaluates a list of conditions and returns one of multiple possible result expressions.
         If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
    @@ -313,9 +557,91 @@ def when(condition, value):
         return Column(jc)
     
     
    +@since(1.5)
    +def log(arg1, arg2=None):
    +    """Returns the first argument-based logarithm of the second argument.
    +
    +    If there is only one argument, then this takes the natural logarithm of the argument.
    +
    +    >>> df.select(log(10.0, df.age).alias('ten')).map(lambda l: str(l.ten)[:7]).collect()
    +    ['0.30102', '0.69897']
    +
    +    >>> df.select(log(df.age).alias('e')).map(lambda l: str(l.e)[:7]).collect()
    +    ['0.69314', '1.60943']
    +    """
    +    sc = SparkContext._active_spark_context
    +    if arg2 is None:
    +        jc = sc._jvm.functions.log(_to_java_column(arg1))
    +    else:
    +        jc = sc._jvm.functions.log(arg1, _to_java_column(arg2))
    +    return Column(jc)
    +
    +
    +@since(1.5)
    +def log2(col):
    +    """Returns the base-2 logarithm of the argument.
    +
    +    >>> sqlContext.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect()
    +    [Row(log2=2.0)]
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.log2(_to_java_column(col)))
    +
    +
    +@since(1.4)
    +def lag(col, count=1, default=None):
    +    """
    +    Window function: returns the value that is `offset` rows before the current row, and
    +    `defaultValue` if there is less than `offset` rows before the current row. For example,
    +    an `offset` of one will return the previous row at any given point in the window partition.
    +
    +    This is equivalent to the LAG function in SQL.
    +
    +    :param col: name of column or expression
    +    :param count: number of row to extend
    +    :param default: default value
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.lag(_to_java_column(col), count, default))
    +
    +
    +@since(1.4)
    +def lead(col, count=1, default=None):
    +    """
    +    Window function: returns the value that is `offset` rows after the current row, and
    +    `defaultValue` if there is less than `offset` rows after the current row. For example,
    +    an `offset` of one will return the next row at any given point in the window partition.
    +
    +    This is equivalent to the LEAD function in SQL.
    +
    +    :param col: name of column or expression
    +    :param count: number of row to extend
    +    :param default: default value
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.lead(_to_java_column(col), count, default))
    +
    +
    +@since(1.4)
    +def ntile(n):
    +    """
    +    Window function: returns a group id from 1 to `n` (inclusive) in a round-robin fashion in
    +    a window partition. Fow example, if `n` is 3, the first row will get 1, the second row will
    +    get 2, the third row will get 3, and the fourth row will get 1...
    +
    +    This is equivalent to the NTILE function in SQL.
    +
    +    :param n: an integer
    +    """
    +    sc = SparkContext._active_spark_context
    +    return Column(sc._jvm.functions.ntile(int(n)))
    +
    +
     class UserDefinedFunction(object):
         """
         User defined function in Python
    +
    +    .. versionadded:: 1.3
         """
         def __init__(self, func, returnType):
             self.func = func
    @@ -333,8 +659,8 @@ def _create_judf(self):
             ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
             jdt = ssql_ctx.parseDataType(self.returnType.json())
             fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
    -        judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env,
    -                                                 includes, sc.pythonExec, broadcast_vars,
    +        judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
    +                                                 sc.pythonExec, sc.pythonVer, broadcast_vars,
                                                      sc._javaAccumulator, jdt)
             return judf
     
    @@ -349,6 +675,7 @@ def __call__(self, *cols):
             return Column(jc)
     
     
    +@since(1.3)
     def udf(f, returnType=StringType()):
         """Creates a :class:`Column` expression representing a user defined function (UDF).
     
    diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
    new file mode 100644
    index 0000000000000..04594d5a836ce
    --- /dev/null
    +++ b/python/pyspark/sql/group.py
    @@ -0,0 +1,195 @@
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one or more
    +# contributor license agreements.  See the NOTICE file distributed with
    +# this work for additional information regarding copyright ownership.
    +# The ASF licenses this file to You under the Apache License, Version 2.0
    +# (the "License"); you may not use this file except in compliance with
    +# the License.  You may obtain a copy of the License at
    +#
    +#    http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing, software
    +# distributed under the License is distributed on an "AS IS" BASIS,
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +# See the License for the specific language governing permissions and
    +# limitations under the License.
    +#
    +
    +from pyspark.rdd import ignore_unicode_prefix
    +from pyspark.sql import since
    +from pyspark.sql.column import Column, _to_seq
    +from pyspark.sql.dataframe import DataFrame
    +from pyspark.sql.types import *
    +
    +__all__ = ["GroupedData"]
    +
    +
    +def dfapi(f):
    +    def _api(self):
    +        name = f.__name__
    +        jdf = getattr(self._jdf, name)()
    +        return DataFrame(jdf, self.sql_ctx)
    +    _api.__name__ = f.__name__
    +    _api.__doc__ = f.__doc__
    +    return _api
    +
    +
    +def df_varargs_api(f):
    +    def _api(self, *args):
    +        name = f.__name__
    +        jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
    +        return DataFrame(jdf, self.sql_ctx)
    +    _api.__name__ = f.__name__
    +    _api.__doc__ = f.__doc__
    +    return _api
    +
    +
    +class GroupedData(object):
    +    """
    +    A set of methods for aggregations on a :class:`DataFrame`,
    +    created by :func:`DataFrame.groupBy`.
    +
    +    .. note:: Experimental
    +
    +    .. versionadded:: 1.3
    +    """
    +
    +    def __init__(self, jdf, sql_ctx):
    +        self._jdf = jdf
    +        self.sql_ctx = sql_ctx
    +
    +    @ignore_unicode_prefix
    +    @since(1.3)
    +    def agg(self, *exprs):
    +        """Compute aggregates and returns the result as a :class:`DataFrame`.
    +
    +        The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`.
    +
    +        If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
    +        is the column to perform aggregation on, and the value is the aggregate function.
    +
    +        Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
    +
    +        :param exprs: a dict mapping from column name (string) to aggregate functions (string),
    +            or a list of :class:`Column`.
    +
    +        >>> gdf = df.groupBy(df.name)
    +        >>> gdf.agg({"*": "count"}).collect()
    +        [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)]
    +
    +        >>> from pyspark.sql import functions as F
    +        >>> gdf.agg(F.min(df.age)).collect()
    +        [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)]
    +        """
    +        assert exprs, "exprs should not be empty"
    +        if len(exprs) == 1 and isinstance(exprs[0], dict):
    +            jdf = self._jdf.agg(exprs[0])
    +        else:
    +            # Columns
    +            assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
    +            jdf = self._jdf.agg(exprs[0]._jc,
    +                                _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
    +        return DataFrame(jdf, self.sql_ctx)
    +
    +    @dfapi
    +    @since(1.3)
    +    def count(self):
    +        """Counts the number of records for each group.
    +
    +        >>> df.groupBy(df.age).count().collect()
    +        [Row(age=2, count=1), Row(age=5, count=1)]
    +        """
    +
    +    @df_varargs_api
    +    @since(1.3)
    +    def mean(self, *cols):
    +        """Computes average values for each numeric columns for each group.
    +
    +        :func:`mean` is an alias for :func:`avg`.
    +
    +        :param cols: list of column names (string). Non-numeric columns are ignored.
    +
    +        >>> df.groupBy().mean('age').collect()
    +        [Row(avg(age)=3.5)]
    +        >>> df3.groupBy().mean('age', 'height').collect()
    +        [Row(avg(age)=3.5, avg(height)=82.5)]
    +        """
    +
    +    @df_varargs_api
    +    @since(1.3)
    +    def avg(self, *cols):
    +        """Computes average values for each numeric columns for each group.
    +
    +        :func:`mean` is an alias for :func:`avg`.
    +
    +        :param cols: list of column names (string). Non-numeric columns are ignored.
    +
    +        >>> df.groupBy().avg('age').collect()
    +        [Row(avg(age)=3.5)]
    +        >>> df3.groupBy().avg('age', 'height').collect()
    +        [Row(avg(age)=3.5, avg(height)=82.5)]
    +        """
    +
    +    @df_varargs_api
    +    @since(1.3)
    +    def max(self, *cols):
    +        """Computes the max value for each numeric columns for each group.
    +
    +        >>> df.groupBy().max('age').collect()
    +        [Row(max(age)=5)]
    +        >>> df3.groupBy().max('age', 'height').collect()
    +        [Row(max(age)=5, max(height)=85)]
    +        """
    +
    +    @df_varargs_api
    +    @since(1.3)
    +    def min(self, *cols):
    +        """Computes the min value for each numeric column for each group.
    +
    +        :param cols: list of column names (string). Non-numeric columns are ignored.
    +
    +        >>> df.groupBy().min('age').collect()
    +        [Row(min(age)=2)]
    +        >>> df3.groupBy().min('age', 'height').collect()
    +        [Row(min(age)=2, min(height)=80)]
    +        """
    +
    +    @df_varargs_api
    +    @since(1.3)
    +    def sum(self, *cols):
    +        """Compute the sum for each numeric columns for each group.
    +
    +        :param cols: list of column names (string). Non-numeric columns are ignored.
    +
    +        >>> df.groupBy().sum('age').collect()
    +        [Row(sum(age)=7)]
    +        >>> df3.groupBy().sum('age', 'height').collect()
    +        [Row(sum(age)=7, sum(height)=165)]
    +        """
    +
    +
    +def _test():
    +    import doctest
    +    from pyspark.context import SparkContext
    +    from pyspark.sql import Row, SQLContext
    +    import pyspark.sql.group
    +    globs = pyspark.sql.group.__dict__.copy()
    +    sc = SparkContext('local[4]', 'PythonTest')
    +    globs['sc'] = sc
    +    globs['sqlContext'] = SQLContext(sc)
    +    globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
    +        .toDF(StructType([StructField('age', IntegerType()),
    +                          StructField('name', StringType())]))
    +    globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
    +                                   Row(name='Bob', age=5, height=85)]).toDF()
    +
    +    (failure_count, test_count) = doctest.testmod(
    +        pyspark.sql.group, globs=globs,
    +        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
    +    globs['sc'].stop()
    +    if failure_count:
    +        exit(-1)
    +
    +
    +if __name__ == "__main__":
    +    _test()
    diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
    new file mode 100644
    index 0000000000000..882a03090ec13
    --- /dev/null
    +++ b/python/pyspark/sql/readwriter.py
    @@ -0,0 +1,434 @@
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one or more
    +# contributor license agreements.  See the NOTICE file distributed with
    +# this work for additional information regarding copyright ownership.
    +# The ASF licenses this file to You under the Apache License, Version 2.0
    +# (the "License"); you may not use this file except in compliance with
    +# the License.  You may obtain a copy of the License at
    +#
    +#    http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing, software
    +# distributed under the License is distributed on an "AS IS" BASIS,
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +# See the License for the specific language governing permissions and
    +# limitations under the License.
    +#
    +
    +from py4j.java_gateway import JavaClass
    +
    +from pyspark.sql import since
    +from pyspark.sql.column import _to_seq
    +from pyspark.sql.types import *
    +
    +__all__ = ["DataFrameReader", "DataFrameWriter"]
    +
    +
    +class DataFrameReader(object):
    +    """
    +    Interface used to load a :class:`DataFrame` from external storage systems
    +    (e.g. file systems, key-value stores, etc). Use :func:`SQLContext.read`
    +    to access this.
    +
    +    ::Note: Experimental
    +
    +    .. versionadded:: 1.4
    +    """
    +
    +    def __init__(self, sqlContext):
    +        self._jreader = sqlContext._ssql_ctx.read()
    +        self._sqlContext = sqlContext
    +
    +    def _df(self, jdf):
    +        from pyspark.sql.dataframe import DataFrame
    +        return DataFrame(jdf, self._sqlContext)
    +
    +    @since(1.4)
    +    def format(self, source):
    +        """Specifies the input data source format.
    +
    +        :param source: string, name of the data source, e.g. 'json', 'parquet'.
    +
    +        >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json')
    +        >>> df.dtypes
    +        [('age', 'bigint'), ('name', 'string')]
    +
    +        """
    +        self._jreader = self._jreader.format(source)
    +        return self
    +
    +    @since(1.4)
    +    def schema(self, schema):
    +        """Specifies the input schema.
    +
    +        Some data sources (e.g. JSON) can infer the input schema automatically from data.
    +        By specifying the schema here, the underlying data source can skip the schema
    +        inference step, and thus speed up data loading.
    +
    +        :param schema: a StructType object
    +        """
    +        if not isinstance(schema, StructType):
    +            raise TypeError("schema should be StructType")
    +        jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
    +        self._jreader = self._jreader.schema(jschema)
    +        return self
    +
    +    @since(1.5)
    +    def option(self, key, value):
    +        """Adds an input option for the underlying data source.
    +        """
    +        self._jreader = self._jreader.option(key, value)
    +        return self
    +
    +    @since(1.4)
    +    def options(self, **options):
    +        """Adds input options for the underlying data source.
    +        """
    +        for k in options:
    +            self._jreader = self._jreader.option(k, options[k])
    +        return self
    +
    +    @since(1.4)
    +    def load(self, path=None, format=None, schema=None, **options):
    +        """Loads data from a data source and returns it as a :class`DataFrame`.
    +
    +        :param path: optional string for file-system backed data sources.
    +        :param format: optional string for format of the data source. Default to 'parquet'.
    +        :param schema: optional :class:`StructType` for the input schema.
    +        :param options: all other string options
    +
    +        >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned')
    +        >>> df.dtypes
    +        [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
    +        """
    +        if format is not None:
    +            self.format(format)
    +        if schema is not None:
    +            self.schema(schema)
    +        self.options(**options)
    +        if path is not None:
    +            return self._df(self._jreader.load(path))
    +        else:
    +            return self._df(self._jreader.load())
    +
    +    @since(1.4)
    +    def json(self, path, schema=None):
    +        """
    +        Loads a JSON file (one object per line) and returns the result as
    +        a :class`DataFrame`.
    +
    +        If the ``schema`` parameter is not specified, this function goes
    +        through the input once to determine the input schema.
    +
    +        :param path: string, path to the JSON dataset.
    +        :param schema: an optional :class:`StructType` for the input schema.
    +
    +        >>> df = sqlContext.read.json('python/test_support/sql/people.json')
    +        >>> df.dtypes
    +        [('age', 'bigint'), ('name', 'string')]
    +
    +        """
    +        if schema is not None:
    +            self.schema(schema)
    +        return self._df(self._jreader.json(path))
    +
    +    @since(1.4)
    +    def table(self, tableName):
    +        """Returns the specified table as a :class:`DataFrame`.
    +
    +        :param tableName: string, name of the table.
    +
    +        >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
    +        >>> df.registerTempTable('tmpTable')
    +        >>> sqlContext.read.table('tmpTable').dtypes
    +        [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
    +        """
    +        return self._df(self._jreader.table(tableName))
    +
    +    @since(1.4)
    +    def parquet(self, *path):
    +        """Loads a Parquet file, returning the result as a :class:`DataFrame`.
    +
    +        >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
    +        >>> df.dtypes
    +        [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
    +        """
    +        return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path)))
    +
    +    @since(1.4)
    +    def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None,
    +             predicates=None, properties={}):
    +        """
    +        Construct a :class:`DataFrame` representing the database table accessible
    +        via JDBC URL `url` named `table` and connection `properties`.
    +
    +        The `column` parameter could be used to partition the table, then it will
    +        be retrieved in parallel based on the parameters passed to this function.
    +
    +        The `predicates` parameter gives a list expressions suitable for inclusion
    +        in WHERE clauses; each one defines one partition of the :class:`DataFrame`.
    +
    +        ::Note: Don't create too many partitions in parallel on a large cluster;
    +        otherwise Spark might crash your external database systems.
    +
    +        :param url: a JDBC URL
    +        :param table: name of table
    +        :param column: the column used to partition
    +        :param lowerBound: the lower bound of partition column
    +        :param upperBound: the upper bound of the partition column
    +        :param numPartitions: the number of partitions
    +        :param predicates: a list of expressions
    +        :param properties: JDBC database connection arguments, a list of arbitrary string
    +                           tag/value. Normally at least a "user" and "password" property
    +                           should be included.
    +        :return: a DataFrame
    +        """
    +        jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
    +        for k in properties:
    +            jprop.setProperty(k, properties[k])
    +        if column is not None:
    +            if numPartitions is None:
    +                numPartitions = self._sqlContext._sc.defaultParallelism
    +            return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound),
    +                                               int(numPartitions), jprop))
    +        if predicates is not None:
    +            arr = self._sqlContext._sc._jvm.PythonUtils.toArray(predicates)
    +            return self._df(self._jreader.jdbc(url, table, arr, jprop))
    +        return self._df(self._jreader.jdbc(url, table, jprop))
    +
    +
    +class DataFrameWriter(object):
    +    """
    +    Interface used to write a [[DataFrame]] to external storage systems
    +    (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.write`
    +    to access this.
    +
    +    ::Note: Experimental
    +
    +    .. versionadded:: 1.4
    +    """
    +    def __init__(self, df):
    +        self._df = df
    +        self._sqlContext = df.sql_ctx
    +        self._jwrite = df._jdf.write()
    +
    +    @since(1.4)
    +    def mode(self, saveMode):
    +        """Specifies the behavior when data or table already exists.
    +
    +        Options include:
    +
    +        * `append`: Append contents of this :class:`DataFrame` to existing data.
    +        * `overwrite`: Overwrite existing data.
    +        * `error`: Throw an exception if data already exists.
    +        * `ignore`: Silently ignore this operation if data already exists.
    +
    +        >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
    +        """
    +        # At the JVM side, the default value of mode is already set to "error".
    +        # So, if the given saveMode is None, we will not call JVM-side's mode method.
    +        if saveMode is not None:
    +            self._jwrite = self._jwrite.mode(saveMode)
    +        return self
    +
    +    @since(1.4)
    +    def format(self, source):
    +        """Specifies the underlying output data source.
    +
    +        :param source: string, name of the data source, e.g. 'json', 'parquet'.
    +
    +        >>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data'))
    +        """
    +        self._jwrite = self._jwrite.format(source)
    +        return self
    +
    +    @since(1.5)
    +    def option(self, key, value):
    +        """Adds an output option for the underlying data source.
    +        """
    +        self._jwrite = self._jwrite.option(key, value)
    +        return self
    +
    +    @since(1.4)
    +    def options(self, **options):
    +        """Adds output options for the underlying data source.
    +        """
    +        for k in options:
    +            self._jwrite = self._jwrite.option(k, options[k])
    +        return self
    +
    +    @since(1.4)
    +    def partitionBy(self, *cols):
    +        """Partitions the output by the given columns on the file system.
    +
    +        If specified, the output is laid out on the file system similar
    +        to Hive's partitioning scheme.
    +
    +        :param cols: name of columns
    +
    +        >>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
    +        """
    +        if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
    +            cols = cols[0]
    +        self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
    +        return self
    +
    +    @since(1.4)
    +    def save(self, path=None, format=None, mode=None, partitionBy=None, **options):
    +        """Saves the contents of the :class:`DataFrame` to a data source.
    +
    +        The data source is specified by the ``format`` and a set of ``options``.
    +        If ``format`` is not specified, the default data source configured by
    +        ``spark.sql.sources.default`` will be used.
    +
    +        :param path: the path in a Hadoop supported file system
    +        :param format: the format used to save
    +        :param mode: specifies the behavior of the save operation when data already exists.
    +
    +            * ``append``: Append contents of this :class:`DataFrame` to existing data.
    +            * ``overwrite``: Overwrite existing data.
    +            * ``ignore``: Silently ignore this operation if data already exists.
    +            * ``error`` (default case): Throw an exception if data already exists.
    +        :param partitionBy: names of partitioning columns
    +        :param options: all other string options
    +
    +        >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
    +        """
    +        self.mode(mode).options(**options)
    +        if partitionBy is not None:
    +            self.partitionBy(partitionBy)
    +        if format is not None:
    +            self.format(format)
    +        if path is None:
    +            self._jwrite.save()
    +        else:
    +            self._jwrite.save(path)
    +
    +    @since(1.4)
    +    def insertInto(self, tableName, overwrite=False):
    +        """Inserts the content of the :class:`DataFrame` to the specified table.
    +
    +        It requires that the schema of the class:`DataFrame` is the same as the
    +        schema of the table.
    +
    +        Optionally overwriting any existing data.
    +        """
    +        self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName)
    +
    +    @since(1.4)
    +    def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options):
    +        """Saves the content of the :class:`DataFrame` as the specified table.
    +
    +        In the case the table already exists, behavior of this function depends on the
    +        save mode, specified by the `mode` function (default to throwing an exception).
    +        When `mode` is `Overwrite`, the schema of the [[DataFrame]] does not need to be
    +        the same as that of the existing table.
    +
    +        * `append`: Append contents of this :class:`DataFrame` to existing data.
    +        * `overwrite`: Overwrite existing data.
    +        * `error`: Throw an exception if data already exists.
    +        * `ignore`: Silently ignore this operation if data already exists.
    +
    +        :param name: the table name
    +        :param format: the format used to save
    +        :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
    +        :param partitionBy: names of partitioning columns
    +        :param options: all other string options
    +        """
    +        self.mode(mode).options(**options)
    +        if partitionBy is not None:
    +            self.partitionBy(partitionBy)
    +        if format is not None:
    +            self.format(format)
    +        self._jwrite.saveAsTable(name)
    +
    +    @since(1.4)
    +    def json(self, path, mode=None):
    +        """Saves the content of the :class:`DataFrame` in JSON format at the specified path.
    +
    +        :param path: the path in any Hadoop supported file system
    +        :param mode: specifies the behavior of the save operation when data already exists.
    +
    +            * ``append``: Append contents of this :class:`DataFrame` to existing data.
    +            * ``overwrite``: Overwrite existing data.
    +            * ``ignore``: Silently ignore this operation if data already exists.
    +            * ``error`` (default case): Throw an exception if data already exists.
    +
    +        >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
    +        """
    +        self.mode(mode)._jwrite.json(path)
    +
    +    @since(1.4)
    +    def parquet(self, path, mode=None, partitionBy=None):
    +        """Saves the content of the :class:`DataFrame` in Parquet format at the specified path.
    +
    +        :param path: the path in any Hadoop supported file system
    +        :param mode: specifies the behavior of the save operation when data already exists.
    +
    +            * ``append``: Append contents of this :class:`DataFrame` to existing data.
    +            * ``overwrite``: Overwrite existing data.
    +            * ``ignore``: Silently ignore this operation if data already exists.
    +            * ``error`` (default case): Throw an exception if data already exists.
    +        :param partitionBy: names of partitioning columns
    +
    +        >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
    +        """
    +        self.mode(mode)
    +        if partitionBy is not None:
    +            self.partitionBy(partitionBy)
    +        self._jwrite.parquet(path)
    +
    +    @since(1.4)
    +    def jdbc(self, url, table, mode=None, properties={}):
    +        """Saves the content of the :class:`DataFrame` to a external database table via JDBC.
    +
    +        .. note:: Don't create too many partitions in parallel on a large cluster;\
    +        otherwise Spark might crash your external database systems.
    +
    +        :param url: a JDBC URL of the form ``jdbc:subprotocol:subname``
    +        :param table: Name of the table in the external database.
    +        :param mode: specifies the behavior of the save operation when data already exists.
    +
    +            * ``append``: Append contents of this :class:`DataFrame` to existing data.
    +            * ``overwrite``: Overwrite existing data.
    +            * ``ignore``: Silently ignore this operation if data already exists.
    +            * ``error`` (default case): Throw an exception if data already exists.
    +        :param properties: JDBC database connection arguments, a list of
    +                           arbitrary string tag/value. Normally at least a
    +                           "user" and "password" property should be included.
    +        """
    +        jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
    +        for k in properties:
    +            jprop.setProperty(k, properties[k])
    +        self._jwrite.mode(mode).jdbc(url, table, jprop)
    +
    +
    +def _test():
    +    import doctest
    +    import os
    +    import tempfile
    +    from pyspark.context import SparkContext
    +    from pyspark.sql import Row, SQLContext
    +    import pyspark.sql.readwriter
    +
    +    os.chdir(os.environ["SPARK_HOME"])
    +
    +    globs = pyspark.sql.readwriter.__dict__.copy()
    +    sc = SparkContext('local[4]', 'PythonTest')
    +
    +    globs['tempfile'] = tempfile
    +    globs['os'] = os
    +    globs['sc'] = sc
    +    globs['sqlContext'] = SQLContext(sc)
    +    globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned')
    +
    +    (failure_count, test_count) = doctest.testmod(
    +        pyspark.sql.readwriter, globs=globs,
    +        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
    +    globs['sc'].stop()
    +    if failure_count:
    +        exit(-1)
    +
    +
    +if __name__ == "__main__":
    +    _test()
    diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
    index 1922d03af61da..241eac45cfe36 100644
    --- a/python/pyspark/sql/tests.py
    +++ b/python/pyspark/sql/tests.py
    @@ -1,3 +1,4 @@
    +# -*- encoding: utf-8 -*-
     #
     # Licensed to the Apache Software Foundation (ASF) under one or more
     # contributor license agreements.  See the NOTICE file distributed with
    @@ -26,6 +27,7 @@
     import tempfile
     import pickle
     import functools
    +import time
     import datetime
     
     import py4j
    @@ -44,6 +46,22 @@
     from pyspark.sql.types import UserDefinedType, _infer_type
     from pyspark.tests import ReusedPySparkTestCase
     from pyspark.sql.functions import UserDefinedFunction
    +from pyspark.sql.window import Window
    +from pyspark.sql.utils import AnalysisException
    +
    +
    +class UTC(datetime.tzinfo):
    +    """UTC"""
    +    ZERO = datetime.timedelta(0)
    +
    +    def utcoffset(self, dt):
    +        return self.ZERO
    +
    +    def tzname(self, dt):
    +        return "UTC"
    +
    +    def dst(self, dt):
    +        return self.ZERO
     
     
     class ExamplePointUDT(UserDefinedType):
    @@ -99,6 +117,15 @@ def test_data_type_eq(self):
             lt2 = pickle.loads(pickle.dumps(LongType()))
             self.assertEquals(lt, lt2)
     
    +    # regression test for SPARK-7978
    +    def test_decimal_type(self):
    +        t1 = DecimalType()
    +        t2 = DecimalType(10, 2)
    +        self.assertTrue(t2 is not t1)
    +        self.assertNotEqual(t1, t2)
    +        t3 = DecimalType(8)
    +        self.assertNotEqual(t2, t3)
    +
     
     class SQLTests(ReusedPySparkTestCase):
     
    @@ -117,6 +144,47 @@ def tearDownClass(cls):
             ReusedPySparkTestCase.tearDownClass()
             shutil.rmtree(cls.tempdir.name, ignore_errors=True)
     
    +    def test_range(self):
    +        self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
    +        self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
    +        self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2)
    +        self.assertEqual(self.sqlCtx.range(-2).count(), 0)
    +        self.assertEqual(self.sqlCtx.range(3).count(), 3)
    +
    +    def test_duplicated_column_names(self):
    +        df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"])
    +        row = df.select('*').first()
    +        self.assertEqual(1, row[0])
    +        self.assertEqual(2, row[1])
    +        self.assertEqual("Row(c=1, c=2)", str(row))
    +        # Cannot access columns
    +        self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
    +        self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
    +        self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
    +
    +    def test_explode(self):
    +        from pyspark.sql.functions import explode
    +        d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
    +        rdd = self.sc.parallelize(d)
    +        data = self.sqlCtx.createDataFrame(rdd)
    +
    +        result = data.select(explode(data.intlist).alias("a")).select("a").collect()
    +        self.assertEqual(result[0][0], 1)
    +        self.assertEqual(result[1][0], 2)
    +        self.assertEqual(result[2][0], 3)
    +
    +        result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
    +        self.assertEqual(result[0][0], "a")
    +        self.assertEqual(result[0][1], "b")
    +
    +    def test_and_in_expression(self):
    +        self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
    +        self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
    +        self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count())
    +        self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2")
    +        self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count())
    +        self.assertRaises(ValueError, lambda: not self.df.key == 1)
    +
         def test_udf_with_callable(self):
             d = [Row(number=i, squared=i**2) for i in range(10)]
             rdd = self.sc.parallelize(d)
    @@ -344,6 +412,14 @@ def test_apply_schema_with_udt(self):
             point = df.head().point
             self.assertEquals(point, ExamplePoint(1.0, 2.0))
     
    +    def test_udf_with_udt(self):
    +        from pyspark.sql.tests import ExamplePoint
    +        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
    +        df = self.sc.parallelize([row]).toDF()
    +        self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
    +        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
    +        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
    +
         def test_parquet_with_udt(self):
             from pyspark.sql.tests import ExamplePoint
             row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
    @@ -361,7 +437,7 @@ def test_column_operators(self):
             self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
             rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
             self.assertTrue(all(isinstance(c, Column) for c in rcc))
    -        cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
    +        cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
             self.assertTrue(all(isinstance(c, Column) for c in cb))
             cbool = (ci & ci), (ci | ci), (~ci)
             self.assertTrue(all(isinstance(c, Column) for c in cbool))
    @@ -461,33 +537,95 @@ def test_between_function(self):
             self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
                              df.filter(df.a.between(df.b, df.c)).collect())
     
    +    def test_struct_type(self):
    +        from pyspark.sql.types import StructType, StringType, StructField
    +        struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
    +        struct2 = StructType([StructField("f1", StringType(), True),
    +                              StructField("f2", StringType(), True, None)])
    +        self.assertEqual(struct1, struct2)
    +
    +        struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
    +        struct2 = StructType([StructField("f1", StringType(), True)])
    +        self.assertNotEqual(struct1, struct2)
    +
    +        struct1 = (StructType().add(StructField("f1", StringType(), True))
    +                   .add(StructField("f2", StringType(), True, None)))
    +        struct2 = StructType([StructField("f1", StringType(), True),
    +                              StructField("f2", StringType(), True, None)])
    +        self.assertEqual(struct1, struct2)
    +
    +        struct1 = (StructType().add(StructField("f1", StringType(), True))
    +                   .add(StructField("f2", StringType(), True, None)))
    +        struct2 = StructType([StructField("f1", StringType(), True)])
    +        self.assertNotEqual(struct1, struct2)
    +
    +        # Catch exception raised during improper construction
    +        try:
    +            struct1 = StructType().add("name")
    +            self.assertEqual(1, 0)
    +        except ValueError:
    +            self.assertEqual(1, 1)
    +
         def test_save_and_load(self):
             df = self.df
             tmpPath = tempfile.mkdtemp()
             shutil.rmtree(tmpPath)
    -        df.save(tmpPath, "org.apache.spark.sql.json", "error")
    -        actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
    -        self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
    +        df.write.json(tmpPath)
    +        actual = self.sqlCtx.read.json(tmpPath)
    +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
     
             schema = StructType([StructField("value", StringType(), True)])
    -        actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema)
    -        self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
    +        actual = self.sqlCtx.read.json(tmpPath, schema)
    +        self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
     
    -        df.save(tmpPath, "org.apache.spark.sql.json", "overwrite")
    -        actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
    -        self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
    +        df.write.json(tmpPath, "overwrite")
    +        actual = self.sqlCtx.read.json(tmpPath)
    +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
     
    -        df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath,
    -                noUse="this options will not be used in save.")
    -        actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath,
    -                                  noUse="this options will not be used in load.")
    -        self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
    +        df.write.save(format="json", mode="overwrite", path=tmpPath,
    +                      noUse="this options will not be used in save.")
    +        actual = self.sqlCtx.read.load(format="json", path=tmpPath,
    +                                       noUse="this options will not be used in load.")
    +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
     
             defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
                                                         "org.apache.spark.sql.parquet")
             self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
             actual = self.sqlCtx.load(path=tmpPath)
    -        self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
    +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
    +        self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
    +
    +        shutil.rmtree(tmpPath)
    +
    +    def test_save_and_load_builder(self):
    +        df = self.df
    +        tmpPath = tempfile.mkdtemp()
    +        shutil.rmtree(tmpPath)
    +        df.write.json(tmpPath)
    +        actual = self.sqlCtx.read.json(tmpPath)
    +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
    +
    +        schema = StructType([StructField("value", StringType(), True)])
    +        actual = self.sqlCtx.read.json(tmpPath, schema)
    +        self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
    +
    +        df.write.mode("overwrite").json(tmpPath)
    +        actual = self.sqlCtx.read.json(tmpPath)
    +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
    +
    +        df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
    +                .option("noUse", "this option will not be used in save.")\
    +                .format("json").save(path=tmpPath)
    +        actual =\
    +            self.sqlCtx.read.format("json")\
    +                            .load(path=tmpPath, noUse="this options will not be used in load.")
    +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
    +
    +        defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
    +                                                    "org.apache.spark.sql.parquet")
    +        self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
    +        actual = self.sqlCtx.load(path=tmpPath)
    +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
             self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
     
             shutil.rmtree(tmpPath)
    @@ -510,6 +648,14 @@ def test_access_column(self):
             self.assertRaises(IndexError, lambda: df["bad_key"])
             self.assertRaises(TypeError, lambda: df[{}])
     
    +    def test_column_name_with_non_ascii(self):
    +        df = self.sqlCtx.createDataFrame([(1,)], ["数量"])
    +        self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema)
    +        self.assertEqual("DataFrame[数量: bigint]", str(df))
    +        self.assertEqual([("数量", 'bigint')], df.dtypes)
    +        self.assertEqual(1, df.select("数量").first()[0])
    +        self.assertEqual(1, df.select(df["数量"]).first()[0])
    +
         def test_access_nested_types(self):
             df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
             self.assertEqual(1, df.select(df.l[0]).first()[0])
    @@ -556,6 +702,35 @@ def test_filter_with_datetime(self):
             self.assertEqual(0, df.filter(df.date > date).count())
             self.assertEqual(0, df.filter(df.time > time).count())
     
    +    def test_time_with_timezone(self):
    +        day = datetime.date.today()
    +        now = datetime.datetime.now()
    +        ts = time.mktime(now.timetuple())
    +        # class in __main__ is not serializable
    +        from pyspark.sql.tests import UTC
    +        utc = UTC()
    +        utcnow = datetime.datetime.utcfromtimestamp(ts)  # without microseconds
    +        # add microseconds to utcnow (keeping year,month,day,hour,minute,second)
    +        utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc)))
    +        df = self.sqlCtx.createDataFrame([(day, now, utcnow)])
    +        day1, now1, utcnow1 = df.first()
    +        self.assertEqual(day1, day)
    +        self.assertEqual(now, now1)
    +        self.assertEqual(now, utcnow1)
    +
    +    def test_decimal(self):
    +        from decimal import Decimal
    +        schema = StructType([StructField("decimal", DecimalType(10, 5))])
    +        df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema)
    +        row = df.select(df.decimal + 1).first()
    +        self.assertEqual(row[0], Decimal("4.14159"))
    +        tmpPath = tempfile.mkdtemp()
    +        shutil.rmtree(tmpPath)
    +        df.write.parquet(tmpPath)
    +        df2 = self.sqlCtx.read.parquet(tmpPath)
    +        row = df2.first()
    +        self.assertEqual(row[0], Decimal("3.14159"))
    +
         def test_dropna(self):
             schema = StructType([
                 StructField("name", StringType(), True),
    @@ -713,6 +888,12 @@ def test_replace(self):
             self.assertEqual(row.age, 10)
             self.assertEqual(row.height, None)
     
    +    def test_capture_analysis_exception(self):
    +        self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
    +        self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
    +        # RuntimeException should not be captured
    +        self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc"))
    +
     
     class HiveContextSQLTests(ReusedPySparkTestCase):
     
    @@ -723,11 +904,11 @@ def setUpClass(cls):
             try:
                 cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
             except py4j.protocol.Py4JError:
    -            cls.sqlCtx = None
    -            return
    +            cls.tearDownClass()
    +            raise unittest.SkipTest("Hive is not available")
             except TypeError:
    -            cls.sqlCtx = None
    -            return
    +            cls.tearDownClass()
    +            raise unittest.SkipTest("Hive is not available")
             os.unlink(cls.tempdir.name)
             _scala_HiveContext =\
                 cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
    @@ -741,57 +922,68 @@ def tearDownClass(cls):
             shutil.rmtree(cls.tempdir.name, ignore_errors=True)
     
         def test_save_and_load_table(self):
    -        if self.sqlCtx is None:
    -            return  # no hive available, skipped
    -
             df = self.df
             tmpPath = tempfile.mkdtemp()
             shutil.rmtree(tmpPath)
    -        df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath)
    -        actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath,
    -                                                 "org.apache.spark.sql.json")
    -        self.assertTrue(
    -            sorted(df.collect()) ==
    -            sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
    -        self.assertTrue(
    -            sorted(df.collect()) ==
    -            sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
    -        self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
    +        df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath)
    +        actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, "json")
    +        self.assertEqual(sorted(df.collect()),
    +                         sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
    +        self.assertEqual(sorted(df.collect()),
    +                         sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
    +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
             self.sqlCtx.sql("DROP TABLE externalJsonTable")
     
    -        df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath)
    +        df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath)
             schema = StructType([StructField("value", StringType(), True)])
    -        actual = self.sqlCtx.createExternalTable("externalJsonTable",
    -                                                 source="org.apache.spark.sql.json",
    +        actual = self.sqlCtx.createExternalTable("externalJsonTable", source="json",
                                                      schema=schema, path=tmpPath,
                                                      noUse="this options will not be used")
    -        self.assertTrue(
    -            sorted(df.collect()) ==
    -            sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
    -        self.assertTrue(
    -            sorted(df.select("value").collect()) ==
    -            sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
    -        self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
    +        self.assertEqual(sorted(df.collect()),
    +                         sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
    +        self.assertEqual(sorted(df.select("value").collect()),
    +                         sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
    +        self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
             self.sqlCtx.sql("DROP TABLE savedJsonTable")
             self.sqlCtx.sql("DROP TABLE externalJsonTable")
     
             defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
                                                         "org.apache.spark.sql.parquet")
             self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
    -        df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
    +        df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
             actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
    -        self.assertTrue(
    -            sorted(df.collect()) ==
    -            sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
    -        self.assertTrue(
    -            sorted(df.collect()) ==
    -            sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
    -        self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
    +        self.assertEqual(sorted(df.collect()),
    +                         sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
    +        self.assertEqual(sorted(df.collect()),
    +                         sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
    +        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
             self.sqlCtx.sql("DROP TABLE savedJsonTable")
             self.sqlCtx.sql("DROP TABLE externalJsonTable")
             self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
     
             shutil.rmtree(tmpPath)
     
    +    def test_window_functions(self):
    +        df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
    +        w = Window.partitionBy("value").orderBy("key")
    +        from pyspark.sql import functions as F
    +        sel = df.select(df.value, df.key,
    +                        F.max("key").over(w.rowsBetween(0, 1)),
    +                        F.min("key").over(w.rowsBetween(0, 1)),
    +                        F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
    +                        F.rowNumber().over(w),
    +                        F.rank().over(w),
    +                        F.denseRank().over(w),
    +                        F.ntile(2).over(w))
    +        rs = sorted(sel.collect())
    +        expected = [
    +            ("1", 1, 1, 1, 1, 1, 1, 1, 1),
    +            ("2", 1, 1, 1, 3, 1, 1, 1, 1),
    +            ("2", 1, 2, 1, 3, 2, 1, 1, 1),
    +            ("2", 2, 2, 2, 3, 3, 3, 2, 2)
    +        ]
    +        for r, ex in zip(rs, expected):
    +            self.assertEqual(tuple(r), ex[:len(r)])
    +
     if __name__ == "__main__":
         unittest.main()
    diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/types.py
    similarity index 75%
    rename from python/pyspark/sql/_types.py
    rename to python/pyspark/sql/types.py
    index b96851a174d49..f75791fad1612 100644
    --- a/python/pyspark/sql/_types.py
    +++ b/python/pyspark/sql/types.py
    @@ -19,13 +19,10 @@
     import decimal
     import time
     import datetime
    -import keyword
    -import warnings
    +import calendar
     import json
     import re
    -import weakref
     from array import array
    -from operator import itemgetter
     
     if sys.version >= "3":
         long = int
    @@ -70,59 +67,132 @@ def json(self):
                               separators=(',', ':'),
                               sort_keys=True)
     
    +    def needConversion(self):
    +        """
    +        Does this type need to conversion between Python object and internal SQL object.
    +
    +        This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
    +        """
    +        return False
    +
    +    def toInternal(self, obj):
    +        """
    +        Converts a Python object into an internal SQL object.
    +        """
    +        return obj
    +
    +    def fromInternal(self, obj):
    +        """
    +        Converts an internal SQL object into a native Python object.
    +        """
    +        return obj
    +
     
     # This singleton pattern does not work with pickle, you will get
     # another object after pickle and unpickle
    -class PrimitiveTypeSingleton(type):
    -    """Metaclass for PrimitiveType"""
    +class DataTypeSingleton(type):
    +    """Metaclass for DataType"""
     
         _instances = {}
     
         def __call__(cls):
             if cls not in cls._instances:
    -            cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__()
    +            cls._instances[cls] = super(DataTypeSingleton, cls).__call__()
             return cls._instances[cls]
     
     
    -class PrimitiveType(DataType):
    -    """Spark SQL PrimitiveType"""
    +class NullType(DataType):
    +    """Null type.
     
    -    __metaclass__ = PrimitiveTypeSingleton
    +    The data type representing None, used for the types that cannot be inferred.
    +    """
     
    +    __metaclass__ = DataTypeSingleton
     
    -class NullType(PrimitiveType):
    -    """Null type.
     
    -    The data type representing None, used for the types that cannot be inferred.
    +class AtomicType(DataType):
    +    """An internal type used to represent everything that is not
    +    null, UDTs, arrays, structs, and maps."""
    +
    +
    +class NumericType(AtomicType):
    +    """Numeric data types.
    +    """
    +
    +
    +class IntegralType(NumericType):
    +    """Integral data types.
         """
     
    +    __metaclass__ = DataTypeSingleton
     
    -class StringType(PrimitiveType):
    +
    +class FractionalType(NumericType):
    +    """Fractional data types.
    +    """
    +
    +
    +class StringType(AtomicType):
         """String data type.
         """
     
    +    __metaclass__ = DataTypeSingleton
    +
     
    -class BinaryType(PrimitiveType):
    +class BinaryType(AtomicType):
         """Binary (byte array) data type.
         """
     
    +    __metaclass__ = DataTypeSingleton
     
    -class BooleanType(PrimitiveType):
    +
    +class BooleanType(AtomicType):
         """Boolean data type.
         """
     
    +    __metaclass__ = DataTypeSingleton
    +
     
    -class DateType(PrimitiveType):
    +class DateType(AtomicType):
         """Date (datetime.date) data type.
         """
     
    +    __metaclass__ = DataTypeSingleton
    +
    +    EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
    +
    +    def needConversion(self):
    +        return True
    +
    +    def toInternal(self, d):
    +        return d and d.toordinal() - self.EPOCH_ORDINAL
    +
    +    def fromInternal(self, v):
    +        return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL)
    +
     
    -class TimestampType(PrimitiveType):
    +class TimestampType(AtomicType):
         """Timestamp (datetime.datetime) data type.
         """
     
    +    __metaclass__ = DataTypeSingleton
    +
    +    def needConversion(self):
    +        return True
    +
    +    def toInternal(self, dt):
    +        if dt is not None:
    +            seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
    +                       else time.mktime(dt.timetuple()))
    +            return int(seconds * 1e6 + dt.microsecond)
    +
    +    def fromInternal(self, ts):
    +        if ts is not None:
    +            # using int to avoid precision loss in float
    +            return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000)
    +
     
    -class DecimalType(DataType):
    +class DecimalType(FractionalType):
         """Decimal (decimal.Decimal) data type.
         """
     
    @@ -150,31 +220,35 @@ def __repr__(self):
                 return "DecimalType()"
     
     
    -class DoubleType(PrimitiveType):
    +class DoubleType(FractionalType):
         """Double data type, representing double precision floats.
         """
     
    +    __metaclass__ = DataTypeSingleton
     
    -class FloatType(PrimitiveType):
    +
    +class FloatType(FractionalType):
         """Float data type, representing single precision floats.
         """
     
    +    __metaclass__ = DataTypeSingleton
    +
     
    -class ByteType(PrimitiveType):
    +class ByteType(IntegralType):
         """Byte data type, i.e. a signed integer in a single byte.
         """
         def simpleString(self):
             return 'tinyint'
     
     
    -class IntegerType(PrimitiveType):
    +class IntegerType(IntegralType):
         """Int data type, i.e. a signed 32-bit integer.
         """
         def simpleString(self):
             return 'int'
     
     
    -class LongType(PrimitiveType):
    +class LongType(IntegralType):
         """Long data type, i.e. a signed 64-bit integer.
     
         If the values are beyond the range of [-9223372036854775808, 9223372036854775807],
    @@ -184,7 +258,7 @@ def simpleString(self):
             return 'bigint'
     
     
    -class ShortType(PrimitiveType):
    +class ShortType(IntegralType):
         """Short data type, i.e. a signed 16-bit integer.
         """
         def simpleString(self):
    @@ -226,6 +300,19 @@ def fromJson(cls, json):
             return ArrayType(_parse_datatype_json_value(json["elementType"]),
                              json["containsNull"])
     
    +    def needConversion(self):
    +        return self.elementType.needConversion()
    +
    +    def toInternal(self, obj):
    +        if not self.needConversion():
    +            return obj
    +        return obj and [self.elementType.toInternal(v) for v in obj]
    +
    +    def fromInternal(self, obj):
    +        if not self.needConversion():
    +            return obj
    +        return obj and [self.elementType.fromInternal(v) for v in obj]
    +
     
     class MapType(DataType):
         """Map data type.
    @@ -271,6 +358,21 @@ def fromJson(cls, json):
                            _parse_datatype_json_value(json["valueType"]),
                            json["valueContainsNull"])
     
    +    def needConversion(self):
    +        return self.keyType.needConversion() or self.valueType.needConversion()
    +
    +    def toInternal(self, obj):
    +        if not self.needConversion():
    +            return obj
    +        return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v))
    +                            for k, v in obj.items())
    +
    +    def fromInternal(self, obj):
    +        if not self.needConversion():
    +            return obj
    +        return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v))
    +                            for k, v in obj.items())
    +
     
     class StructField(DataType):
         """A field in :class:`StructType`.
    @@ -278,7 +380,7 @@ class StructField(DataType):
         :param name: string, name of the field.
         :param dataType: :class:`DataType` of the field.
         :param nullable: boolean, whether the field can be null (None) or not.
    -    :param metadata: a dict from string to simple type that can be serialized to JSON automatically
    +    :param metadata: a dict from string to simple type that can be toInternald to JSON automatically
         """
     
         def __init__(self, name, dataType, nullable=True, metadata=None):
    @@ -291,6 +393,8 @@ def __init__(self, name, dataType, nullable=True, metadata=None):
             False
             """
             assert isinstance(dataType, DataType), "dataType should be DataType"
    +        if not isinstance(name, str):
    +            name = name.encode('utf-8')
             self.name = name
             self.dataType = dataType
             self.nullable = nullable
    @@ -316,14 +420,22 @@ def fromJson(cls, json):
                                json["nullable"],
                                json["metadata"])
     
    +    def needConversion(self):
    +        return self.dataType.needConversion()
    +
    +    def toInternal(self, obj):
    +        return self.dataType.toInternal(obj)
    +
    +    def fromInternal(self, obj):
    +        return self.dataType.fromInternal(obj)
    +
     
     class StructType(DataType):
         """Struct type, consisting of a list of :class:`StructField`.
     
         This is the data type representing a :class:`Row`.
         """
    -
    -    def __init__(self, fields):
    +    def __init__(self, fields=None):
             """
             >>> struct1 = StructType([StructField("f1", StringType(), True)])
             >>> struct2 = StructType([StructField("f1", StringType(), True)])
    @@ -335,8 +447,58 @@ def __init__(self, fields):
             >>> struct1 == struct2
             False
             """
    -        assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType"
    -        self.fields = fields
    +        if not fields:
    +            self.fields = []
    +            self.names = []
    +        else:
    +            self.fields = fields
    +            self.names = [f.name for f in fields]
    +            assert all(isinstance(f, StructField) for f in fields),\
    +                "fields should be a list of StructField"
    +        self._needSerializeFields = None
    +
    +    def add(self, field, data_type=None, nullable=True, metadata=None):
    +        """
    +        Construct a StructType by adding new elements to it to define the schema. The method accepts
    +        either:
    +            a) A single parameter which is a StructField object.
    +            b) Between 2 and 4 parameters as (name, data_type, nullable (optional),
    +             metadata(optional). The data_type parameter may be either a String or a DataType object
    +
    +        >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
    +        >>> struct2 = StructType([StructField("f1", StringType(), True),\
    +         StructField("f2", StringType(), True, None)])
    +        >>> struct1 == struct2
    +        True
    +        >>> struct1 = StructType().add(StructField("f1", StringType(), True))
    +        >>> struct2 = StructType([StructField("f1", StringType(), True)])
    +        >>> struct1 == struct2
    +        True
    +        >>> struct1 = StructType().add("f1", "string", True)
    +        >>> struct2 = StructType([StructField("f1", StringType(), True)])
    +        >>> struct1 == struct2
    +        True
    +
    +        :param field: Either the name of the field or a StructField object
    +        :param data_type: If present, the DataType of the StructField to create
    +        :param nullable: Whether the field to add should be nullable (default True)
    +        :param metadata: Any additional metadata (default None)
    +        :return: a new updated StructType
    +        """
    +        if isinstance(field, StructField):
    +            self.fields.append(field)
    +            self.names.append(field.name)
    +        else:
    +            if isinstance(field, str) and data_type is None:
    +                raise ValueError("Must specify DataType if passing name of struct_field to create.")
    +
    +            if isinstance(data_type, str):
    +                data_type_f = _parse_datatype_json_value(data_type)
    +            else:
    +                data_type_f = data_type
    +            self.fields.append(StructField(field, data_type_f, nullable, metadata))
    +            self.names.append(field)
    +        return self
     
         def simpleString(self):
             return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields))
    @@ -353,6 +515,41 @@ def jsonValue(self):
         def fromJson(cls, json):
             return StructType([StructField.fromJson(f) for f in json["fields"]])
     
    +    def needConversion(self):
    +        # We need convert Row()/namedtuple into tuple()
    +        return True
    +
    +    def toInternal(self, obj):
    +        if obj is None:
    +            return
    +
    +        if self._needSerializeFields is None:
    +            self._needSerializeFields = any(f.needConversion() for f in self.fields)
    +
    +        if self._needSerializeFields:
    +            if isinstance(obj, dict):
    +                return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields))
    +            elif isinstance(obj, (tuple, list)):
    +                return tuple(f.toInternal(v) for f, v in zip(self.fields, obj))
    +            else:
    +                raise ValueError("Unexpected tuple %r with StructType" % obj)
    +        else:
    +            if isinstance(obj, dict):
    +                return tuple(obj.get(n) for n in self.names)
    +            elif isinstance(obj, (list, tuple)):
    +                return tuple(obj)
    +            else:
    +                raise ValueError("Unexpected tuple %r with StructType" % obj)
    +
    +    def fromInternal(self, obj):
    +        if obj is None:
    +            return
    +        if isinstance(obj, Row):
    +            # it's already converted by pickler
    +            return obj
    +        values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)]
    +        return _create_row(self.names, values)
    +
     
     class UserDefinedType(DataType):
         """User-defined type (UDT).
    @@ -385,17 +582,35 @@ def scalaUDT(cls):
             """
             raise NotImplementedError("UDT must have a paired Scala UDT.")
     
    +    def needConversion(self):
    +        return True
    +
    +    @classmethod
    +    def _cachedSqlType(cls):
    +        """
    +        Cache the sqlType() into class, because it's heavy used in `toInternal`.
    +        """
    +        if not hasattr(cls, "_cached_sql_type"):
    +            cls._cached_sql_type = cls.sqlType()
    +        return cls._cached_sql_type
    +
    +    def toInternal(self, obj):
    +        return self._cachedSqlType().toInternal(self.serialize(obj))
    +
    +    def fromInternal(self, obj):
    +        return self.deserialize(self._cachedSqlType().fromInternal(obj))
    +
         def serialize(self, obj):
             """
             Converts the a user-type object into a SQL datum.
             """
    -        raise NotImplementedError("UDT must implement serialize().")
    +        raise NotImplementedError("UDT must implement toInternal().")
     
         def deserialize(self, datum):
             """
             Converts a SQL datum into a user-type object.
             """
    -        raise NotImplementedError("UDT must implement deserialize().")
    +        raise NotImplementedError("UDT must implement fromInternal().")
     
         def simpleString(self):
             return 'udt'
    @@ -426,11 +641,9 @@ def __eq__(self, other):
             return type(self) == type(other)
     
     
    -_all_primitive_types = dict((v.typeName(), v)
    -                            for v in list(globals().values())
    -                            if (type(v) is type or type(v) is PrimitiveTypeSingleton)
    -                            and v.__base__ == PrimitiveType)
    -
    +_atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType,
    +                 ByteType, ShortType, IntegerType, LongType, DateType, TimestampType]
    +_all_atomic_types = dict((t.typeName(), t) for t in _atomic_types)
     _all_complex_types = dict((v.typeName(), v)
                               for v in [ArrayType, MapType, StructType])
     
    @@ -444,7 +657,7 @@ def _parse_datatype_json_string(json_string):
         ...     scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json())
         ...     python_datatype = _parse_datatype_json_string(scala_datatype.json())
         ...     assert datatype == python_datatype
    -    >>> for cls in _all_primitive_types.values():
    +    >>> for cls in _all_atomic_types.values():
         ...     check_datatype(cls())
     
         >>> # Simple ArrayType.
    @@ -494,8 +707,8 @@ def _parse_datatype_json_string(json_string):
     
     def _parse_datatype_json_value(json_value):
         if not isinstance(json_value, dict):
    -        if json_value in _all_primitive_types.keys():
    -            return _all_primitive_types[json_value]()
    +        if json_value in _all_atomic_types.keys():
    +            return _all_atomic_types[json_value]()
             elif json_value == 'decimal':
                 return DecimalType()
             elif _FIXED_DECIMAL.match(json_value):
    @@ -594,93 +807,6 @@ def _infer_schema(row):
         return StructType(fields)
     
     
    -def _need_python_to_sql_conversion(dataType):
    -    """
    -    Checks whether we need python to sql conversion for the given type.
    -    For now, only UDTs need this conversion.
    -
    -    >>> _need_python_to_sql_conversion(DoubleType())
    -    False
    -    >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
    -    ...                       StructField("values", ArrayType(DoubleType(), False), False)])
    -    >>> _need_python_to_sql_conversion(schema0)
    -    False
    -    >>> _need_python_to_sql_conversion(ExamplePointUDT())
    -    True
    -    >>> schema1 = ArrayType(ExamplePointUDT(), False)
    -    >>> _need_python_to_sql_conversion(schema1)
    -    True
    -    >>> schema2 = StructType([StructField("label", DoubleType(), False),
    -    ...                       StructField("point", ExamplePointUDT(), False)])
    -    >>> _need_python_to_sql_conversion(schema2)
    -    True
    -    """
    -    if isinstance(dataType, StructType):
    -        return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
    -    elif isinstance(dataType, ArrayType):
    -        return _need_python_to_sql_conversion(dataType.elementType)
    -    elif isinstance(dataType, MapType):
    -        return _need_python_to_sql_conversion(dataType.keyType) or \
    -            _need_python_to_sql_conversion(dataType.valueType)
    -    elif isinstance(dataType, UserDefinedType):
    -        return True
    -    else:
    -        return False
    -
    -
    -def _python_to_sql_converter(dataType):
    -    """
    -    Returns a converter that converts a Python object into a SQL datum for the given type.
    -
    -    >>> conv = _python_to_sql_converter(DoubleType())
    -    >>> conv(1.0)
    -    1.0
    -    >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
    -    >>> conv([1.0, 2.0])
    -    [1.0, 2.0]
    -    >>> conv = _python_to_sql_converter(ExamplePointUDT())
    -    >>> conv(ExamplePoint(1.0, 2.0))
    -    [1.0, 2.0]
    -    >>> schema = StructType([StructField("label", DoubleType(), False),
    -    ...                      StructField("point", ExamplePointUDT(), False)])
    -    >>> conv = _python_to_sql_converter(schema)
    -    >>> conv((1.0, ExamplePoint(1.0, 2.0)))
    -    (1.0, [1.0, 2.0])
    -    """
    -    if not _need_python_to_sql_conversion(dataType):
    -        return lambda x: x
    -
    -    if isinstance(dataType, StructType):
    -        names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
    -        converters = [_python_to_sql_converter(t) for t in types]
    -
    -        def converter(obj):
    -            if isinstance(obj, dict):
    -                return tuple(c(obj.get(n)) for n, c in zip(names, converters))
    -            elif isinstance(obj, tuple):
    -                if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
    -                    return tuple(c(v) for c, v in zip(converters, obj))
    -                elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):  # k-v pairs
    -                    d = dict(obj)
    -                    return tuple(c(d.get(n)) for n, c in zip(names, converters))
    -                else:
    -                    return tuple(c(v) for c, v in zip(converters, obj))
    -            else:
    -                raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
    -        return converter
    -    elif isinstance(dataType, ArrayType):
    -        element_converter = _python_to_sql_converter(dataType.elementType)
    -        return lambda a: [element_converter(v) for v in a]
    -    elif isinstance(dataType, MapType):
    -        key_converter = _python_to_sql_converter(dataType.keyType)
    -        value_converter = _python_to_sql_converter(dataType.valueType)
    -        return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
    -    elif isinstance(dataType, UserDefinedType):
    -        return lambda obj: dataType.serialize(obj)
    -    else:
    -        raise ValueError("Unexpected type %r" % dataType)
    -
    -
     def _has_nulltype(dt):
         """ Return whether there is NullType in `dt` or not """
         if isinstance(dt, StructType):
    @@ -930,7 +1056,7 @@ def _infer_schema_type(obj, dataType):
         DecimalType: (decimal.Decimal,),
         StringType: (str, unicode),
         BinaryType: (bytearray,),
    -    DateType: (datetime.date,),
    +    DateType: (datetime.date, datetime.datetime),
         TimestampType: (datetime.datetime,),
         ArrayType: (list, tuple, array),
         MapType: (dict,),
    @@ -968,19 +1094,26 @@ def _verify_type(obj, dataType):
         if obj is None:
             return
     
    +    # StringType can work with any types
    +    if isinstance(dataType, StringType):
    +        return
    +
         if isinstance(dataType, UserDefinedType):
             if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
                 raise ValueError("%r is not an instance of type %r" % (obj, dataType))
    -        _verify_type(dataType.serialize(obj), dataType.sqlType())
    +        _verify_type(dataType.toInternal(obj), dataType.sqlType())
             return
     
         _type = type(dataType)
         assert _type in _acceptable_types, "unknown datatype: %s" % dataType
     
    -    # subclass of them can not be deserialized in JVM
    -    if type(obj) not in _acceptable_types[_type]:
    -        raise TypeError("%s can not accept object in type %s"
    -                        % (dataType, type(obj)))
    +    if _type is StructType:
    +        if not isinstance(obj, (tuple, list)):
    +            raise TypeError("StructType can not accept object in type %s" % type(obj))
    +    else:
    +        # subclass of them can not be fromInternald in JVM
    +        if type(obj) not in _acceptable_types[_type]:
    +            raise TypeError("%s can not accept object in type %s" % (dataType, type(obj)))
     
         if isinstance(dataType, ArrayType):
             for i in obj:
    @@ -998,159 +1131,10 @@ def _verify_type(obj, dataType):
             for v, f in zip(obj, dataType.fields):
                 _verify_type(v, f.dataType)
     
    -_cached_cls = weakref.WeakValueDictionary()
    -
    -
    -def _restore_object(dataType, obj):
    -    """ Restore object during unpickling. """
    -    # use id(dataType) as key to speed up lookup in dict
    -    # Because of batched pickling, dataType will be the
    -    # same object in most cases.
    -    k = id(dataType)
    -    cls = _cached_cls.get(k)
    -    if cls is None or cls.__datatype is not dataType:
    -        # use dataType as key to avoid create multiple class
    -        cls = _cached_cls.get(dataType)
    -        if cls is None:
    -            cls = _create_cls(dataType)
    -            _cached_cls[dataType] = cls
    -        cls.__datatype = dataType
    -        _cached_cls[k] = cls
    -    return cls(obj)
    -
    -
    -def _create_object(cls, v):
    -    """ Create an customized object with class `cls`. """
    -    # datetime.date would be deserialized as datetime.datetime
    -    # from java type, so we need to set it back.
    -    if cls is datetime.date and isinstance(v, datetime.datetime):
    -        return v.date()
    -    return cls(v) if v is not None else v
    -
    -
    -def _create_getter(dt, i):
    -    """ Create a getter for item `i` with schema """
    -    cls = _create_cls(dt)
    -
    -    def getter(self):
    -        return _create_object(cls, self[i])
    -
    -    return getter
    -
    -
    -def _has_struct_or_date(dt):
    -    """Return whether `dt` is or has StructType/DateType in it"""
    -    if isinstance(dt, StructType):
    -        return True
    -    elif isinstance(dt, ArrayType):
    -        return _has_struct_or_date(dt.elementType)
    -    elif isinstance(dt, MapType):
    -        return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType)
    -    elif isinstance(dt, DateType):
    -        return True
    -    elif isinstance(dt, UserDefinedType):
    -        return True
    -    return False
    -
    -
    -def _create_properties(fields):
    -    """Create properties according to fields"""
    -    ps = {}
    -    for i, f in enumerate(fields):
    -        name = f.name
    -        if (name.startswith("__") and name.endswith("__")
    -                or keyword.iskeyword(name)):
    -            warnings.warn("field name %s can not be accessed in Python,"
    -                          "use position to access it instead" % name)
    -        if _has_struct_or_date(f.dataType):
    -            # delay creating object until accessing it
    -            getter = _create_getter(f.dataType, i)
    -        else:
    -            getter = itemgetter(i)
    -        ps[name] = property(getter)
    -    return ps
    -
    -
    -def _create_cls(dataType):
    -    """
    -    Create an class by dataType
    -
    -    The created class is similar to namedtuple, but can have nested schema.
    -
    -    >>> schema = _parse_schema_abstract("a b c")
    -    >>> row = (1, 1.0, "str")
    -    >>> schema = _infer_schema_type(row, schema)
    -    >>> obj = _create_cls(schema)(row)
    -    >>> import pickle
    -    >>> pickle.loads(pickle.dumps(obj))
    -    Row(a=1, b=1.0, c='str')
    -
    -    >>> row = [[1], {"key": (1, 2.0)}]
    -    >>> schema = _parse_schema_abstract("a[] b{c d}")
    -    >>> schema = _infer_schema_type(row, schema)
    -    >>> obj = _create_cls(schema)(row)
    -    >>> pickle.loads(pickle.dumps(obj))
    -    Row(a=[1], b={'key': Row(c=1, d=2.0)})
    -    >>> pickle.loads(pickle.dumps(obj.a))
    -    [1]
    -    >>> pickle.loads(pickle.dumps(obj.b))
    -    {'key': Row(c=1, d=2.0)}
    -    """
    -
    -    if isinstance(dataType, ArrayType):
    -        cls = _create_cls(dataType.elementType)
    -
    -        def List(l):
    -            if l is None:
    -                return
    -            return [_create_object(cls, v) for v in l]
    -
    -        return List
    -
    -    elif isinstance(dataType, MapType):
    -        kcls = _create_cls(dataType.keyType)
    -        vcls = _create_cls(dataType.valueType)
    -
    -        def Dict(d):
    -            if d is None:
    -                return
    -            return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items())
    -
    -        return Dict
    -
    -    elif isinstance(dataType, DateType):
    -        return datetime.date
    -
    -    elif isinstance(dataType, UserDefinedType):
    -        return lambda datum: dataType.deserialize(datum)
    -
    -    elif not isinstance(dataType, StructType):
    -        # no wrapper for primitive types
    -        return lambda x: x
    -
    -    class Row(tuple):
    -
    -        """ Row in DataFrame """
    -        __datatype = dataType
    -        __fields__ = tuple(f.name for f in dataType.fields)
    -        __slots__ = ()
    -
    -        # create property for fast access
    -        locals().update(_create_properties(dataType.fields))
    -
    -        def asDict(self):
    -            """ Return as a dict """
    -            return dict((n, getattr(self, n)) for n in self.__fields__)
    -
    -        def __repr__(self):
    -            # call collect __repr__ for nested objects
    -            return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
    -                                          for n in self.__fields__))
    -
    -        def __reduce__(self):
    -            return (_restore_object, (self.__datatype, tuple(self)))
     
    -    return Row
    +# This is used to unpickle a Row from JVM
    +def _create_row_inbound_converter(dataType):
    +    return lambda *a: dataType.fromInternal(a)
     
     
     def _create_row(fields, values):
    diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
    new file mode 100644
    index 0000000000000..cc5b2c088b7cc
    --- /dev/null
    +++ b/python/pyspark/sql/utils.py
    @@ -0,0 +1,54 @@
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one or more
    +# contributor license agreements.  See the NOTICE file distributed with
    +# this work for additional information regarding copyright ownership.
    +# The ASF licenses this file to You under the Apache License, Version 2.0
    +# (the "License"); you may not use this file except in compliance with
    +# the License.  You may obtain a copy of the License at
    +#
    +#    http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing, software
    +# distributed under the License is distributed on an "AS IS" BASIS,
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +# See the License for the specific language governing permissions and
    +# limitations under the License.
    +#
    +
    +import py4j
    +
    +
    +class AnalysisException(Exception):
    +    """
    +    Failed to analyze a SQL query plan.
    +    """
    +
    +
    +def capture_sql_exception(f):
    +    def deco(*a, **kw):
    +        try:
    +            return f(*a, **kw)
    +        except py4j.protocol.Py4JJavaError as e:
    +            s = e.java_exception.toString()
    +            if s.startswith('org.apache.spark.sql.AnalysisException: '):
    +                raise AnalysisException(s.split(': ', 1)[1])
    +            raise
    +    return deco
    +
    +
    +def install_exception_handler():
    +    """
    +    Hook an exception handler into Py4j, which could capture some SQL exceptions in Java.
    +
    +    When calling Java API, it will call `get_return_value` to parse the returned object.
    +    If any exception happened in JVM, the result will be Java exception object, it raise
    +    py4j.protocol.Py4JJavaError. We replace the original `get_return_value` with one that
    +    could capture the Java exception and throw a Python one (with the same error message).
    +
    +    It's idempotent, could be called multiple times.
    +    """
    +    original = py4j.protocol.get_return_value
    +    # The original `get_return_value` is not patched, it's idempotent.
    +    patched = capture_sql_exception(original)
    +    # only patch the one used in in py4j.java_gateway (call Java API)
    +    py4j.java_gateway.get_return_value = patched
    diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py
    new file mode 100644
    index 0000000000000..c74745c726a0c
    --- /dev/null
    +++ b/python/pyspark/sql/window.py
    @@ -0,0 +1,157 @@
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one or more
    +# contributor license agreements.  See the NOTICE file distributed with
    +# this work for additional information regarding copyright ownership.
    +# The ASF licenses this file to You under the Apache License, Version 2.0
    +# (the "License"); you may not use this file except in compliance with
    +# the License.  You may obtain a copy of the License at
    +#
    +#    http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing, software
    +# distributed under the License is distributed on an "AS IS" BASIS,
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +# See the License for the specific language governing permissions and
    +# limitations under the License.
    +#
    +
    +import sys
    +
    +from pyspark import SparkContext
    +from pyspark.sql import since
    +from pyspark.sql.column import _to_seq, _to_java_column
    +
    +__all__ = ["Window", "WindowSpec"]
    +
    +
    +def _to_java_cols(cols):
    +    sc = SparkContext._active_spark_context
    +    if len(cols) == 1 and isinstance(cols[0], list):
    +        cols = cols[0]
    +    return _to_seq(sc, cols, _to_java_column)
    +
    +
    +class Window(object):
    +    """
    +    Utility functions for defining window in DataFrames.
    +
    +    For example:
    +
    +    >>> # PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
    +    >>> window = Window.partitionBy("country").orderBy("date").rowsBetween(-sys.maxsize, 0)
    +
    +    >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING
    +    >>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3)
    +
    +    .. note:: Experimental
    +
    +    .. versionadded:: 1.4
    +    """
    +    @staticmethod
    +    @since(1.4)
    +    def partitionBy(*cols):
    +        """
    +        Creates a :class:`WindowSpec` with the partitioning defined.
    +        """
    +        sc = SparkContext._active_spark_context
    +        jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols))
    +        return WindowSpec(jspec)
    +
    +    @staticmethod
    +    @since(1.4)
    +    def orderBy(*cols):
    +        """
    +        Creates a :class:`WindowSpec` with the partitioning defined.
    +        """
    +        sc = SparkContext._active_spark_context
    +        jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols))
    +        return WindowSpec(jspec)
    +
    +
    +class WindowSpec(object):
    +    """
    +    A window specification that defines the partitioning, ordering,
    +    and frame boundaries.
    +
    +    Use the static methods in :class:`Window` to create a :class:`WindowSpec`.
    +
    +    .. note:: Experimental
    +
    +    .. versionadded:: 1.4
    +    """
    +
    +    _JAVA_MAX_LONG = (1 << 63) - 1
    +    _JAVA_MIN_LONG = - (1 << 63)
    +
    +    def __init__(self, jspec):
    +        self._jspec = jspec
    +
    +    @since(1.4)
    +    def partitionBy(self, *cols):
    +        """
    +        Defines the partitioning columns in a :class:`WindowSpec`.
    +
    +        :param cols: names of columns or expressions
    +        """
    +        return WindowSpec(self._jspec.partitionBy(_to_java_cols(cols)))
    +
    +    @since(1.4)
    +    def orderBy(self, *cols):
    +        """
    +        Defines the ordering columns in a :class:`WindowSpec`.
    +
    +        :param cols: names of columns or expressions
    +        """
    +        return WindowSpec(self._jspec.orderBy(_to_java_cols(cols)))
    +
    +    @since(1.4)
    +    def rowsBetween(self, start, end):
    +        """
    +        Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive).
    +
    +        Both `start` and `end` are relative positions from the current row.
    +        For example, "0" means "current row", while "-1" means the row before
    +        the current row, and "5" means the fifth row after the current row.
    +
    +        :param start: boundary start, inclusive.
    +                      The frame is unbounded if this is ``-sys.maxsize`` (or lower).
    +        :param end: boundary end, inclusive.
    +                    The frame is unbounded if this is ``sys.maxsize`` (or higher).
    +        """
    +        if start <= -sys.maxsize:
    +            start = self._JAVA_MIN_LONG
    +        if end >= sys.maxsize:
    +            end = self._JAVA_MAX_LONG
    +        return WindowSpec(self._jspec.rowsBetween(start, end))
    +
    +    @since(1.4)
    +    def rangeBetween(self, start, end):
    +        """
    +        Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive).
    +
    +        Both `start` and `end` are relative from the current row. For example,
    +        "0" means "current row", while "-1" means one off before the current row,
    +        and "5" means the five off after the current row.
    +
    +        :param start: boundary start, inclusive.
    +                      The frame is unbounded if this is ``-sys.maxsize`` (or lower).
    +        :param end: boundary end, inclusive.
    +                    The frame is unbounded if this is ``sys.maxsize`` (or higher).
    +        """
    +        if start <= -sys.maxsize:
    +            start = self._JAVA_MIN_LONG
    +        if end >= sys.maxsize:
    +            end = self._JAVA_MAX_LONG
    +        return WindowSpec(self._jspec.rangeBetween(start, end))
    +
    +
    +def _test():
    +    import doctest
    +    SparkContext('local[4]', 'PythonTest')
    +    (failure_count, test_count) = doctest.testmod()
    +    if failure_count:
    +        exit(-1)
    +
    +
    +if __name__ == "__main__":
    +    _test()
    diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py
    index ff097985fae3e..8dcb9645cdc6b 100644
    --- a/python/pyspark/streaming/dstream.py
    +++ b/python/pyspark/streaming/dstream.py
    @@ -176,7 +176,7 @@ def takeAndPrint(time, rdd):
                     print(record)
                 if len(taken) > num:
                     print("...")
    -            print()
    +            print("")
     
             self.foreachRDD(takeAndPrint)
     
    diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py
    new file mode 100644
    index 0000000000000..cbb573f226bbe
    --- /dev/null
    +++ b/python/pyspark/streaming/flume.py
    @@ -0,0 +1,147 @@
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one or more
    +# contributor license agreements.  See the NOTICE file distributed with
    +# this work for additional information regarding copyright ownership.
    +# The ASF licenses this file to You under the Apache License, Version 2.0
    +# (the "License"); you may not use this file except in compliance with
    +# the License.  You may obtain a copy of the License at
    +#
    +#    http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing, software
    +# distributed under the License is distributed on an "AS IS" BASIS,
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +# See the License for the specific language governing permissions and
    +# limitations under the License.
    +#
    +
    +import sys
    +if sys.version >= "3":
    +    from io import BytesIO
    +else:
    +    from StringIO import StringIO
    +from py4j.java_gateway import Py4JJavaError
    +
    +from pyspark.storagelevel import StorageLevel
    +from pyspark.serializers import PairDeserializer, NoOpSerializer, UTF8Deserializer, read_int
    +from pyspark.streaming import DStream
    +
    +__all__ = ['FlumeUtils', 'utf8_decoder']
    +
    +
    +def utf8_decoder(s):
    +    """ Decode the unicode as UTF-8 """
    +    return s and s.decode('utf-8')
    +
    +
    +class FlumeUtils(object):
    +
    +    @staticmethod
    +    def createStream(ssc, hostname, port,
    +                     storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2,
    +                     enableDecompression=False,
    +                     bodyDecoder=utf8_decoder):
    +        """
    +        Create an input stream that pulls events from Flume.
    +
    +        :param ssc:  StreamingContext object
    +        :param hostname:  Hostname of the slave machine to which the flume data will be sent
    +        :param port:  Port of the slave machine to which the flume data will be sent
    +        :param storageLevel:  Storage level to use for storing the received objects
    +        :param enableDecompression:  Should netty server decompress input stream
    +        :param bodyDecoder:  A function used to decode body (default is utf8_decoder)
    +        :return: A DStream object
    +        """
    +        jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
    +
    +        try:
    +            helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
    +                .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
    +            helper = helperClass.newInstance()
    +            jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression)
    +        except Py4JJavaError as e:
    +            if 'ClassNotFoundException' in str(e.java_exception):
    +                FlumeUtils._printErrorMsg(ssc.sparkContext)
    +            raise e
    +
    +        return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)
    +
    +    @staticmethod
    +    def createPollingStream(ssc, addresses,
    +                            storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2,
    +                            maxBatchSize=1000,
    +                            parallelism=5,
    +                            bodyDecoder=utf8_decoder):
    +        """
    +        Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent.
    +        This stream will poll the sink for data and will pull events as they are available.
    +
    +        :param ssc:  StreamingContext object
    +        :param addresses:  List of (host, port)s on which the Spark Sink is running.
    +        :param storageLevel:  Storage level to use for storing the received objects
    +        :param maxBatchSize:  The maximum number of events to be pulled from the Spark sink
    +                              in a single RPC call
    +        :param parallelism:  Number of concurrent requests this stream should send to the sink.
    +                             Note that having a higher number of requests concurrently being pulled
    +                             will result in this stream using more threads
    +        :param bodyDecoder:  A function used to decode body (default is utf8_decoder)
    +        :return: A DStream object
    +        """
    +        jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
    +        hosts = []
    +        ports = []
    +        for (host, port) in addresses:
    +            hosts.append(host)
    +            ports.append(port)
    +
    +        try:
    +            helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
    +                .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
    +            helper = helperClass.newInstance()
    +            jstream = helper.createPollingStream(
    +                ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism)
    +        except Py4JJavaError as e:
    +            if 'ClassNotFoundException' in str(e.java_exception):
    +                FlumeUtils._printErrorMsg(ssc.sparkContext)
    +            raise e
    +
    +        return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)
    +
    +    @staticmethod
    +    def _toPythonDStream(ssc, jstream, bodyDecoder):
    +        ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
    +        stream = DStream(jstream, ssc, ser)
    +
    +        def func(event):
    +            headersBytes = BytesIO(event[0]) if sys.version >= "3" else StringIO(event[0])
    +            headers = {}
    +            strSer = UTF8Deserializer()
    +            for i in range(0, read_int(headersBytes)):
    +                key = strSer.loads(headersBytes)
    +                value = strSer.loads(headersBytes)
    +                headers[key] = value
    +            body = bodyDecoder(event[1])
    +            return (headers, body)
    +        return stream.map(func)
    +
    +    @staticmethod
    +    def _printErrorMsg(sc):
    +        print("""
    +________________________________________________________________________________________________
    +
    +  Spark Streaming's Flume libraries not found in class path. Try one of the following.
    +
    +  1. Include the Flume library and its dependencies with in the
    +     spark-submit command as
    +
    +     $ bin/spark-submit --packages org.apache.spark:spark-streaming-flume:%s ...
    +
    +  2. Download the JAR of the artifact from Maven Central http://search.maven.org/,
    +     Group Id = org.apache.spark, Artifact Id = spark-streaming-flume-assembly, Version = %s.
    +     Then, include the jar in the spark-submit command as
    +
    +     $ bin/spark-submit --jars  ...
    +
    +________________________________________________________________________________________________
    +
    +""" % (sc.version, sc.version))
    diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
    index e278b29003f69..33dd596335b47 100644
    --- a/python/pyspark/streaming/kafka.py
    +++ b/python/pyspark/streaming/kafka.py
    @@ -21,6 +21,8 @@
     from pyspark.storagelevel import StorageLevel
     from pyspark.serializers import PairDeserializer, NoOpSerializer
     from pyspark.streaming import DStream
    +from pyspark.streaming.dstream import TransformedDStream
    +from pyspark.streaming.util import TransformFunction
     
     __all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder']
     
    @@ -122,8 +124,9 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={},
                 raise e
     
             ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
    -        stream = DStream(jstream, ssc, ser)
    -        return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
    +        stream = DStream(jstream, ssc, ser) \
    +            .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
    +        return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
     
         @staticmethod
         def createRDD(sc, kafkaParams, offsetRanges, leaders={},
    @@ -132,11 +135,12 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={},
             .. note:: Experimental
     
             Create a RDD from Kafka using offset ranges for each topic and partition.
    +
             :param sc:  SparkContext object
             :param kafkaParams: Additional params for Kafka
             :param offsetRanges:  list of offsetRange to specify topic:partition:[start, end) to consume
             :param leaders: Kafka brokers for each TopicAndPartition in offsetRanges.  May be an empty
    -                        map, in which case leaders will be looked up on the driver.
    +            map, in which case leaders will be looked up on the driver.
             :param keyDecoder:  A function used to decode key (default is utf8_decoder)
             :param valueDecoder:  A function used to decode value (default is utf8_decoder)
             :return: A RDD object
    @@ -160,8 +164,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={},
                 raise e
     
             ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
    -        rdd = RDD(jrdd, sc, ser)
    -        return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
    +        rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
    +        return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer)
     
         @staticmethod
         def _printErrorMsg(sc):
    @@ -199,14 +203,30 @@ def __init__(self, topic, partition, fromOffset, untilOffset):
             :param fromOffset: Inclusive starting offset.
             :param untilOffset: Exclusive ending offset.
             """
    -        self._topic = topic
    -        self._partition = partition
    -        self._fromOffset = fromOffset
    -        self._untilOffset = untilOffset
    +        self.topic = topic
    +        self.partition = partition
    +        self.fromOffset = fromOffset
    +        self.untilOffset = untilOffset
    +
    +    def __eq__(self, other):
    +        if isinstance(other, self.__class__):
    +            return (self.topic == other.topic
    +                    and self.partition == other.partition
    +                    and self.fromOffset == other.fromOffset
    +                    and self.untilOffset == other.untilOffset)
    +        else:
    +            return False
    +
    +    def __ne__(self, other):
    +        return not self.__eq__(other)
    +
    +    def __str__(self):
    +        return "OffsetRange(topic: %s, partition: %d, range: [%d -> %d]" \
    +               % (self.topic, self.partition, self.fromOffset, self.untilOffset)
     
         def _jOffsetRange(self, helper):
    -        return helper.createOffsetRange(self._topic, self._partition, self._fromOffset,
    -                                        self._untilOffset)
    +        return helper.createOffsetRange(self.topic, self.partition, self.fromOffset,
    +                                        self.untilOffset)
     
     
     class TopicAndPartition(object):
    @@ -243,3 +263,87 @@ def __init__(self, host, port):
     
         def _jBroker(self, helper):
             return helper.createBroker(self._host, self._port)
    +
    +
    +class KafkaRDD(RDD):
    +    """
    +    A Python wrapper of KafkaRDD, to provide additional information on normal RDD.
    +    """
    +
    +    def __init__(self, jrdd, ctx, jrdd_deserializer):
    +        RDD.__init__(self, jrdd, ctx, jrdd_deserializer)
    +
    +    def offsetRanges(self):
    +        """
    +        Get the OffsetRange of specific KafkaRDD.
    +        :return: A list of OffsetRange
    +        """
    +        try:
    +            helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
    +                .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
    +            helper = helperClass.newInstance()
    +            joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
    +        except Py4JJavaError as e:
    +            if 'ClassNotFoundException' in str(e.java_exception):
    +                KafkaUtils._printErrorMsg(self.ctx)
    +            raise e
    +
    +        ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset())
    +                  for o in joffsetRanges]
    +        return ranges
    +
    +
    +class KafkaDStream(DStream):
    +    """
    +    A Python wrapper of KafkaDStream
    +    """
    +
    +    def __init__(self, jdstream, ssc, jrdd_deserializer):
    +        DStream.__init__(self, jdstream, ssc, jrdd_deserializer)
    +
    +    def foreachRDD(self, func):
    +        """
    +        Apply a function to each RDD in this DStream.
    +        """
    +        if func.__code__.co_argcount == 1:
    +            old_func = func
    +            func = lambda r, rdd: old_func(rdd)
    +        jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) \
    +            .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser))
    +        api = self._ssc._jvm.PythonDStream
    +        api.callForeachRDD(self._jdstream, jfunc)
    +
    +    def transform(self, func):
    +        """
    +        Return a new DStream in which each RDD is generated by applying a function
    +        on each RDD of this DStream.
    +
    +        `func` can have one argument of `rdd`, or have two arguments of
    +        (`time`, `rdd`)
    +        """
    +        if func.__code__.co_argcount == 1:
    +            oldfunc = func
    +            func = lambda t, rdd: oldfunc(rdd)
    +        assert func.__code__.co_argcount == 2, "func should take one or two arguments"
    +
    +        return KafkaTransformedDStream(self, func)
    +
    +
    +class KafkaTransformedDStream(TransformedDStream):
    +    """
    +    Kafka specific wrapper of TransformedDStream to transform on Kafka RDD.
    +    """
    +
    +    def __init__(self, prev, func):
    +        TransformedDStream.__init__(self, prev, func)
    +
    +    @property
    +    def _jdstream(self):
    +        if self._jdstream_val is not None:
    +            return self._jdstream_val
    +
    +        jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) \
    +            .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser))
    +        dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
    +        self._jdstream_val = dstream.asJavaDStream()
    +        return self._jdstream_val
    diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
    index 33ea8c9293d74..4ecae1e4bf282 100644
    --- a/python/pyspark/streaming/tests.py
    +++ b/python/pyspark/streaming/tests.py
    @@ -15,6 +15,7 @@
     # limitations under the License.
     #
     
    +import glob
     import os
     import sys
     from itertools import chain
    @@ -37,12 +38,13 @@
     from pyspark.context import SparkConf, SparkContext, RDD
     from pyspark.streaming.context import StreamingContext
     from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition
    +from pyspark.streaming.flume import FlumeUtils
     
     
     class PySparkStreamingTestCase(unittest.TestCase):
     
    -    timeout = 4  # seconds
    -    duration = .2
    +    timeout = 10  # seconds
    +    duration = .5
     
         @classmethod
         def setUpClass(cls):
    @@ -379,13 +381,13 @@ def func(dstream):
     
     class WindowFunctionTests(PySparkStreamingTestCase):
     
    -    timeout = 5
    +    timeout = 15
     
         def test_window(self):
             input = [range(1), range(2), range(3), range(4), range(5)]
     
             def func(dstream):
    -            return dstream.window(.6, .2).count()
    +            return dstream.window(1.5, .5).count()
     
             expected = [[1], [3], [6], [9], [12], [9], [5]]
             self._test_func(input, func, expected)
    @@ -394,7 +396,7 @@ def test_count_by_window(self):
             input = [range(1), range(2), range(3), range(4), range(5)]
     
             def func(dstream):
    -            return dstream.countByWindow(.6, .2)
    +            return dstream.countByWindow(1.5, .5)
     
             expected = [[1], [3], [6], [9], [12], [9], [5]]
             self._test_func(input, func, expected)
    @@ -403,7 +405,7 @@ def test_count_by_window_large(self):
             input = [range(1), range(2), range(3), range(4), range(5), range(6)]
     
             def func(dstream):
    -            return dstream.countByWindow(1, .2)
    +            return dstream.countByWindow(2.5, .5)
     
             expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
             self._test_func(input, func, expected)
    @@ -412,7 +414,7 @@ def test_count_by_value_and_window(self):
             input = [range(1), range(2), range(3), range(4), range(5), range(6)]
     
             def func(dstream):
    -            return dstream.countByValueAndWindow(1, .2)
    +            return dstream.countByValueAndWindow(2.5, .5)
     
             expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]]
             self._test_func(input, func, expected)
    @@ -421,7 +423,7 @@ def test_group_by_key_and_window(self):
             input = [[('a', i)] for i in range(5)]
     
             def func(dstream):
    -            return dstream.groupByKeyAndWindow(.6, .2).mapValues(list)
    +            return dstream.groupByKeyAndWindow(1.5, .5).mapValues(list)
     
             expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])],
                         [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
    @@ -615,7 +617,6 @@ def test_kafka_stream(self):
     
             self._kafkaTestUtils.createTopic(topic)
             self._kafkaTestUtils.sendMessages(topic, sendData)
    -        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
     
             stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(),
                                              "test-streaming-consumer", {topic: 1},
    @@ -631,7 +632,6 @@ def test_kafka_direct_stream(self):
     
             self._kafkaTestUtils.createTopic(topic)
             self._kafkaTestUtils.sendMessages(topic, sendData)
    -        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
     
             stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
             self._validateStreamResult(sendData, stream)
    @@ -646,7 +646,6 @@ def test_kafka_direct_stream_from_offset(self):
     
             self._kafkaTestUtils.createTopic(topic)
             self._kafkaTestUtils.sendMessages(topic, sendData)
    -        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
     
             stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets)
             self._validateStreamResult(sendData, stream)
    @@ -661,7 +660,6 @@ def test_kafka_rdd(self):
     
             self._kafkaTestUtils.createTopic(topic)
             self._kafkaTestUtils.sendMessages(topic, sendData)
    -        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
             rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
             self._validateRddResult(sendData, rdd)
     
    @@ -677,9 +675,261 @@ def test_kafka_rdd_with_leaders(self):
     
             self._kafkaTestUtils.createTopic(topic)
             self._kafkaTestUtils.sendMessages(topic, sendData)
    -        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
             rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders)
             self._validateRddResult(sendData, rdd)
     
    +    @unittest.skipIf(sys.version >= "3", "long type not support")
    +    def test_kafka_rdd_get_offsetRanges(self):
    +        """Test Python direct Kafka RDD get OffsetRanges."""
    +        topic = self._randomTopic()
    +        sendData = {"a": 3, "b": 4, "c": 5}
    +        offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))]
    +        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()}
    +
    +        self._kafkaTestUtils.createTopic(topic)
    +        self._kafkaTestUtils.sendMessages(topic, sendData)
    +        rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
    +        self.assertEqual(offsetRanges, rdd.offsetRanges())
    +
    +    @unittest.skipIf(sys.version >= "3", "long type not support")
    +    def test_kafka_direct_stream_foreach_get_offsetRanges(self):
    +        """Test the Python direct Kafka stream foreachRDD get offsetRanges."""
    +        topic = self._randomTopic()
    +        sendData = {"a": 1, "b": 2, "c": 3}
    +        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
    +                       "auto.offset.reset": "smallest"}
    +
    +        self._kafkaTestUtils.createTopic(topic)
    +        self._kafkaTestUtils.sendMessages(topic, sendData)
    +
    +        stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
    +
    +        offsetRanges = []
    +
    +        def getOffsetRanges(_, rdd):
    +            for o in rdd.offsetRanges():
    +                offsetRanges.append(o)
    +
    +        stream.foreachRDD(getOffsetRanges)
    +        self.ssc.start()
    +        self.wait_for(offsetRanges, 1)
    +
    +        self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
    +
    +    @unittest.skipIf(sys.version >= "3", "long type not support")
    +    def test_kafka_direct_stream_transform_get_offsetRanges(self):
    +        """Test the Python direct Kafka stream transform get offsetRanges."""
    +        topic = self._randomTopic()
    +        sendData = {"a": 1, "b": 2, "c": 3}
    +        kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
    +                       "auto.offset.reset": "smallest"}
    +
    +        self._kafkaTestUtils.createTopic(topic)
    +        self._kafkaTestUtils.sendMessages(topic, sendData)
    +
    +        stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
    +
    +        offsetRanges = []
    +
    +        def transformWithOffsetRanges(rdd):
    +            for o in rdd.offsetRanges():
    +                offsetRanges.append(o)
    +            return rdd
    +
    +        stream.transform(transformWithOffsetRanges).foreachRDD(lambda rdd: rdd.count())
    +        self.ssc.start()
    +        self.wait_for(offsetRanges, 1)
    +
    +        self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
    +
    +
    +class FlumeStreamTests(PySparkStreamingTestCase):
    +    timeout = 20  # seconds
    +    duration = 1
    +
    +    def setUp(self):
    +        super(FlumeStreamTests, self).setUp()
    +
    +        utilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
    +            .loadClass("org.apache.spark.streaming.flume.FlumeTestUtils")
    +        self._utils = utilsClz.newInstance()
    +
    +    def tearDown(self):
    +        if self._utils is not None:
    +            self._utils.close()
    +            self._utils = None
    +
    +        super(FlumeStreamTests, self).tearDown()
    +
    +    def _startContext(self, n, compressed):
    +        # Start the StreamingContext and also collect the result
    +        dstream = FlumeUtils.createStream(self.ssc, "localhost", self._utils.getTestPort(),
    +                                          enableDecompression=compressed)
    +        result = []
    +
    +        def get_output(_, rdd):
    +            for event in rdd.collect():
    +                if len(result) < n:
    +                    result.append(event)
    +        dstream.foreachRDD(get_output)
    +        self.ssc.start()
    +        return result
    +
    +    def _validateResult(self, input, result):
    +        # Validate both the header and the body
    +        header = {"test": "header"}
    +        self.assertEqual(len(input), len(result))
    +        for i in range(0, len(input)):
    +            self.assertEqual(header, result[i][0])
    +            self.assertEqual(input[i], result[i][1])
    +
    +    def _writeInput(self, input, compressed):
    +        # Try to write input to the receiver until success or timeout
    +        start_time = time.time()
    +        while True:
    +            try:
    +                self._utils.writeInput(input, compressed)
    +                break
    +            except:
    +                if time.time() - start_time < self.timeout:
    +                    time.sleep(0.01)
    +                else:
    +                    raise
    +
    +    def test_flume_stream(self):
    +        input = [str(i) for i in range(1, 101)]
    +        result = self._startContext(len(input), False)
    +        self._writeInput(input, False)
    +        self.wait_for(result, len(input))
    +        self._validateResult(input, result)
    +
    +    def test_compressed_flume_stream(self):
    +        input = [str(i) for i in range(1, 101)]
    +        result = self._startContext(len(input), True)
    +        self._writeInput(input, True)
    +        self.wait_for(result, len(input))
    +        self._validateResult(input, result)
    +
    +
    +class FlumePollingStreamTests(PySparkStreamingTestCase):
    +    timeout = 20  # seconds
    +    duration = 1
    +    maxAttempts = 5
    +
    +    def setUp(self):
    +        utilsClz = \
    +            self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
    +                .loadClass("org.apache.spark.streaming.flume.PollingFlumeTestUtils")
    +        self._utils = utilsClz.newInstance()
    +
    +    def tearDown(self):
    +        if self._utils is not None:
    +            self._utils.close()
    +            self._utils = None
    +
    +    def _writeAndVerify(self, ports):
    +        # Set up the streaming context and input streams
    +        ssc = StreamingContext(self.sc, self.duration)
    +        try:
    +            addresses = [("localhost", port) for port in ports]
    +            dstream = FlumeUtils.createPollingStream(
    +                ssc,
    +                addresses,
    +                maxBatchSize=self._utils.eventsPerBatch(),
    +                parallelism=5)
    +            outputBuffer = []
    +
    +            def get_output(_, rdd):
    +                for e in rdd.collect():
    +                    outputBuffer.append(e)
    +
    +            dstream.foreachRDD(get_output)
    +            ssc.start()
    +            self._utils.sendDatAndEnsureAllDataHasBeenReceived()
    +
    +            self.wait_for(outputBuffer, self._utils.getTotalEvents())
    +            outputHeaders = [event[0] for event in outputBuffer]
    +            outputBodies = [event[1] for event in outputBuffer]
    +            self._utils.assertOutput(outputHeaders, outputBodies)
    +        finally:
    +            ssc.stop(False)
    +
    +    def _testMultipleTimes(self, f):
    +        attempt = 0
    +        while True:
    +            try:
    +                f()
    +                break
    +            except:
    +                attempt += 1
    +                if attempt >= self.maxAttempts:
    +                    raise
    +                else:
    +                    import traceback
    +                    traceback.print_exc()
    +
    +    def _testFlumePolling(self):
    +        try:
    +            port = self._utils.startSingleSink()
    +            self._writeAndVerify([port])
    +            self._utils.assertChannelsAreEmpty()
    +        finally:
    +            self._utils.close()
    +
    +    def _testFlumePollingMultipleHosts(self):
    +        try:
    +            port = self._utils.startSingleSink()
    +            self._writeAndVerify([port])
    +            self._utils.assertChannelsAreEmpty()
    +        finally:
    +            self._utils.close()
    +
    +    def test_flume_polling(self):
    +        self._testMultipleTimes(self._testFlumePolling)
    +
    +    def test_flume_polling_multiple_hosts(self):
    +        self._testMultipleTimes(self._testFlumePollingMultipleHosts)
    +
    +
    +def search_kafka_assembly_jar():
    +    SPARK_HOME = os.environ["SPARK_HOME"]
    +    kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly")
    +    jars = glob.glob(
    +        os.path.join(kafka_assembly_dir, "target/scala-*/spark-streaming-kafka-assembly-*.jar"))
    +    if not jars:
    +        raise Exception(
    +            ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) +
    +            "You need to build Spark with "
    +            "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or "
    +            "'build/mvn package' before running this test")
    +    elif len(jars) > 1:
    +        raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please "
    +                         "remove all but one") % kafka_assembly_dir)
    +    else:
    +        return jars[0]
    +
    +
    +def search_flume_assembly_jar():
    +    SPARK_HOME = os.environ["SPARK_HOME"]
    +    flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly")
    +    jars = glob.glob(
    +        os.path.join(flume_assembly_dir, "target/scala-*/spark-streaming-flume-assembly-*.jar"))
    +    if not jars:
    +        raise Exception(
    +            ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) +
    +            "You need to build Spark with "
    +            "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or "
    +            "'build/mvn package' before running this test")
    +    elif len(jars) > 1:
    +        raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please "
    +                         "remove all but one") % flume_assembly_dir)
    +    else:
    +        return jars[0]
    +
     if __name__ == "__main__":
    +    kafka_assembly_jar = search_kafka_assembly_jar()
    +    flume_assembly_jar = search_flume_assembly_jar()
    +    jars = "%s,%s" % (kafka_assembly_jar, flume_assembly_jar)
    +
    +    os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars
         unittest.main()
    diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
    index 34291f30a5652..b20613b1283bd 100644
    --- a/python/pyspark/streaming/util.py
    +++ b/python/pyspark/streaming/util.py
    @@ -37,6 +37,11 @@ def __init__(self, ctx, func, *deserializers):
             self.ctx = ctx
             self.func = func
             self.deserializers = deserializers
    +        self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
    +
    +    def rdd_wrapper(self, func):
    +        self._rdd_wrapper = func
    +        return self
     
         def call(self, milliseconds, jrdds):
             try:
    @@ -51,7 +56,7 @@ def call(self, milliseconds, jrdds):
                 if len(sers) < len(jrdds):
                     sers += (sers[0],) * (len(jrdds) - len(sers))
     
    -            rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None
    +            rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None
                         for jrdd, ser in zip(jrdds, sers)]
                 t = datetime.fromtimestamp(milliseconds / 1000.0)
                 r = self.func(t, *rdds)
    @@ -125,4 +130,6 @@ def rddToFileName(prefix, suffix, timestamp):
     
     if __name__ == "__main__":
         import doctest
    -    doctest.testmod()
    +    (failure_count, test_count) = doctest.testmod()
    +    if failure_count:
    +        exit(-1)
    diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
    index 09de4d159fdcf..c5c0add49d02c 100644
    --- a/python/pyspark/tests.py
    +++ b/python/pyspark/tests.py
    @@ -179,9 +179,12 @@ def test_in_memory_sort(self):
                              list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
     
         def test_external_sort(self):
    +        class CustomizedSorter(ExternalSorter):
    +            def _next_limit(self):
    +                return self.memory_limit
             l = list(range(1024))
             random.shuffle(l)
    -        sorter = ExternalSorter(1)
    +        sorter = CustomizedSorter(1)
             self.assertEqual(sorted(l), list(sorter.sorted(l)))
             self.assertGreater(shuffle.DiskBytesSpilled, 0)
             last = shuffle.DiskBytesSpilled
    @@ -444,6 +447,11 @@ def func(x):
     
     class RDDTests(ReusedPySparkTestCase):
     
    +    def test_range(self):
    +        self.assertEqual(self.sc.range(1, 1).count(), 0)
    +        self.assertEqual(self.sc.range(1, 0, -1).count(), 1)
    +        self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2)
    +
         def test_id(self):
             rdd = self.sc.parallelize(range(10))
             id = rdd.id()
    @@ -453,6 +461,14 @@ def test_id(self):
             self.assertEqual(id + 1, id2)
             self.assertEqual(id2, rdd2.id())
     
    +    def test_empty_rdd(self):
    +        rdd = self.sc.emptyRDD()
    +        self.assertTrue(rdd.isEmpty())
    +
    +    def test_sum(self):
    +        self.assertEqual(0, self.sc.emptyRDD().sum())
    +        self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum())
    +
         def test_save_as_textfile_with_unicode(self):
             # Regression test for SPARK-970
             x = u"\u00A1Hola, mundo!"
    @@ -869,6 +885,18 @@ def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
                 for size in sizes:
                     self.assertGreater(size, 0)
     
    +    def test_pipe_functions(self):
    +        data = ['1', '2', '3']
    +        rdd = self.sc.parallelize(data)
    +        with QuietTest(self.sc):
    +            self.assertEqual([], rdd.pipe('cc').collect())
    +            self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect)
    +        result = rdd.pipe('cat').collect()
    +        result.sort()
    +        [self.assertEqual(x, y) for x, y in zip(data, result)]
    +        self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
    +        self.assertEqual([], rdd.pipe('grep 4').collect())
    +
     
     class ProfilerTests(PySparkTestCase):
     
    @@ -1405,7 +1433,8 @@ def do_termination_test(self, terminator):
     
             # start daemon
             daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
    -        daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)
    +        python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON")
    +        daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE)
     
             # read the port number
             port = read_int(daemon.stdout)
    @@ -1543,13 +1572,13 @@ def count():
         def test_with_different_versions_of_python(self):
             rdd = self.sc.parallelize(range(10))
             rdd.count()
    -        version = sys.version_info
    -        sys.version_info = (2, 0, 0)
    +        version = self.sc.pythonVer
    +        self.sc.pythonVer = "2.0"
             try:
                 with QuietTest(self.sc):
                     self.assertRaises(Py4JJavaError, lambda: rdd.count())
             finally:
    -            sys.version_info = version
    +            self.sc.pythonVer = version
     
     
     class SparkSubmitTests(unittest.TestCase):
    @@ -1804,6 +1833,10 @@ def run():
     
                 sc.stop()
     
    +    def test_startTime(self):
    +        with SparkContext() as sc:
    +            self.assertGreater(sc.startTime, 0)
    +
     
     @unittest.skipIf(not _have_scipy, "SciPy not installed")
     class SciPyTests(PySparkTestCase):
    diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
    index fbdaf3a5814cd..93df9002be377 100644
    --- a/python/pyspark/worker.py
    +++ b/python/pyspark/worker.py
    @@ -57,6 +57,12 @@ def main(infile, outfile):
             if split_index == -1:  # for unit tests
                 exit(-1)
     
    +        version = utf8_deserializer.loads(infile)
    +        if version != "%d.%d" % sys.version_info[:2]:
    +            raise Exception(("Python in worker has different version %s than that in " +
    +                             "driver %s, PySpark cannot run with different minor versions") %
    +                            ("%d.%d" % sys.version_info[:2], version))
    +
             # initialize global state
             shuffle.MemoryBytesSpilled = 0
             shuffle.DiskBytesSpilled = 0
    @@ -92,11 +98,7 @@ def main(infile, outfile):
             command = pickleSer._read_with_length(infile)
             if isinstance(command, Broadcast):
                 command = pickleSer.loads(command.value)
    -        (func, profiler, deserializer, serializer), version = command
    -        if version != sys.version_info[:2]:
    -            raise Exception(("Python in worker has different version %s than that in " +
    -                            "driver %s, PySpark cannot run with different minor versions") %
    -                            (sys.version_info[:2], version))
    +        func, profiler, deserializer, serializer = command
             init_time = time.time()
     
             def process():
    diff --git a/python/run-tests b/python/run-tests
    index f2757a3967e81..24949657ed7ab 100755
    --- a/python/run-tests
    +++ b/python/run-tests
    @@ -18,160 +18,7 @@
     #
     
     
    -# Figure out where the Spark framework is installed
    -FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)"
    +FWDIR="$(cd "`dirname $0`"/..; pwd)"
    +cd "$FWDIR"
     
    -. "$FWDIR"/bin/load-spark-env.sh
    -
    -# CD into the python directory to find things on the right path
    -cd "$FWDIR/python"
    -
    -FAILED=0
    -LOG_FILE=unit-tests.log
    -START=$(date +"%s")
    -
    -rm -f $LOG_FILE
    -
    -# Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL
    -rm -rf metastore warehouse
    -
    -function run_test() {
    -    echo -en "Running test: $1 ... " | tee -a $LOG_FILE
    -    start=$(date +"%s")
    -    SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1
    -
    -    FAILED=$((PIPESTATUS[0]||$FAILED))
    -
    -    # Fail and exit on the first test failure.
    -    if [[ $FAILED != 0 ]]; then
    -        cat $LOG_FILE | grep -v "^[0-9][0-9]*" # filter all lines starting with a number.
    -        echo -en "\033[31m"  # Red
    -        echo "Had test failures; see logs."
    -        echo -en "\033[0m"  # No color
    -        exit -1
    -    else
    -        now=$(date +"%s")
    -        echo "ok ($(($now - $start))s)"
    -    fi
    -}
    -
    -function run_core_tests() {
    -    echo "Run core tests ..."
    -    run_test "pyspark/rdd.py"
    -    run_test "pyspark/context.py"
    -    run_test "pyspark/conf.py"
    -    PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
    -    PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
    -    run_test "pyspark/serializers.py"
    -    run_test "pyspark/profiler.py"
    -    run_test "pyspark/shuffle.py"
    -    run_test "pyspark/tests.py"
    -}
    -
    -function run_sql_tests() {
    -    echo "Run sql tests ..."
    -    run_test "pyspark/sql/_types.py"
    -    run_test "pyspark/sql/context.py"
    -    run_test "pyspark/sql/dataframe.py"
    -    run_test "pyspark/sql/functions.py"
    -    run_test "pyspark/sql/tests.py"
    -}
    -
    -function run_mllib_tests() {
    -    echo "Run mllib tests ..."
    -    run_test "pyspark/mllib/classification.py"
    -    run_test "pyspark/mllib/clustering.py"
    -    run_test "pyspark/mllib/evaluation.py"
    -    run_test "pyspark/mllib/feature.py"
    -    run_test "pyspark/mllib/fpm.py"
    -    run_test "pyspark/mllib/linalg.py"
    -    run_test "pyspark/mllib/rand.py"
    -    run_test "pyspark/mllib/recommendation.py"
    -    run_test "pyspark/mllib/regression.py"
    -    run_test "pyspark/mllib/stat/_statistics.py"
    -    run_test "pyspark/mllib/tree.py"
    -    run_test "pyspark/mllib/util.py"
    -    run_test "pyspark/mllib/tests.py"
    -}
    -
    -function run_ml_tests() {
    -    echo "Run ml tests ..."
    -    run_test "pyspark/ml/feature.py"
    -    run_test "pyspark/ml/classification.py"
    -    run_test "pyspark/ml/recommendation.py"
    -    run_test "pyspark/ml/regression.py"
    -    run_test "pyspark/ml/tuning.py"
    -    run_test "pyspark/ml/tests.py"
    -    run_test "pyspark/ml/evaluation.py"
    -}
    -
    -function run_streaming_tests() {
    -    echo "Run streaming tests ..."
    -
    -    KAFKA_ASSEMBLY_DIR="$FWDIR"/external/kafka-assembly
    -    JAR_PATH="${KAFKA_ASSEMBLY_DIR}/target/scala-${SPARK_SCALA_VERSION}"
    -    for f in "${JAR_PATH}"/spark-streaming-kafka-assembly-*.jar; do
    -      if [[ ! -e "$f" ]]; then
    -        echo "Failed to find Spark Streaming Kafka assembly jar in $KAFKA_ASSEMBLY_DIR" 1>&2
    -        echo "You need to build Spark with " \
    -             "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or" \
    -             "'build/mvn package' before running this program" 1>&2
    -        exit 1
    -      fi
    -      KAFKA_ASSEMBLY_JAR="$f"
    -    done
    -
    -    export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell"
    -    run_test "pyspark/streaming/util.py"
    -    run_test "pyspark/streaming/tests.py"
    -}
    -
    -echo "Running PySpark tests. Output is in python/$LOG_FILE."
    -
    -export PYSPARK_PYTHON="python"
    -
    -# Try to test with Python 2.6, since that's the minimum version that we support:
    -if [ $(which python2.6) ]; then
    -    export PYSPARK_PYTHON="python2.6"
    -fi
    -
    -echo "Testing with Python version:"
    -$PYSPARK_PYTHON --version
    -
    -run_core_tests
    -run_sql_tests
    -run_mllib_tests
    -run_ml_tests
    -run_streaming_tests
    -
    -# Try to test with Python 3
    -if [ $(which python3.4) ]; then
    -    export PYSPARK_PYTHON="python3.4"
    -    echo "Testing with Python3.4 version:"
    -    $PYSPARK_PYTHON --version
    -
    -    run_core_tests
    -    run_sql_tests
    -    run_mllib_tests
    -    run_ml_tests
    -    run_streaming_tests
    -fi
    -
    -# Try to test with PyPy
    -if [ $(which pypy) ]; then
    -    export PYSPARK_PYTHON="pypy"
    -    echo "Testing with PyPy version:"
    -    $PYSPARK_PYTHON --version
    -
    -    run_core_tests
    -    run_sql_tests
    -    run_streaming_tests
    -fi
    -
    -if [[ $FAILED == 0 ]]; then
    -    now=$(date +"%s")
    -    echo -e "\033[32mTests passed \033[0min $(($now - $START)) seconds"
    -fi
    -
    -# TODO: in the long-run, it would be nice to use a test runner like `nose`.
    -# The doctest fixtures are the current barrier to doing this.
    +exec python -u ./python/run-tests.py "$@"
    diff --git a/python/run-tests.py b/python/run-tests.py
    new file mode 100755
    index 0000000000000..cc560779373b3
    --- /dev/null
    +++ b/python/run-tests.py
    @@ -0,0 +1,214 @@
    +#!/usr/bin/env python
    +
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one or more
    +# contributor license agreements.  See the NOTICE file distributed with
    +# this work for additional information regarding copyright ownership.
    +# The ASF licenses this file to You under the Apache License, Version 2.0
    +# (the "License"); you may not use this file except in compliance with
    +# the License.  You may obtain a copy of the License at
    +#
    +#    http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing, software
    +# distributed under the License is distributed on an "AS IS" BASIS,
    +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +# See the License for the specific language governing permissions and
    +# limitations under the License.
    +#
    +
    +from __future__ import print_function
    +import logging
    +from optparse import OptionParser
    +import os
    +import re
    +import subprocess
    +import sys
    +import tempfile
    +from threading import Thread, Lock
    +import time
    +if sys.version < '3':
    +    import Queue
    +else:
    +    import queue as Queue
    +if sys.version_info >= (2, 7):
    +    subprocess_check_output = subprocess.check_output
    +else:
    +    # SPARK-8763
    +    # backported from subprocess module in Python 2.7
    +    def subprocess_check_output(*popenargs, **kwargs):
    +        if 'stdout' in kwargs:
    +            raise ValueError('stdout argument not allowed, it will be overridden.')
    +        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
    +        output, unused_err = process.communicate()
    +        retcode = process.poll()
    +        if retcode:
    +            cmd = kwargs.get("args")
    +            if cmd is None:
    +                cmd = popenargs[0]
    +            raise subprocess.CalledProcessError(retcode, cmd, output=output)
    +        return output
    +
    +
    +# Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module
    +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../dev/"))
    +
    +
    +from sparktestsupport import SPARK_HOME  # noqa (suppress pep8 warnings)
    +from sparktestsupport.shellutils import which  # noqa
    +from sparktestsupport.modules import all_modules  # noqa
    +
    +
    +python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root')
    +
    +
    +def print_red(text):
    +    print('\033[31m' + text + '\033[0m')
    +
    +
    +LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log")
    +FAILURE_REPORTING_LOCK = Lock()
    +LOGGER = logging.getLogger()
    +
    +
    +def run_individual_python_test(test_name, pyspark_python):
    +    env = dict(os.environ)
    +    env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)})
    +    LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name)
    +    start_time = time.time()
    +    try:
    +        per_test_output = tempfile.TemporaryFile()
    +        retcode = subprocess.Popen(
    +            [os.path.join(SPARK_HOME, "bin/pyspark"), test_name],
    +            stderr=per_test_output, stdout=per_test_output, env=env).wait()
    +    except:
    +        LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python)
    +        # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
    +        # this code is invoked from a thread other than the main thread.
    +        os._exit(1)
    +    duration = time.time() - start_time
    +    # Exit on the first failure.
    +    if retcode != 0:
    +        try:
    +            with FAILURE_REPORTING_LOCK:
    +                with open(LOG_FILE, 'ab') as log_file:
    +                    per_test_output.seek(0)
    +                    log_file.writelines(per_test_output)
    +                per_test_output.seek(0)
    +                for line in per_test_output:
    +                    decoded_line = line.decode()
    +                    if not re.match('[0-9]+', decoded_line):
    +                        print(decoded_line, end='')
    +                per_test_output.close()
    +        except:
    +            LOGGER.exception("Got an exception while trying to print failed test output")
    +        finally:
    +            print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python))
    +            # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
    +            # this code is invoked from a thread other than the main thread.
    +            os._exit(-1)
    +    else:
    +        per_test_output.close()
    +        LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration)
    +
    +
    +def get_default_python_executables():
    +    python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)]
    +    if "python2.6" not in python_execs:
    +        LOGGER.warning("Not testing against `python2.6` because it could not be found; falling"
    +                       " back to `python` instead")
    +        python_execs.insert(0, "python")
    +    return python_execs
    +
    +
    +def parse_opts():
    +    parser = OptionParser(
    +        prog="run-tests"
    +    )
    +    parser.add_option(
    +        "--python-executables", type="string", default=','.join(get_default_python_executables()),
    +        help="A comma-separated list of Python executables to test against (default: %default)"
    +    )
    +    parser.add_option(
    +        "--modules", type="string",
    +        default=",".join(sorted(python_modules.keys())),
    +        help="A comma-separated list of Python modules to test (default: %default)"
    +    )
    +    parser.add_option(
    +        "-p", "--parallelism", type="int", default=4,
    +        help="The number of suites to test in parallel (default %default)"
    +    )
    +    parser.add_option(
    +        "--verbose", action="store_true",
    +        help="Enable additional debug logging"
    +    )
    +
    +    (opts, args) = parser.parse_args()
    +    if args:
    +        parser.error("Unsupported arguments: %s" % ' '.join(args))
    +    if opts.parallelism < 1:
    +        parser.error("Parallelism cannot be less than 1")
    +    return opts
    +
    +
    +def main():
    +    opts = parse_opts()
    +    if (opts.verbose):
    +        log_level = logging.DEBUG
    +    else:
    +        log_level = logging.INFO
    +    logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
    +    LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE)
    +    if os.path.exists(LOG_FILE):
    +        os.remove(LOG_FILE)
    +    python_execs = opts.python_executables.split(',')
    +    modules_to_test = []
    +    for module_name in opts.modules.split(','):
    +        if module_name in python_modules:
    +            modules_to_test.append(python_modules[module_name])
    +        else:
    +            print("Error: unrecognized module %s" % module_name)
    +            sys.exit(-1)
    +    LOGGER.info("Will test against the following Python executables: %s", python_execs)
    +    LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])
    +
    +    task_queue = Queue.Queue()
    +    for python_exec in python_execs:
    +        python_implementation = subprocess_check_output(
    +            [python_exec, "-c", "import platform; print(platform.python_implementation())"],
    +            universal_newlines=True).strip()
    +        LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation)
    +        LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output(
    +            [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip())
    +        for module in modules_to_test:
    +            if python_implementation not in module.blacklisted_python_implementations:
    +                for test_goal in module.python_test_goals:
    +                    task_queue.put((python_exec, test_goal))
    +
    +    def process_queue(task_queue):
    +        while True:
    +            try:
    +                (python_exec, test_goal) = task_queue.get_nowait()
    +            except Queue.Empty:
    +                break
    +            try:
    +                run_individual_python_test(test_goal, python_exec)
    +            finally:
    +                task_queue.task_done()
    +
    +    start_time = time.time()
    +    for _ in range(opts.parallelism):
    +        worker = Thread(target=process_queue, args=(task_queue,))
    +        worker.daemon = True
    +        worker.start()
    +    try:
    +        task_queue.join()
    +    except (KeyboardInterrupt, SystemExit):
    +        print_red("Exiting due to interrupt")
    +        sys.exit(-1)
    +    total_duration = time.time() - start_time
    +    LOGGER.info("Tests passed in %i seconds", total_duration)
    +
    +
    +if __name__ == "__main__":
    +    main()
    diff --git a/python/test_support/sql/parquet_partitioned/_SUCCESS b/python/test_support/sql/parquet_partitioned/_SUCCESS
    new file mode 100644
    index 0000000000000..e69de29bb2d1d
    diff --git a/python/test_support/sql/parquet_partitioned/_common_metadata b/python/test_support/sql/parquet_partitioned/_common_metadata
    new file mode 100644
    index 0000000000000..7ef2320651dee
    Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_common_metadata differ
    diff --git a/python/test_support/sql/parquet_partitioned/_metadata b/python/test_support/sql/parquet_partitioned/_metadata
    new file mode 100644
    index 0000000000000..78a1ca7d38279
    Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_metadata differ
    diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc
    new file mode 100644
    index 0000000000000..e93f42ed6f350
    Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc differ
    diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet
    new file mode 100644
    index 0000000000000..461c382937ecd
    Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet differ
    diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc
    new file mode 100644
    index 0000000000000..b63c4d6d1e1dc
    Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc differ
    diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc
    new file mode 100644
    index 0000000000000..5bc0ebd713563
    Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc differ
    diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet
    new file mode 100644
    index 0000000000000..62a63915beac2
    Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet differ
    diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet
    new file mode 100644
    index 0000000000000..67665a7b55da6
    Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet differ
    diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc
    new file mode 100644
    index 0000000000000..ae94a15d08c81
    Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc differ
    diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet
    new file mode 100644
    index 0000000000000..6cb8538aa8904
    Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet differ
    diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc
    new file mode 100644
    index 0000000000000..58d9bb5fc5883
    Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc differ
    diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet
    new file mode 100644
    index 0000000000000..9b00805481e7b
    Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet differ
    diff --git a/python/test_support/sql/people.json b/python/test_support/sql/people.json
    new file mode 100644
    index 0000000000000..50a859cbd7ee8
    --- /dev/null
    +++ b/python/test_support/sql/people.json
    @@ -0,0 +1,3 @@
    +{"name":"Michael"}
    +{"name":"Andy", "age":30}
    +{"name":"Justin", "age":19}
    diff --git a/repl/pom.xml b/repl/pom.xml
    index 03053b4c3b287..70c9bd7c01296 100644
    --- a/repl/pom.xml
    +++ b/repl/pom.xml
    @@ -21,7 +21,7 @@
       
         org.apache.spark
         spark-parent_2.10
    -    1.4.0-SNAPSHOT
    +    1.5.0-SNAPSHOT
         ../pom.xml
       
     
    @@ -39,14 +39,16 @@
     
       
         
    -      ${jline.groupid}
    -      jline
    -      ${jline.version}
    +      org.apache.spark
    +      spark-core_${scala.binary.version}
    +      ${project.version}
         
         
           org.apache.spark
           spark-core_${scala.binary.version}
           ${project.version}
    +      test-jar
    +      test
         
         
           org.apache.spark
    @@ -86,7 +88,7 @@
         
         
           org.mockito
    -      mockito-all
    +      mockito-core
           test
         
     
    @@ -154,6 +156,20 @@
         
       
       
    +    
    +      scala-2.10
    +      
    +        !scala-2.11
    +      
    +      
    +        
    +          ${jline.groupid}
    +          jline
    +          ${jline.version}
    +        
    +      
    +    
    +
         
           scala-2.11
           
    diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
    index 6480e2d24e044..24fbbc12c08da 100644
    --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
    +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
    @@ -39,6 +39,8 @@ class SparkCommandLine(args: List[String], override val settings: Settings)
       }
     
       def this(args: List[String]) {
    +    // scalastyle:off println
         this(args, str => Console.println("Error: " + str))
    +    // scalastyle:on println
       }
     }
    diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
    index 488f3a9f33256..8f7f9074d3f03 100644
    --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
    +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
    @@ -206,7 +206,8 @@ class SparkILoop(
             // e.g. file:/C:/my/path.jar -> C:/my/path.jar
             SparkILoop.getAddedJars.map { jar => new URI(jar).getPath.stripPrefix("/") }
           } else {
    -        SparkILoop.getAddedJars
    +        // We need new URI(jar).getPath here for the case that `jar` includes encoded white space (%20).
    +        SparkILoop.getAddedJars.map { jar => new URI(jar).getPath }
           }
         // work around for Scala bug
         val totalClassPath = addedJars.foldLeft(
    @@ -1100,7 +1101,9 @@ object SparkILoop extends Logging {
                 val s = super.readLine()
                 // helping out by printing the line being interpreted.
                 if (s != null)
    +              // scalastyle:off println
                   output.println(s)
    +              // scalastyle:on println
                 s
               }
             }
    @@ -1109,7 +1112,7 @@ object SparkILoop extends Logging {
             if (settings.classpath.isDefault)
               settings.classpath.value = sys.props("java.class.path")
     
    -        getAddedJars.foreach(settings.classpath.append(_))
    +        getAddedJars.map(jar => new URI(jar).getPath).foreach(settings.classpath.append(_))
     
             repl process settings
           }
    diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
    index 05faef8786d2c..bd3314d94eed6 100644
    --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
    +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
    @@ -80,11 +80,13 @@ private[repl] trait SparkILoopInit {
         if (!initIsComplete)
           withLock { while (!initIsComplete) initLoopCondition.await() }
         if (initError != null) {
    +      // scalastyle:off println
           println("""
             |Failed to initialize the REPL due to an unexpected error.
             |This is a bug, please, report it along with the error diagnostics printed below.
             |%s.""".stripMargin.format(initError)
           )
    +      // scalastyle:on println
           false
         } else true
       }
    diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
    index 35fb625645022..4ee605fd7f11e 100644
    --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
    +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
    @@ -1079,8 +1079,10 @@ import org.apache.spark.annotation.DeveloperApi
           throw new EvalException("Failed to load '" + path + "': " + ex.getMessage, ex)
     
         private def load(path: String): Class[_] = {
    +      // scalastyle:off classforname
           try Class.forName(path, true, classLoader)
           catch { case ex: Throwable => evalError(path, unwrap(ex)) }
    +      // scalastyle:on classforname
         }
     
         lazy val evalClass = load(evalPath)
    @@ -1761,7 +1763,9 @@ object SparkIMain {
             if (intp.totalSilence) ()
             else super.printMessage(msg)
           }
    +      // scalastyle:off println
           else Console.println(msg)
    +      // scalastyle:on println
         }
       }
     }
    diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
    index 934daaeaafca1..f150fec7db945 100644
    --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
    +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
    @@ -22,13 +22,12 @@ import java.net.URLClassLoader
     
     import scala.collection.mutable.ArrayBuffer
     
    -import org.scalatest.FunSuite
    -import org.apache.spark.SparkContext
    +import org.apache.spark.{SparkContext, SparkFunSuite}
     import org.apache.commons.lang3.StringEscapeUtils
     import org.apache.spark.util.Utils
     
     
    -class ReplSuite extends FunSuite {
    +class ReplSuite extends SparkFunSuite {
     
       def runInterpreter(master: String, input: String): String = {
         val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath"
    @@ -268,6 +267,17 @@ class ReplSuite extends FunSuite {
         assertDoesNotContain("Exception", output)
       }
     
    +  test("SPARK-8461 SQL with codegen") {
    +    val output = runInterpreter("local",
    +    """
    +      |val sqlContext = new org.apache.spark.sql.SQLContext(sc)
    +      |sqlContext.setConf("spark.sql.codegen", "true")
    +      |sqlContext.range(0, 100).filter('id > 50).count()
    +    """.stripMargin)
    +    assertContains("Long = 49", output)
    +    assertDoesNotContain("java.lang.ClassNotFoundException", output)
    +  }
    +
       test("SPARK-2632 importing a method from non serializable class and not using it.") {
         val output = runInterpreter("local",
         """
    diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
    index f4f4b626988e9..eed4a379afa60 100644
    --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
    +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
    @@ -17,13 +17,14 @@
     
     package org.apache.spark.repl
     
    +import java.io.File
    +
    +import scala.tools.nsc.Settings
    +
     import org.apache.spark.util.Utils
     import org.apache.spark._
     import org.apache.spark.sql.SQLContext
     
    -import scala.tools.nsc.Settings
    -import scala.tools.nsc.interpreter.SparkILoop
    -
     object Main extends Logging {
     
       val conf = new SparkConf()
    @@ -32,7 +33,8 @@ object Main extends Logging {
       val outputDir = Utils.createTempDir(rootDir)
       val s = new Settings()
       s.processArguments(List("-Yrepl-class-based",
    -    "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true)
    +    "-Yrepl-outdir", s"${outputDir.getAbsolutePath}",
    +    "-classpath", getAddedJars.mkString(File.pathSeparator)), true)
       val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf))
       var sparkContext: SparkContext = _
       var sqlContext: SQLContext = _
    @@ -48,7 +50,6 @@ object Main extends Logging {
         Option(sparkContext).map(_.stop)
       }
     
    -
       def getAddedJars: Array[String] = {
         val envJars = sys.env.get("ADD_JARS")
         if (envJars.isDefined) {
    @@ -84,10 +85,9 @@ object Main extends Logging {
         val loader = Utils.getContextOrSparkClassLoader
         try {
           sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext])
    -        .newInstance(sparkContext).asInstanceOf[SQLContext] 
    +        .newInstance(sparkContext).asInstanceOf[SQLContext]
           logInfo("Created sql context (with Hive support)..")
    -    }
    -    catch {
    +    } catch {
           case _: java.lang.ClassNotFoundException | _: java.lang.NoClassDefFoundError =>
             sqlContext = new SQLContext(sparkContext)
             logInfo("Created sql context..")
    diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
    deleted file mode 100644
    index 8e519fa67f649..0000000000000
    --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
    +++ /dev/null
    @@ -1,86 +0,0 @@
    -/* NSC -- new Scala compiler
    - * Copyright 2005-2013 LAMP/EPFL
    - * @author  Paul Phillips
    - */
    -
    -package scala.tools.nsc
    -package interpreter
    -
    -import scala.tools.nsc.ast.parser.Tokens.EOF
    -
    -trait SparkExprTyper {
    -  val repl: SparkIMain
    -
    -  import repl._
    -  import global.{ reporter => _, Import => _, _ }
    -  import naming.freshInternalVarName
    -
    -  def symbolOfLine(code: String): Symbol = {
    -    def asExpr(): Symbol = {
    -      val name  = freshInternalVarName()
    -      // Typing it with a lazy val would give us the right type, but runs
    -      // into compiler bugs with things like existentials, so we compile it
    -      // behind a def and strip the NullaryMethodType which wraps the expr.
    -      val line = "def " + name + " = " + code
    -
    -      interpretSynthetic(line) match {
    -        case IR.Success =>
    -          val sym0 = symbolOfTerm(name)
    -          // drop NullaryMethodType
    -          sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType)
    -        case _          => NoSymbol
    -      }
    -    }
    -    def asDefn(): Symbol = {
    -      val old = repl.definedSymbolList.toSet
    -
    -      interpretSynthetic(code) match {
    -        case IR.Success =>
    -          repl.definedSymbolList filterNot old match {
    -            case Nil        => NoSymbol
    -            case sym :: Nil => sym
    -            case syms       => NoSymbol.newOverloaded(NoPrefix, syms)
    -          }
    -        case _ => NoSymbol
    -      }
    -    }
    -    def asError(): Symbol = {
    -      interpretSynthetic(code)
    -      NoSymbol
    -    }
    -    beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError()
    -  }
    -
    -  private var typeOfExpressionDepth = 0
    -  def typeOfExpression(expr: String, silent: Boolean = true): Type = {
    -    if (typeOfExpressionDepth > 2) {
    -      repldbg("Terminating typeOfExpression recursion for expression: " + expr)
    -      return NoType
    -    }
    -    typeOfExpressionDepth += 1
    -    // Don't presently have a good way to suppress undesirable success output
    -    // while letting errors through, so it is first trying it silently: if there
    -    // is an error, and errors are desired, then it re-evaluates non-silently
    -    // to induce the error message.
    -    try beSilentDuring(symbolOfLine(expr).tpe) match {
    -      case NoType if !silent => symbolOfLine(expr).tpe // generate error
    -      case tpe               => tpe
    -    }
    -    finally typeOfExpressionDepth -= 1
    -  }
    -
    -  // This only works for proper types.
    -  def typeOfTypeString(typeString: String): Type = {
    -    def asProperType(): Option[Type] = {
    -      val name = freshInternalVarName()
    -      val line = "def %s: %s = ???" format (name, typeString)
    -      interpretSynthetic(line) match {
    -        case IR.Success =>
    -          val sym0 = symbolOfTerm(name)
    -          Some(sym0.asMethod.returnType)
    -        case _          => None
    -      }
    -    }
    -    beSilentDuring(asProperType()) getOrElse NoType
    -  }
    -}
    diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
    index 7a5e94da5cbf3..bf609ff0f65fc 100644
    --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
    +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
    @@ -1,88 +1,64 @@
    -/* NSC -- new Scala compiler
    - * Copyright 2005-2013 LAMP/EPFL
    - * @author Alexander Spoon
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
      */
     
    -package scala
    -package tools.nsc
    -package interpreter
    +package org.apache.spark.repl
     
    -import scala.language.{ implicitConversions, existentials }
    -import scala.annotation.tailrec
    -import Predef.{ println => _, _ }
    -import interpreter.session._
    -import StdReplTags._
    -import scala.reflect.api.{Mirror, Universe, TypeCreator}
    -import scala.util.Properties.{ jdkHome, javaVersion, versionString, javaVmName }
    -import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream }
    -import scala.reflect.{ClassTag, classTag}
    -import scala.reflect.internal.util.{ BatchSourceFile, ScalaClassLoader }
    -import ScalaClassLoader._
    -import scala.reflect.io.{ File, Directory }
    -import scala.tools.util._
    -import scala.collection.generic.Clearable
    -import scala.concurrent.{ ExecutionContext, Await, Future, future }
    -import ExecutionContext.Implicits._
    -import java.io.{ BufferedReader, FileReader }
    +import java.io.{BufferedReader, FileReader}
     
    -/** The Scala interactive shell.  It provides a read-eval-print loop
    -  *  around the Interpreter class.
    -  *  After instantiation, clients should call the main() method.
    -  *
    -  *  If no in0 is specified, then input will come from the console, and
    -  *  the class will attempt to provide input editing feature such as
    -  *  input history.
    -  *
    -  *  @author Moez A. Abdel-Gawad
    -  *  @author  Lex Spoon
    -  *  @version 1.2
    -  */
    -class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter)
    -  extends AnyRef
    -  with LoopCommands
    -{
    -  def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out)
    -  def this() = this(None, new JPrintWriter(Console.out, true))
    -//
    -//  @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp
    -//  @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: Interpreter): Unit = intp = i
    -
    -  var in: InteractiveReader = _   // the input stream from which commands come
    -  var settings: Settings = _
    -  var intp: SparkIMain = _
    +import Predef.{println => _, _}
    +import scala.util.Properties.{jdkHome, javaVersion, versionString, javaVmName}
     
    -  var globalFuture: Future[Boolean] = _
    +import scala.tools.nsc.interpreter.{JPrintWriter, ILoop}
    +import scala.tools.nsc.Settings
    +import scala.tools.nsc.util.stringFromStream
     
    -  protected def asyncMessage(msg: String) {
    -    if (isReplInfo || isReplPower)
    -      echoAndRefresh(msg)
    -  }
    +/**
    + *  A Spark-specific interactive shell.
    + */
    +class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter)
    +    extends ILoop(in0, out) {
    +  def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out)
    +  def this() = this(None, new JPrintWriter(Console.out, true))
     
       def initializeSpark() {
         intp.beQuietDuring {
    -      command( """
    +      processLine("""
              @transient val sc = {
                val _sc = org.apache.spark.repl.Main.createSparkContext()
                println("Spark context available as sc.")
                _sc
              }
             """)
    -      command( """
    +      processLine("""
              @transient val sqlContext = {
                val _sqlContext = org.apache.spark.repl.Main.createSQLContext()
                println("SQL context available as sqlContext.")
                _sqlContext
              }
             """)
    -      command("import org.apache.spark.SparkContext._")
    -      command("import sqlContext.implicits._")
    -      command("import sqlContext.sql")
    -      command("import org.apache.spark.sql.functions._")
    +      processLine("import org.apache.spark.SparkContext._")
    +      processLine("import sqlContext.implicits._")
    +      processLine("import sqlContext.sql")
    +      processLine("import org.apache.spark.sql.functions._")
         }
       }
     
       /** Print a welcome message */
    -  def printWelcome() {
    +  override def printWelcome() {
         import org.apache.spark.SPARK_VERSION
         echo("""Welcome to
           ____              __
    @@ -98,875 +74,42 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter)
         echo("Type :help for more information.")
       }
     
    -  override def echoCommandMessage(msg: String) {
    -    intp.reporter printUntruncatedMessage msg
    -  }
    -
    -  // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals])
    -  def history = in.history
    -
    -  // classpath entries added via :cp
    -  var addedClasspath: String = ""
    -
    -  /** A reverse list of commands to replay if the user requests a :replay */
    -  var replayCommandStack: List[String] = Nil
    -
    -  /** A list of commands to replay if the user requests a :replay */
    -  def replayCommands = replayCommandStack.reverse
    -
    -  /** Record a command for replay should the user request a :replay */
    -  def addReplay(cmd: String) = replayCommandStack ::= cmd
    -
    -  def savingReplayStack[T](body: => T): T = {
    -    val saved = replayCommandStack
    -    try body
    -    finally replayCommandStack = saved
    -  }
    -  def savingReader[T](body: => T): T = {
    -    val saved = in
    -    try body
    -    finally in = saved
    -  }
    -
    -  /** Close the interpreter and set the var to null. */
    -  def closeInterpreter() {
    -    if (intp ne null) {
    -      intp.close()
    -      intp = null
    -    }
    -  }
    -
    -  class SparkILoopInterpreter extends SparkIMain(settings, out) {
    -    outer =>
    -
    -    override lazy val formatting = new Formatting {
    -      def prompt = SparkILoop.this.prompt
    -    }
    -    override protected def parentClassLoader =
    -      settings.explicitParentLoader.getOrElse( classOf[SparkILoop].getClassLoader )
    -  }
    -
    -  /** Create a new interpreter. */
    -  def createInterpreter() {
    -    if (addedClasspath != "")
    -      settings.classpath append addedClasspath
    -
    -    intp = new SparkILoopInterpreter
    -  }
    -
    -  /** print a friendly help message */
    -  def helpCommand(line: String): Result = {
    -    if (line == "") helpSummary()
    -    else uniqueCommand(line) match {
    -      case Some(lc) => echo("\n" + lc.help)
    -      case _        => ambiguousError(line)
    -    }
    -  }
    -  private def helpSummary() = {
    -    val usageWidth  = commands map (_.usageMsg.length) max
    -    val formatStr   = "%-" + usageWidth + "s %s"
    -
    -    echo("All commands can be abbreviated, e.g. :he instead of :help.")
    -
    -    commands foreach { cmd =>
    -      echo(formatStr.format(cmd.usageMsg, cmd.help))
    -    }
    -  }
    -  private def ambiguousError(cmd: String): Result = {
    -    matchingCommands(cmd) match {
    -      case Nil  => echo(cmd + ": no such command.  Type :help for help.")
    -      case xs   => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?")
    -    }
    -    Result(keepRunning = true, None)
    -  }
    -  private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd)
    -  private def uniqueCommand(cmd: String): Option[LoopCommand] = {
    -    // this lets us add commands willy-nilly and only requires enough command to disambiguate
    -    matchingCommands(cmd) match {
    -      case List(x)  => Some(x)
    -      // exact match OK even if otherwise appears ambiguous
    -      case xs       => xs find (_.name == cmd)
    -    }
    -  }
    -
    -  /** Show the history */
    -  lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") {
    -    override def usage = "[num]"
    -    def defaultLines = 20
    -
    -    def apply(line: String): Result = {
    -      if (history eq NoHistory)
    -        return "No history available."
    -
    -      val xs      = words(line)
    -      val current = history.index
    -      val count   = try xs.head.toInt catch { case _: Exception => defaultLines }
    -      val lines   = history.asStrings takeRight count
    -      val offset  = current - lines.size + 1
    -
    -      for ((line, index) <- lines.zipWithIndex)
    -        echo("%3d  %s".format(index + offset, line))
    -    }
    -  }
    -
    -  // When you know you are most likely breaking into the middle
    -  // of a line being typed.  This softens the blow.
    -  protected def echoAndRefresh(msg: String) = {
    -    echo("\n" + msg)
    -    in.redrawLine()
    -  }
    -  protected def echo(msg: String) = {
    -    out println msg
    -    out.flush()
    -  }
    -
    -  /** Search the history */
    -  def searchHistory(_cmdline: String) {
    -    val cmdline = _cmdline.toLowerCase
    -    val offset  = history.index - history.size + 1
    -
    -    for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline)
    -      echo("%d %s".format(index + offset, line))
    -  }
    -
    -  private val currentPrompt = Properties.shellPromptString
    -
    -  /** Prompt to print when awaiting input */
    -  def prompt = currentPrompt
    -
       import LoopCommand.{ cmd, nullary }
     
    -  /** Standard commands **/
    -  lazy val standardCommands = List(
    -    cmd("cp", "", "add a jar or directory to the classpath", addClasspath),
    -    cmd("edit", "|", "edit history", editCommand),
    -    cmd("help", "[command]", "print this summary or command-specific help", helpCommand),
    -    historyCommand,
    -    cmd("h?", "", "search the history", searchHistory),
    -    cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand),
    -    //cmd("implicits", "[-v]", "show the implicits in scope", intp.implicitsCommand),
    -    cmd("javap", "", "disassemble a file or class name", javapCommand),
    -    cmd("line", "|", "place line(s) at the end of history", lineCommand),
    -    cmd("load", "", "interpret lines in a file", loadCommand),
    -    cmd("paste", "[-raw] [path]", "enter paste mode or paste a file", pasteCommand),
    -    // nullary("power", "enable power user mode", powerCmd),
    -    nullary("quit", "exit the interpreter", () => Result(keepRunning = false, None)),
    -    nullary("replay", "reset execution and replay all previous commands", replay),
    -    nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand),
    -    cmd("save", "", "save replayable session to a file", saveCommand),
    -    shCommand,
    -    cmd("settings", "[+|-]", "+enable/-disable flags, set compiler options", changeSettings),
    -    nullary("silent", "disable/enable automatic printing of results", verbosity),
    -//    cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand),
    -//    cmd("kind", "[-v] ", "display the kind of expression's type", kindCommand),
    -    nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand)
    -  )
    -
    -  /** Power user commands */
    -//  lazy val powerCommands: List[LoopCommand] = List(
    -//    cmd("phase", "", "set the implicit phase for power commands", phaseCommand)
    -//  )
    -
    -  private def importsCommand(line: String): Result = {
    -    val tokens    = words(line)
    -    val handlers  = intp.languageWildcardHandlers ++ intp.importHandlers
    -
    -    handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach {
    -      case (handler, idx) =>
    -        val (types, terms) = handler.importedSymbols partition (_.name.isTypeName)
    -        val imps           = handler.implicitSymbols
    -        val found          = tokens filter (handler importsSymbolNamed _)
    -        val typeMsg        = if (types.isEmpty) "" else types.size + " types"
    -        val termMsg        = if (terms.isEmpty) "" else terms.size + " terms"
    -        val implicitMsg    = if (imps.isEmpty) "" else imps.size + " are implicit"
    -        val foundMsg       = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "")
    -        val statsMsg       = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")")
    -
    -        intp.reporter.printMessage("%2d) %-30s %s%s".format(
    -          idx + 1,
    -          handler.importString,
    -          statsMsg,
    -          foundMsg
    -        ))
    -    }
    -  }
    -
    -  private def findToolsJar() = PathResolver.SupplementalLocations.platformTools
    +  private val blockedCommands = Set("implicits", "javap", "power", "type", "kind")
     
    -  private def addToolsJarToLoader() = {
    -    val cl = findToolsJar() match {
    -      case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader)
    -      case _           => intp.classLoader
    -    }
    -    if (Javap.isAvailable(cl)) {
    -      repldbg(":javap available.")
    -      cl
    -    }
    -    else {
    -      repldbg(":javap unavailable: no tools.jar at " + jdkHome)
    -      intp.classLoader
    -    }
    -  }
    -//
    -//  protected def newJavap() =
    -//    JavapClass(addToolsJarToLoader(), new IMain.ReplStrippingWriter(intp), Some(intp))
    -//
    -//  private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap())
    -
    -  // Still todo: modules.
    -//  private def typeCommand(line0: String): Result = {
    -//    line0.trim match {
    -//      case "" => ":type [-v] "
    -//      case s  => intp.typeCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ")
    -//    }
    -//  }
    -
    -//  private def kindCommand(expr: String): Result = {
    -//    expr.trim match {
    -//      case "" => ":kind [-v] "
    -//      case s  => intp.kindCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ")
    -//    }
    -//  }
    -
    -  private def warningsCommand(): Result = {
    -    if (intp.lastWarnings.isEmpty)
    -      "Can't find any cached warnings."
    -    else
    -      intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) }
    -  }
    -
    -  private def changeSettings(args: String): Result = {
    -    def showSettings() = {
    -      for (s <- settings.userSetSettings.toSeq.sorted) echo(s.toString)
    -    }
    -    def updateSettings() = {
    -      // put aside +flag options
    -      val (pluses, rest) = (args split "\\s+").toList partition (_.startsWith("+"))
    -      val tmps = new Settings
    -      val (ok, leftover) = tmps.processArguments(rest, processAll = true)
    -      if (!ok) echo("Bad settings request.")
    -      else if (leftover.nonEmpty) echo("Unprocessed settings.")
    -      else {
    -        // boolean flags set-by-user on tmp copy should be off, not on
    -        val offs = tmps.userSetSettings filter (_.isInstanceOf[Settings#BooleanSetting])
    -        val (minuses, nonbools) = rest partition (arg => offs exists (_ respondsTo arg))
    -        // update non-flags
    -        settings.processArguments(nonbools, processAll = true)
    -        // also snag multi-value options for clearing, e.g. -Ylog: and -language:
    -        for {
    -          s <- settings.userSetSettings
    -          if s.isInstanceOf[Settings#MultiStringSetting] || s.isInstanceOf[Settings#PhasesSetting]
    -          if nonbools exists (arg => arg.head == '-' && arg.last == ':' && (s respondsTo arg.init))
    -        } s match {
    -          case c: Clearable => c.clear()
    -          case _ =>
    -        }
    -        def update(bs: Seq[String], name: String=>String, setter: Settings#Setting=>Unit) = {
    -          for (b <- bs)
    -            settings.lookupSetting(name(b)) match {
    -              case Some(s) =>
    -                if (s.isInstanceOf[Settings#BooleanSetting]) setter(s)
    -                else echo(s"Not a boolean flag: $b")
    -              case _ =>
    -                echo(s"Not an option: $b")
    -            }
    -        }
    -        update(minuses, identity, _.tryToSetFromPropertyValue("false"))  // turn off
    -        update(pluses, "-" + _.drop(1), _.tryToSet(Nil))                 // turn on
    -      }
    -    }
    -    if (args.isEmpty) showSettings() else updateSettings()
    -  }
    -
    -  private def javapCommand(line: String): Result = {
    -//    if (javap == null)
    -//      ":javap unavailable, no tools.jar at %s.  Set JDK_HOME.".format(jdkHome)
    -//    else if (line == "")
    -//      ":javap [-lcsvp] [path1 path2 ...]"
    -//    else
    -//      javap(words(line)) foreach { res =>
    -//        if (res.isError) return "Failed: " + res.value
    -//        else res.show()
    -//      }
    -  }
    -
    -  private def pathToPhaseWrapper = intp.originalPath("$r") + ".phased.atCurrent"
    -
    -  private def phaseCommand(name: String): Result = {
    -//    val phased: Phased = power.phased
    -//    import phased.NoPhaseName
    -//
    -//    if (name == "clear") {
    -//      phased.set(NoPhaseName)
    -//      intp.clearExecutionWrapper()
    -//      "Cleared active phase."
    -//    }
    -//    else if (name == "") phased.get match {
    -//      case NoPhaseName => "Usage: :phase  (e.g. typer, erasure.next, erasure+3)"
    -//      case ph          => "Active phase is '%s'.  (To clear, :phase clear)".format(phased.get)
    -//    }
    -//    else {
    -//      val what = phased.parse(name)
    -//      if (what.isEmpty || !phased.set(what))
    -//        "'" + name + "' does not appear to represent a valid phase."
    -//      else {
    -//        intp.setExecutionWrapper(pathToPhaseWrapper)
    -//        val activeMessage =
    -//          if (what.toString.length == name.length) "" + what
    -//          else "%s (%s)".format(what, name)
    -//
    -//        "Active phase is now: " + activeMessage
    -//      }
    -//    }
    -  }
    +  /** Standard commands **/
    +  lazy val sparkStandardCommands: List[SparkILoop.this.LoopCommand] =
    +    standardCommands.filter(cmd => !blockedCommands(cmd.name))
     
       /** Available commands */
    -  def commands: List[LoopCommand] = standardCommands ++ (
    -    // if (isReplPower)
    -    //  powerCommands
    -    // else
    -      Nil
    -    )
    -
    -  val replayQuestionMessage =
    -    """|That entry seems to have slain the compiler.  Shall I replay
    -      |your session? I can re-run each line except the last one.
    -      |[y/n]
    -    """.trim.stripMargin
    -
    -  private val crashRecovery: PartialFunction[Throwable, Boolean] = {
    -    case ex: Throwable =>
    -      val (err, explain) = (
    -        if (intp.isInitializeComplete)
    -          (intp.global.throwableAsString(ex), "")
    -        else
    -          (ex.getMessage, "The compiler did not initialize.\n")
    -        )
    -      echo(err)
    -
    -      ex match {
    -        case _: NoSuchMethodError | _: NoClassDefFoundError =>
    -          echo("\nUnrecoverable error.")
    -          throw ex
    -        case _  =>
    -          def fn(): Boolean =
    -            try in.readYesOrNo(explain + replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() })
    -            catch { case _: RuntimeException => false }
    -
    -          if (fn()) replay()
    -          else echo("\nAbandoning crashed session.")
    -      }
    -      true
    -  }
    -
    -  // return false if repl should exit
    -  def processLine(line: String): Boolean = {
    -    import scala.concurrent.duration._
    -    Await.ready(globalFuture, 60.seconds)
    -
    -    (line ne null) && (command(line) match {
    -      case Result(false, _)      => false
    -      case Result(_, Some(line)) => addReplay(line) ; true
    -      case _                     => true
    -    })
    -  }
    -
    -  private def readOneLine() = {
    -    out.flush()
    -    in readLine prompt
    -  }
    -
    -  /** The main read-eval-print loop for the repl.  It calls
    -    *  command() for each line of input, and stops when
    -    *  command() returns false.
    -    */
    -  @tailrec final def loop() {
    -    if ( try processLine(readOneLine()) catch crashRecovery )
    -      loop()
    -  }
    -
    -  /** interpret all lines from a specified file */
    -  def interpretAllFrom(file: File) {
    -    savingReader {
    -      savingReplayStack {
    -        file applyReader { reader =>
    -          in = SimpleReader(reader, out, interactive = false)
    -          echo("Loading " + file + "...")
    -          loop()
    -        }
    -      }
    -    }
    -  }
    -
    -  /** create a new interpreter and replay the given commands */
    -  def replay() {
    -    reset()
    -    if (replayCommandStack.isEmpty)
    -      echo("Nothing to replay.")
    -    else for (cmd <- replayCommands) {
    -      echo("Replaying: " + cmd)  // flush because maybe cmd will have its own output
    -      command(cmd)
    -      echo("")
    -    }
    -  }
    -  def resetCommand() {
    -    echo("Resetting interpreter state.")
    -    if (replayCommandStack.nonEmpty) {
    -      echo("Forgetting this session history:\n")
    -      replayCommands foreach echo
    -      echo("")
    -      replayCommandStack = Nil
    -    }
    -    if (intp.namedDefinedTerms.nonEmpty)
    -      echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", "))
    -    if (intp.definedTypes.nonEmpty)
    -      echo("Forgetting defined types: " + intp.definedTypes.mkString(", "))
    -
    -    reset()
    -  }
    -  def reset() {
    -    intp.reset()
    -    unleashAndSetPhase()
    -  }
    -
    -  def lineCommand(what: String): Result = editCommand(what, None)
    -
    -  // :edit id or :edit line
    -  def editCommand(what: String): Result = editCommand(what, Properties.envOrNone("EDITOR"))
    -
    -  def editCommand(what: String, editor: Option[String]): Result = {
    -    def diagnose(code: String) = {
    -      echo("The edited code is incomplete!\n")
    -      val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}")
    -      if (errless) echo("The compiler reports no errors.")
    -    }
    -    def historicize(text: String) = history match {
    -      case jlh: JLineHistory => text.lines foreach jlh.add ; jlh.moveToEnd() ; true
    -      case _ => false
    -    }
    -    def edit(text: String): Result = editor match {
    -      case Some(ed) =>
    -        val tmp = File.makeTemp()
    -        tmp.writeAll(text)
    -        try {
    -          val pr = new ProcessResult(s"$ed ${tmp.path}")
    -          pr.exitCode match {
    -            case 0 =>
    -              tmp.safeSlurp() match {
    -                case Some(edited) if edited.trim.isEmpty => echo("Edited text is empty.")
    -                case Some(edited) =>
    -                  echo(edited.lines map ("+" + _) mkString "\n")
    -                  val res = intp interpret edited
    -                  if (res == IR.Incomplete) diagnose(edited)
    -                  else {
    -                    historicize(edited)
    -                    Result(lineToRecord = Some(edited), keepRunning = true)
    -                  }
    -                case None => echo("Can't read edited text. Did you delete it?")
    -              }
    -            case x => echo(s"Error exit from $ed ($x), ignoring")
    -          }
    -        } finally {
    -          tmp.delete()
    -        }
    -      case None =>
    -        if (historicize(text)) echo("Placing text in recent history.")
    -        else echo(f"No EDITOR defined and you can't change history, echoing your text:%n$text")
    -    }
    -
    -    // if what is a number, use it as a line number or range in history
    -    def isNum = what forall (c => c.isDigit || c == '-' || c == '+')
    -    // except that "-" means last value
    -    def isLast = (what == "-")
    -    if (isLast || !isNum) {
    -      val name = if (isLast) intp.mostRecentVar else what
    -      val sym = intp.symbolOfIdent(name)
    -      intp.prevRequestList collectFirst { case r if r.defines contains sym => r } match {
    -        case Some(req) => edit(req.line)
    -        case None      => echo(s"No symbol in scope: $what")
    -      }
    -    } else try {
    -      val s = what
    -      // line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)
    -      val (start, len) =
    -        if ((s indexOf '+') > 0) {
    -          val (a,b) = s splitAt (s indexOf '+')
    -          (a.toInt, b.drop(1).toInt)
    -        } else {
    -          (s indexOf '-') match {
    -            case -1 => (s.toInt, 1)
    -            case 0  => val n = s.drop(1).toInt ; (history.index - n, n)
    -            case _ if s.last == '-' => val n = s.init.toInt ; (n, history.index - n)
    -            case i  => val n = s.take(i).toInt ; (n, s.drop(i+1).toInt - n)
    -          }
    -        }
    -      import scala.collection.JavaConverters._
    -      val index = (start - 1) max 0
    -      val text = history match {
    -        case jlh: JLineHistory => jlh.entries(index).asScala.take(len) map (_.value) mkString "\n"
    -        case _ => history.asStrings.slice(index, index + len) mkString "\n"
    -      }
    -      edit(text)
    -    } catch {
    -      case _: NumberFormatException => echo(s"Bad range '$what'")
    -        echo("Use line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)")
    -    }
    -  }
    -
    -  /** fork a shell and run a command */
    -  lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") {
    -    override def usage = ""
    -    def apply(line: String): Result = line match {
    -      case ""   => showUsage()
    -      case _    =>
    -        val toRun = s"new ${classOf[ProcessResult].getName}(${string2codeQuoted(line)})"
    -        intp interpret toRun
    -        ()
    -    }
    -  }
    -
    -  def withFile[A](filename: String)(action: File => A): Option[A] = {
    -    val res = Some(File(filename)) filter (_.exists) map action
    -    if (res.isEmpty) echo("That file does not exist")  // courtesy side-effect
    -    res
    -  }
    -
    -  def loadCommand(arg: String) = {
    -    var shouldReplay: Option[String] = None
    -    withFile(arg)(f => {
    -      interpretAllFrom(f)
    -      shouldReplay = Some(":load " + arg)
    -    })
    -    Result(keepRunning = true, shouldReplay)
    -  }
    -
    -  def saveCommand(filename: String): Result = (
    -    if (filename.isEmpty) echo("File name is required.")
    -    else if (replayCommandStack.isEmpty) echo("No replay commands in session")
    -    else File(filename).printlnAll(replayCommands: _*)
    -    )
    -
    -  def addClasspath(arg: String): Unit = {
    -    val f = File(arg).normalize
    -    if (f.exists) {
    -      addedClasspath = ClassPath.join(addedClasspath, f.path)
    -      val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath)
    -      echo("Added '%s'.  Your new classpath is:\n\"%s\"".format(f.path, totalClasspath))
    -      replay()
    -    }
    -    else echo("The path '" + f + "' doesn't seem to exist.")
    -  }
    -
    -  def powerCmd(): Result = {
    -    if (isReplPower) "Already in power mode."
    -    else enablePowerMode(isDuringInit = false)
    -  }
    -  def enablePowerMode(isDuringInit: Boolean) = {
    -    replProps.power setValue true
    -    unleashAndSetPhase()
    -    // asyncEcho(isDuringInit, power.banner)
    -  }
    -  private def unleashAndSetPhase() {
    -    if (isReplPower) {
    -    //  power.unleash()
    -      // Set the phase to "typer"
    -      // intp beSilentDuring phaseCommand("typer")
    -    }
    -  }
    -
    -  def asyncEcho(async: Boolean, msg: => String) {
    -    if (async) asyncMessage(msg)
    -    else echo(msg)
    -  }
    -
    -  def verbosity() = {
    -    val old = intp.printResults
    -    intp.printResults = !old
    -    echo("Switched " + (if (old) "off" else "on") + " result printing.")
    -  }
    -
    -  /** Run one command submitted by the user.  Two values are returned:
    -    * (1) whether to keep running, (2) the line to record for replay,
    -    * if any. */
    -  def command(line: String): Result = {
    -    if (line startsWith ":") {
    -      val cmd = line.tail takeWhile (x => !x.isWhitespace)
    -      uniqueCommand(cmd) match {
    -        case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace))
    -        case _        => ambiguousError(cmd)
    -      }
    -    }
    -    else if (intp.global == null) Result(keepRunning = false, None)  // Notice failure to create compiler
    -    else Result(keepRunning = true, interpretStartingWith(line))
    -  }
    -
    -  private def readWhile(cond: String => Boolean) = {
    -    Iterator continually in.readLine("") takeWhile (x => x != null && cond(x))
    -  }
    -
    -  def pasteCommand(arg: String): Result = {
    -    var shouldReplay: Option[String] = None
    -    def result = Result(keepRunning = true, shouldReplay)
    -    val (raw, file) =
    -      if (arg.isEmpty) (false, None)
    -      else {
    -        val r = """(-raw)?(\s+)?([^\-]\S*)?""".r
    -        arg match {
    -          case r(flag, sep, name) =>
    -            if (flag != null && name != null && sep == null)
    -              echo(s"""I assume you mean "$flag $name"?""")
    -            (flag != null, Option(name))
    -          case _ =>
    -            echo("usage: :paste -raw file")
    -            return result
    -        }
    -      }
    -    val code = file match {
    -      case Some(name) =>
    -        withFile(name)(f => {
    -          shouldReplay = Some(s":paste $arg")
    -          val s = f.slurp.trim
    -          if (s.isEmpty) echo(s"File contains no code: $f")
    -          else echo(s"Pasting file $f...")
    -          s
    -        }) getOrElse ""
    -      case None =>
    -        echo("// Entering paste mode (ctrl-D to finish)\n")
    -        val text = (readWhile(_ => true) mkString "\n").trim
    -        if (text.isEmpty) echo("\n// Nothing pasted, nothing gained.\n")
    -        else echo("\n// Exiting paste mode, now interpreting.\n")
    -        text
    -    }
    -    def interpretCode() = {
    -      val res = intp interpret code
    -      // if input is incomplete, let the compiler try to say why
    -      if (res == IR.Incomplete) {
    -        echo("The pasted code is incomplete!\n")
    -        // Remembrance of Things Pasted in an object
    -        val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}")
    -        if (errless) echo("...but compilation found no error? Good luck with that.")
    -      }
    -    }
    -    def compileCode() = {
    -      val errless = intp compileSources new BatchSourceFile("", code)
    -      if (!errless) echo("There were compilation errors!")
    -    }
    -    if (code.nonEmpty) {
    -      if (raw) compileCode() else interpretCode()
    -    }
    -    result
    -  }
    -
    -  private object paste extends Pasted {
    -    val ContinueString = "     | "
    -    val PromptString   = "scala> "
    -
    -    def interpret(line: String): Unit = {
    -      echo(line.trim)
    -      intp interpret line
    -      echo("")
    -    }
    -
    -    def transcript(start: String) = {
    -      echo("\n// Detected repl transcript paste: ctrl-D to finish.\n")
    -      apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim))
    -    }
    -  }
    -  import paste.{ ContinueString, PromptString }
    -
    -  /** Interpret expressions starting with the first line.
    -    * Read lines until a complete compilation unit is available
    -    * or until a syntax error has been seen.  If a full unit is
    -    * read, go ahead and interpret it.  Return the full string
    -    * to be recorded for replay, if any.
    -    */
    -  def interpretStartingWith(code: String): Option[String] = {
    -    // signal completion non-completion input has been received
    -    in.completion.resetVerbosity()
    -
    -    def reallyInterpret = {
    -      val reallyResult = intp.interpret(code)
    -      (reallyResult, reallyResult match {
    -        case IR.Error       => None
    -        case IR.Success     => Some(code)
    -        case IR.Incomplete  =>
    -          if (in.interactive && code.endsWith("\n\n")) {
    -            echo("You typed two blank lines.  Starting a new command.")
    -            None
    -          }
    -          else in.readLine(ContinueString) match {
    -            case null =>
    -              // we know compilation is going to fail since we're at EOF and the
    -              // parser thinks the input is still incomplete, but since this is
    -              // a file being read non-interactively we want to fail.  So we send
    -              // it straight to the compiler for the nice error message.
    -              intp.compileString(code)
    -              None
    -
    -            case line => interpretStartingWith(code + "\n" + line)
    -          }
    -      })
    -    }
    -
    -    /** Here we place ourselves between the user and the interpreter and examine
    -      *  the input they are ostensibly submitting.  We intervene in several cases:
    -      *
    -      *  1) If the line starts with "scala> " it is assumed to be an interpreter paste.
    -      *  2) If the line starts with "." (but not ".." or "./") it is treated as an invocation
    -      *     on the previous result.
    -      *  3) If the Completion object's execute returns Some(_), we inject that value
    -      *     and avoid the interpreter, as it's likely not valid scala code.
    -      */
    -    if (code == "") None
    -    else if (!paste.running && code.trim.startsWith(PromptString)) {
    -      paste.transcript(code)
    -      None
    -    }
    -    else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") {
    -      interpretStartingWith(intp.mostRecentVar + code)
    -    }
    -    else if (code.trim startsWith "//") {
    -      // line comment, do nothing
    -      None
    -    }
    -    else
    -      reallyInterpret._2
    -  }
    -
    -  // runs :load `file` on any files passed via -i
    -  def loadFiles(settings: Settings) = settings match {
    -    case settings: GenericRunnerSettings =>
    -      for (filename <- settings.loadfiles.value) {
    -        val cmd = ":load " + filename
    -        command(cmd)
    -        addReplay(cmd)
    -        echo("")
    -      }
    -    case _ =>
    -  }
    -
    -  /** Tries to create a JLineReader, falling back to SimpleReader:
    -    *  unless settings or properties are such that it should start
    -    *  with SimpleReader.
    -    */
    -  def chooseReader(settings: Settings): InteractiveReader = {
    -    if (settings.Xnojline || Properties.isEmacsShell)
    -      SimpleReader()
    -    else try new JLineReader(
    -      if (settings.noCompletion) NoCompletion
    -      else new SparkJLineCompletion(intp)
    -    )
    -    catch {
    -      case ex @ (_: Exception | _: NoClassDefFoundError) =>
    -        echo("Failed to created JLineReader: " + ex + "\nFalling back to SimpleReader.")
    -        SimpleReader()
    -    }
    -  }
    -  protected def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] =
    -    u.TypeTag[T](
    -      m,
    -      new TypeCreator {
    -        def apply[U <: Universe with Singleton](m: Mirror[U]): U # Type =
    -          m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type]
    -      })
    -
    -  private def loopPostInit() {
    -    // Bind intp somewhere out of the regular namespace where
    -    // we can get at it in generated code.
    -    intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfStaticClass[SparkIMain], classTag[SparkIMain]))
    -    // Auto-run code via some setting.
    -    ( replProps.replAutorunCode.option
    -      flatMap (f => io.File(f).safeSlurp())
    -      foreach (intp quietRun _)
    -      )
    -    // classloader and power mode setup
    -    intp.setContextClassLoader()
    -    if (isReplPower) {
    -     // replProps.power setValue true
    -     // unleashAndSetPhase()
    -     // asyncMessage(power.banner)
    -    }
    -    // SI-7418 Now, and only now, can we enable TAB completion.
    -    in match {
    -      case x: JLineReader => x.consoleReader.postInit
    -      case _              =>
    -    }
    -  }
    -  def process(settings: Settings): Boolean = savingContextLoader {
    -    this.settings = settings
    -    createInterpreter()
    -
    -    // sets in to some kind of reader depending on environmental cues
    -    in = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true))
    -    globalFuture = future {
    -      intp.initializeSynchronous()
    -      loopPostInit()
    -      !intp.reporter.hasErrors
    -    }
    -    import scala.concurrent.duration._
    -    Await.ready(globalFuture, 10 seconds)
    -    printWelcome()
    +  override def commands: List[LoopCommand] = sparkStandardCommands
    +
    +  /** 
    +   * We override `loadFiles` because we need to initialize Spark *before* the REPL
    +   * sees any files, so that the Spark context is visible in those files. This is a bit of a
    +   * hack, but there isn't another hook available to us at this point.
    +   */
    +  override def loadFiles(settings: Settings): Unit = {
         initializeSpark()
    -    loadFiles(settings)
    -
    -    try loop()
    -    catch AbstractOrMissingHandler()
    -    finally closeInterpreter()
    -
    -    true
    +    super.loadFiles(settings)
       }
    -
    -  @deprecated("Use `process` instead", "2.9.0")
    -  def main(settings: Settings): Unit = process(settings) //used by sbt
     }
     
     object SparkILoop {
    -  implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp
     
    -  // Designed primarily for use by test code: take a String with a
    -  // bunch of code, and prints out a transcript of what it would look
    -  // like if you'd just typed it into the repl.
    -  def runForTranscript(code: String, settings: Settings): String = {
    -    import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
    -
    -    stringFromStream { ostream =>
    -      Console.withOut(ostream) {
    -        val output = new JPrintWriter(new OutputStreamWriter(ostream), true) {
    -          override def write(str: String) = {
    -            // completely skip continuation lines
    -            if (str forall (ch => ch.isWhitespace || ch == '|')) ()
    -            else super.write(str)
    -          }
    -        }
    -        val input = new BufferedReader(new StringReader(code.trim + "\n")) {
    -          override def readLine(): String = {
    -            val s = super.readLine()
    -            // helping out by printing the line being interpreted.
    -            if (s != null)
    -              output.println(s)
    -            s
    -          }
    -        }
    -        val repl = new SparkILoop(input, output)
    -        if (settings.classpath.isDefault)
    -          settings.classpath.value = sys.props("java.class.path")
    -
    -        repl process settings
    -      }
    -    }
    -  }
    -
    -  /** Creates an interpreter loop with default settings and feeds
    -    *  the given code to it as input.
    -    */
    +  /** 
    +   * Creates an interpreter loop with default settings and feeds
    +   * the given code to it as input.
    +   */
       def run(code: String, sets: Settings = new Settings): String = {
         import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
     
         stringFromStream { ostream =>
           Console.withOut(ostream) {
    -        val input    = new BufferedReader(new StringReader(code))
    -        val output   = new JPrintWriter(new OutputStreamWriter(ostream), true)
    -        val repl     = new SparkILoop(input, output)
    +        val input = new BufferedReader(new StringReader(code))
    +        val output = new JPrintWriter(new OutputStreamWriter(ostream), true)
    +        val repl = new SparkILoop(input, output)
     
             if (sets.classpath.isDefault)
               sets.classpath.value = sys.props("java.class.path")
    diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala
    deleted file mode 100644
    index 1cb910f376060..0000000000000
    --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala
    +++ /dev/null
    @@ -1,1319 +0,0 @@
    -/* NSC -- new Scala compiler
    - * Copyright 2005-2013 LAMP/EPFL
    - * @author  Martin Odersky
    - */
    -
    -package scala
    -package tools.nsc
    -package interpreter
    -
    -import PartialFunction.cond
    -import scala.language.implicitConversions
    -import scala.beans.BeanProperty
    -import scala.collection.mutable
    -import scala.concurrent.{ Future, ExecutionContext }
    -import scala.reflect.runtime.{ universe => ru }
    -import scala.reflect.{ ClassTag, classTag }
    -import scala.reflect.internal.util.{ BatchSourceFile, SourceFile }
    -import scala.tools.util.PathResolver
    -import scala.tools.nsc.io.AbstractFile
    -import scala.tools.nsc.typechecker.{ TypeStrings, StructuredTypeStrings }
    -import scala.tools.nsc.util.{ ScalaClassLoader, stringFromReader, stringFromWriter, StackTraceOps }
    -import scala.tools.nsc.util.Exceptional.unwrap
    -import javax.script.{AbstractScriptEngine, Bindings, ScriptContext, ScriptEngine, ScriptEngineFactory, ScriptException, CompiledScript, Compilable}
    -
    -/** An interpreter for Scala code.
    -  *
    -  *  The main public entry points are compile(), interpret(), and bind().
    -  *  The compile() method loads a complete Scala file.  The interpret() method
    -  *  executes one line of Scala code at the request of the user.  The bind()
    -  *  method binds an object to a variable that can then be used by later
    -  *  interpreted code.
    -  *
    -  *  The overall approach is based on compiling the requested code and then
    -  *  using a Java classloader and Java reflection to run the code
    -  *  and access its results.
    -  *
    -  *  In more detail, a single compiler instance is used
    -  *  to accumulate all successfully compiled or interpreted Scala code.  To
    -  *  "interpret" a line of code, the compiler generates a fresh object that
    -  *  includes the line of code and which has public member(s) to export
    -  *  all variables defined by that code.  To extract the result of an
    -  *  interpreted line to show the user, a second "result object" is created
    -  *  which imports the variables exported by the above object and then
    -  *  exports members called "$eval" and "$print". To accomodate user expressions
    -  *  that read from variables or methods defined in previous statements, "import"
    -  *  statements are used.
    -  *
    -  *  This interpreter shares the strengths and weaknesses of using the
    -  *  full compiler-to-Java.  The main strength is that interpreted code
    -  *  behaves exactly as does compiled code, including running at full speed.
    -  *  The main weakness is that redefining classes and methods is not handled
    -  *  properly, because rebinding at the Java level is technically difficult.
    -  *
    -  *  @author Moez A. Abdel-Gawad
    -  *  @author Lex Spoon
    -  */
    -class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Settings,
    -  protected val out: JPrintWriter) extends AbstractScriptEngine with Compilable with SparkImports {
    -  imain =>
    -
    -  setBindings(createBindings, ScriptContext.ENGINE_SCOPE)
    -  object replOutput extends ReplOutput(settings.Yreploutdir) { }
    -
    -  @deprecated("Use replOutput.dir instead", "2.11.0")
    -  def virtualDirectory = replOutput.dir
    -  // Used in a test case.
    -  def showDirectory() = replOutput.show(out)
    -
    -  private[nsc] var printResults               = true      // whether to print result lines
    -  private[nsc] var totalSilence               = false     // whether to print anything
    -  private var _initializeComplete             = false     // compiler is initialized
    -  private var _isInitialized: Future[Boolean] = null      // set up initialization future
    -  private var bindExceptions                  = true      // whether to bind the lastException variable
    -  private var _executionWrapper               = ""        // code to be wrapped around all lines
    -
    -  /** We're going to go to some trouble to initialize the compiler asynchronously.
    -    *  It's critical that nothing call into it until it's been initialized or we will
    -    *  run into unrecoverable issues, but the perceived repl startup time goes
    -    *  through the roof if we wait for it.  So we initialize it with a future and
    -    *  use a lazy val to ensure that any attempt to use the compiler object waits
    -    *  on the future.
    -    */
    -  private var _classLoader: util.AbstractFileClassLoader = null                              // active classloader
    -  private val _compiler: ReplGlobal                 = newCompiler(settings, reporter)   // our private compiler
    -
    -  def compilerClasspath: Seq[java.net.URL] = (
    -    if (isInitializeComplete) global.classPath.asURLs
    -    else new PathResolver(settings).result.asURLs  // the compiler's classpath
    -    )
    -  def settings = initialSettings
    -  // Run the code body with the given boolean settings flipped to true.
    -  def withoutWarnings[T](body: => T): T = beQuietDuring {
    -    val saved = settings.nowarn.value
    -    if (!saved)
    -      settings.nowarn.value = true
    -
    -    try body
    -    finally if (!saved) settings.nowarn.value = false
    -  }
    -
    -  /** construct an interpreter that reports to Console */
    -  def this(settings: Settings, out: JPrintWriter) = this(null, settings, out)
    -  def this(factory: ScriptEngineFactory, settings: Settings) = this(factory, settings, new NewLinePrintWriter(new ConsoleWriter, true))
    -  def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true))
    -  def this(factory: ScriptEngineFactory) = this(factory, new Settings())
    -  def this() = this(new Settings())
    -
    -  lazy val formatting: Formatting = new Formatting {
    -    val prompt = Properties.shellPromptString
    -  }
    -  lazy val reporter: SparkReplReporter = new SparkReplReporter(this)
    -
    -  import formatting._
    -  import reporter.{ printMessage, printUntruncatedMessage }
    -
    -  // This exists mostly because using the reporter too early leads to deadlock.
    -  private def echo(msg: String) { Console println msg }
    -  private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }"))
    -  private def _initialize() = {
    -    try {
    -      // if this crashes, REPL will hang its head in shame
    -      val run = new _compiler.Run()
    -      assert(run.typerPhase != NoPhase, "REPL requires a typer phase.")
    -      run compileSources _initSources
    -      _initializeComplete = true
    -      true
    -    }
    -    catch AbstractOrMissingHandler()
    -  }
    -  private def tquoted(s: String) = "\"\"\"" + s + "\"\"\""
    -  private val logScope = scala.sys.props contains "scala.repl.scope"
    -  private def scopelog(msg: String) = if (logScope) Console.err.println(msg)
    -
    -  // argument is a thunk to execute after init is done
    -  def initialize(postInitSignal: => Unit) {
    -    synchronized {
    -      if (_isInitialized == null) {
    -        _isInitialized =
    -          Future(try _initialize() finally postInitSignal)(ExecutionContext.global)
    -      }
    -    }
    -  }
    -  def initializeSynchronous(): Unit = {
    -    if (!isInitializeComplete) {
    -      _initialize()
    -      assert(global != null, global)
    -    }
    -  }
    -  def isInitializeComplete = _initializeComplete
    -
    -  lazy val global: Global = {
    -    if (!isInitializeComplete) _initialize()
    -    _compiler
    -  }
    -
    -  import global._
    -  import definitions.{ ObjectClass, termMember, dropNullaryMethod}
    -
    -  lazy val runtimeMirror = ru.runtimeMirror(classLoader)
    -
    -  private def noFatal(body: => Symbol): Symbol = try body catch { case _: FatalError => NoSymbol }
    -
    -  def getClassIfDefined(path: String)  = (
    -    noFatal(runtimeMirror staticClass path)
    -      orElse noFatal(rootMirror staticClass path)
    -    )
    -  def getModuleIfDefined(path: String) = (
    -    noFatal(runtimeMirror staticModule path)
    -      orElse noFatal(rootMirror staticModule path)
    -    )
    -
    -  implicit class ReplTypeOps(tp: Type) {
    -    def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp)
    -  }
    -
    -  // TODO: If we try to make naming a lazy val, we run into big time
    -  // scalac unhappiness with what look like cycles.  It has not been easy to
    -  // reduce, but name resolution clearly takes different paths.
    -  object naming extends {
    -    val global: imain.global.type = imain.global
    -  } with Naming {
    -    // make sure we don't overwrite their unwisely named res3 etc.
    -    def freshUserTermName(): TermName = {
    -      val name = newTermName(freshUserVarName())
    -      if (replScope containsName name) freshUserTermName()
    -      else name
    -    }
    -    def isInternalTermName(name: Name) = isInternalVarName("" + name)
    -  }
    -  import naming._
    -
    -  object deconstruct extends {
    -    val global: imain.global.type = imain.global
    -  } with StructuredTypeStrings
    -
    -  lazy val memberHandlers = new {
    -    val intp: imain.type = imain
    -  } with SparkMemberHandlers
    -  import memberHandlers._
    -
    -  /** Temporarily be quiet */
    -  def beQuietDuring[T](body: => T): T = {
    -    val saved = printResults
    -    printResults = false
    -    try body
    -    finally printResults = saved
    -  }
    -  def beSilentDuring[T](operation: => T): T = {
    -    val saved = totalSilence
    -    totalSilence = true
    -    try operation
    -    finally totalSilence = saved
    -  }
    -
    -  def quietRun[T](code: String) = beQuietDuring(interpret(code))
    -
    -  /** takes AnyRef because it may be binding a Throwable or an Exceptional */
    -  private def withLastExceptionLock[T](body: => T, alt: => T): T = {
    -    assert(bindExceptions, "withLastExceptionLock called incorrectly.")
    -    bindExceptions = false
    -
    -    try     beQuietDuring(body)
    -    catch   logAndDiscard("withLastExceptionLock", alt)
    -    finally bindExceptions = true
    -  }
    -
    -  def executionWrapper = _executionWrapper
    -  def setExecutionWrapper(code: String) = _executionWrapper = code
    -  def clearExecutionWrapper() = _executionWrapper = ""
    -
    -  /** interpreter settings */
    -  lazy val isettings = new SparkISettings(this)
    -
    -  /** Instantiate a compiler.  Overridable. */
    -  protected def newCompiler(settings: Settings, reporter: reporters.Reporter): ReplGlobal = {
    -    settings.outputDirs setSingleOutput replOutput.dir
    -    settings.exposeEmptyPackage.value = true
    -    new Global(settings, reporter) with ReplGlobal { override def toString: String = "" }
    -  }
    -
    -  /** Parent classloader.  Overridable. */
    -  protected def parentClassLoader: ClassLoader =
    -    settings.explicitParentLoader.getOrElse( this.getClass.getClassLoader() )
    -
    -  /* A single class loader is used for all commands interpreted by this Interpreter.
    -     It would also be possible to create a new class loader for each command
    -     to interpret.  The advantages of the current approach are:
    -
    -       - Expressions are only evaluated one time.  This is especially
    -         significant for I/O, e.g. "val x = Console.readLine"
    -
    -     The main disadvantage is:
    -
    -       - Objects, classes, and methods cannot be rebound.  Instead, definitions
    -         shadow the old ones, and old code objects refer to the old
    -         definitions.
    -  */
    -  def resetClassLoader() = {
    -    repldbg("Setting new classloader: was " + _classLoader)
    -    _classLoader = null
    -    ensureClassLoader()
    -  }
    -  final def ensureClassLoader() {
    -    if (_classLoader == null)
    -      _classLoader = makeClassLoader()
    -  }
    -  def classLoader: util.AbstractFileClassLoader = {
    -    ensureClassLoader()
    -    _classLoader
    -  }
    -
    -  def backticked(s: String): String = (
    -    (s split '.').toList map {
    -      case "_"                               => "_"
    -      case s if nme.keywords(newTermName(s)) => s"`$s`"
    -      case s                                 => s
    -    } mkString "."
    -    )
    -  def readRootPath(readPath: String) = getModuleIfDefined(readPath)
    -
    -  abstract class PhaseDependentOps {
    -    def shift[T](op: => T): T
    -
    -    def path(name: => Name): String = shift(path(symbolOfName(name)))
    -    def path(sym: Symbol): String = backticked(shift(sym.fullName))
    -    def sig(sym: Symbol): String  = shift(sym.defString)
    -  }
    -  object typerOp extends PhaseDependentOps {
    -    def shift[T](op: => T): T = exitingTyper(op)
    -  }
    -  object flatOp extends PhaseDependentOps {
    -    def shift[T](op: => T): T = exitingFlatten(op)
    -  }
    -
    -  def originalPath(name: String): String = originalPath(name: TermName)
    -  def originalPath(name: Name): String   = typerOp path name
    -  def originalPath(sym: Symbol): String  = typerOp path sym
    -  def flatPath(sym: Symbol): String      = flatOp shift sym.javaClassName
    -  def translatePath(path: String) = {
    -    val sym = if (path endsWith "$") symbolOfTerm(path.init) else symbolOfIdent(path)
    -    sym.toOption map flatPath
    -  }
    -  def translateEnclosingClass(n: String) = symbolOfTerm(n).enclClass.toOption map flatPath
    -
    -  private class TranslatingClassLoader(parent: ClassLoader) extends util.AbstractFileClassLoader(replOutput.dir, parent) {
    -    /** Overridden here to try translating a simple name to the generated
    -      *  class name if the original attempt fails.  This method is used by
    -      *  getResourceAsStream as well as findClass.
    -      */
    -    override protected def findAbstractFile(name: String): AbstractFile =
    -      super.findAbstractFile(name) match {
    -        case null if _initializeComplete => translatePath(name) map (super.findAbstractFile(_)) orNull
    -        case file => file
    -      }
    -  }
    -  private def makeClassLoader(): util.AbstractFileClassLoader =
    -    new TranslatingClassLoader(parentClassLoader match {
    -      case null   => ScalaClassLoader fromURLs compilerClasspath
    -      case p      => new ScalaClassLoader.URLClassLoader(compilerClasspath, p)
    -    })
    -
    -  // Set the current Java "context" class loader to this interpreter's class loader
    -  def setContextClassLoader() = classLoader.setAsContext()
    -
    -  def allDefinedNames: List[Name]  = exitingTyper(replScope.toList.map(_.name).sorted)
    -  def unqualifiedIds: List[String] = allDefinedNames map (_.decode) sorted
    -
    -  /** Most recent tree handled which wasn't wholly synthetic. */
    -  private def mostRecentlyHandledTree: Option[Tree] = {
    -    prevRequests.reverse foreach { req =>
    -      req.handlers.reverse foreach {
    -        case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member)
    -        case _ => ()
    -      }
    -    }
    -    None
    -  }
    -
    -  private def updateReplScope(sym: Symbol, isDefined: Boolean) {
    -    def log(what: String) {
    -      val mark = if (sym.isType) "t " else "v "
    -      val name = exitingTyper(sym.nameString)
    -      val info = cleanTypeAfterTyper(sym)
    -      val defn = sym defStringSeenAs info
    -
    -      scopelog(f"[$mark$what%6s] $name%-25s $defn%s")
    -    }
    -    if (ObjectClass isSubClass sym.owner) return
    -    // unlink previous
    -    replScope lookupAll sym.name foreach { sym =>
    -      log("unlink")
    -      replScope unlink sym
    -    }
    -    val what = if (isDefined) "define" else "import"
    -    log(what)
    -    replScope enter sym
    -  }
    -
    -  def recordRequest(req: Request) {
    -    if (req == null)
    -      return
    -
    -    prevRequests += req
    -
    -    // warning about serially defining companions.  It'd be easy
    -    // enough to just redefine them together but that may not always
    -    // be what people want so I'm waiting until I can do it better.
    -    exitingTyper {
    -      req.defines filterNot (s => req.defines contains s.companionSymbol) foreach { newSym =>
    -        val oldSym = replScope lookup newSym.name.companionName
    -        if (Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }) {
    -          replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.")
    -          replwarn("Companions must be defined together; you may wish to use :paste mode for this.")
    -        }
    -      }
    -    }
    -    exitingTyper {
    -      req.imports foreach (sym => updateReplScope(sym, isDefined = false))
    -      req.defines foreach (sym => updateReplScope(sym, isDefined = true))
    -    }
    -  }
    -
    -  private[nsc] def replwarn(msg: => String) {
    -    if (!settings.nowarnings)
    -      printMessage(msg)
    -  }
    -
    -  def compileSourcesKeepingRun(sources: SourceFile*) = {
    -    val run = new Run()
    -    assert(run.typerPhase != NoPhase, "REPL requires a typer phase.")
    -    reporter.reset()
    -    run compileSources sources.toList
    -    (!reporter.hasErrors, run)
    -  }
    -
    -  /** Compile an nsc SourceFile.  Returns true if there are
    -    *  no compilation errors, or false otherwise.
    -    */
    -  def compileSources(sources: SourceFile*): Boolean =
    -    compileSourcesKeepingRun(sources: _*)._1
    -
    -  /** Compile a string.  Returns true if there are no
    -    *  compilation errors, or false otherwise.
    -    */
    -  def compileString(code: String): Boolean =
    -    compileSources(new BatchSourceFile("
    -      
    -      
    +      
    +      
         // scalastyle:on
       }
     
    @@ -186,6 +186,8 @@ private[ui] class StreamingPage(parent: StreamingTab)
           
             {SparkUIUtils.formatDate(startTime)}
           
    +      ({listener.numTotalCompletedBatches}
    +      completed batches, {listener.numTotalReceivedRecords} records)
         

    } @@ -199,9 +201,9 @@ private[ui] class StreamingPage(parent: StreamingTab) * @param times all time values that will be used in the graphs. */ private def generateTimeMap(times: Seq[Long]): Seq[Node] = { - val dateFormat = new SimpleDateFormat("HH:mm:ss") val js = "var timeFormat = {};\n" + times.map { time => - val formattedTime = dateFormat.format(new Date(time)) + val formattedTime = + UIUtils.formatBatchTime(time, listener.batchDuration, showYYYYMMSS = false) s"timeFormat[$time] = '$formattedTime';" }.mkString("\n") @@ -244,17 +246,6 @@ private[ui] class StreamingPage(parent: StreamingTab) val maxEventRate = eventRateForAllStreams.max.map(_.ceil.toLong).getOrElse(0L) val minEventRate = 0L - // JavaScript to show/hide the InputDStreams sub table. - val triangleJs = - s"""$$('#inputs-table').toggle('collapsed'); - |var status = false; - |if ($$(this).html() == '$BLACK_RIGHT_TRIANGLE_HTML') { - |$$(this).html('$BLACK_DOWN_TRIANGLE_HTML');status = true;} - |else {$$(this).html('$BLACK_RIGHT_TRIANGLE_HTML');status = false;} - |window.history.pushState('', - | document.title, window.location.pathname + '?show-streams-detail=' + status);""" - .stripMargin.replaceAll("\\n", "") // it must be only one single line - val batchInterval = UIUtils.convertToTimeUnit(listener.batchDuration, normalizedUnit) val jsCollector = new JsCollector @@ -319,17 +310,25 @@ private[ui] class StreamingPage(parent: StreamingTab) Timelines (Last {batchTimes.length} batches, {numActiveBatches} active, {numCompletedBatches} completed) - Histograms + Histograms
    - {if (hasStream) { - {Unparsed(BLACK_RIGHT_TRIANGLE_HTML)} - }} - Input Rate + { + if (hasStream) { + + + + Input Rate + + + } else { + Input Rate + } + }
    Avg: {eventRateForAllStreams.formattedAvg} events/sec
    @@ -457,7 +456,7 @@ private[ui] class StreamingPage(parent: StreamingTab) {receiverActive} {receiverLocation} {receiverLastErrorTime} -
    {receiverLastError}
    +
    {receiverLastError}
    @@ -475,14 +474,14 @@ private[ui] class StreamingPage(parent: StreamingTab) val activeBatchesContent = {

    Active Batches ({runningBatches.size + waitingBatches.size})

    ++ - new ActiveBatchTable(runningBatches, waitingBatches).toNodeSeq + new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq } val completedBatchesContent = {

    Completed Batches (last {completedBatches.size} out of {listener.numTotalCompletedBatches})

    ++ - new CompletedBatchTable(completedBatches).toNodeSeq + new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq } activeBatchesContent ++ completedBatchesContent diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index f307b54bb9630..e0c0f57212f55 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -17,9 +17,11 @@ package org.apache.spark.streaming.ui +import org.eclipse.jetty.servlet.ServletContextHandler + import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.StreamingContext -import org.apache.spark.ui.{SparkUI, SparkUITab} +import org.apache.spark.ui.{JettyUtils, SparkUI, SparkUITab} import StreamingTab._ @@ -30,6 +32,8 @@ import StreamingTab._ private[spark] class StreamingTab(val ssc: StreamingContext) extends SparkUITab(getSparkUI(ssc), "streaming") with Logging { + private val STATIC_RESOURCE_DIR = "org/apache/spark/streaming/ui/static" + val parent = getSparkUI(ssc) val listener = ssc.progressListener @@ -38,12 +42,18 @@ private[spark] class StreamingTab(val ssc: StreamingContext) attachPage(new StreamingPage(this)) attachPage(new BatchPage(this)) + var staticHandler: ServletContextHandler = null + def attach() { getSparkUI(ssc).attachTab(this) + staticHandler = JettyUtils.createStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming") + getSparkUI(ssc).attachHandler(staticHandler) } def detach() { getSparkUI(ssc).detachTab(this) + getSparkUI(ssc).detachHandler(staticHandler) + staticHandler = null } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala index c206f973b2c66..86cfb1fa47370 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala @@ -17,9 +17,11 @@ package org.apache.spark.streaming.ui +import java.text.SimpleDateFormat +import java.util.TimeZone import java.util.concurrent.TimeUnit -object UIUtils { +private[streaming] object UIUtils { /** * Return the short string for a `TimeUnit`. @@ -62,7 +64,7 @@ object UIUtils { * Convert `milliseconds` to the specified `unit`. We cannot use `TimeUnit.convert` because it * will discard the fractional part. */ - def convertToTimeUnit(milliseconds: Long, unit: TimeUnit): Double = unit match { + def convertToTimeUnit(milliseconds: Long, unit: TimeUnit): Double = unit match { case TimeUnit.NANOSECONDS => milliseconds * 1000 * 1000 case TimeUnit.MICROSECONDS => milliseconds * 1000 case TimeUnit.MILLISECONDS => milliseconds @@ -71,4 +73,55 @@ object UIUtils { case TimeUnit.HOURS => milliseconds / 1000.0 / 60.0 / 60.0 case TimeUnit.DAYS => milliseconds / 1000.0 / 60.0 / 60.0 / 24.0 } + + // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. + private val batchTimeFormat = new ThreadLocal[SimpleDateFormat]() { + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + } + + private val batchTimeFormatWithMilliseconds = new ThreadLocal[SimpleDateFormat]() { + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss.SSS") + } + + /** + * If `batchInterval` is less than 1 second, format `batchTime` with milliseconds. Otherwise, + * format `batchTime` without milliseconds. + * + * @param batchTime the batch time to be formatted + * @param batchInterval the batch interval + * @param showYYYYMMSS if showing the `yyyy/MM/dd` part. If it's false, the return value wll be + * only `HH:mm:ss` or `HH:mm:ss.SSS` depending on `batchInterval` + * @param timezone only for test + */ + def formatBatchTime( + batchTime: Long, + batchInterval: Long, + showYYYYMMSS: Boolean = true, + timezone: TimeZone = null): String = { + val oldTimezones = + (batchTimeFormat.get.getTimeZone, batchTimeFormatWithMilliseconds.get.getTimeZone) + if (timezone != null) { + batchTimeFormat.get.setTimeZone(timezone) + batchTimeFormatWithMilliseconds.get.setTimeZone(timezone) + } + try { + val formattedBatchTime = + if (batchInterval < 1000) { + batchTimeFormatWithMilliseconds.get.format(batchTime) + } else { + // If batchInterval >= 1 second, don't show milliseconds + batchTimeFormat.get.format(batchTime) + } + if (showYYYYMMSS) { + formattedBatchTime + } else { + formattedBatchTime.substring(formattedBatchTime.indexOf(' ') + 1) + } + } finally { + if (timezone != null) { + batchTimeFormat.get.setTimeZone(oldTimezones._1) + batchTimeFormatWithMilliseconds.get.setTimeZone(oldTimezones._2) + } + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 87ba4f84a9ceb..fe6328b1ce727 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -200,7 +200,7 @@ private[streaming] class FileBasedWriteAheadLog( /** Initialize the log directory or recover existing logs inside the directory */ private def initializeOrRecover(): Unit = synchronized { val logDirectoryPath = new Path(logDirectory) - val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) + val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath }) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index 4d968f8bfa7a8..408936653c790 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -27,7 +27,7 @@ object RawTextHelper { * Splits lines and counts the words. */ def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { - val map = new OpenHashMap[String,Long] + val map = new OpenHashMap[String, Long] var i = 0 var j = 0 while (iter.hasNext) { @@ -98,7 +98,7 @@ object RawTextHelper { * before real workload starts. */ def warmUp(sc: SparkContext) { - for(i <- 0 to 1) { + for (i <- 0 to 1) { sc.parallelize(1 to 200000, 1000) .map(_ % 1331).map(_.toString) .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index ca2f319f174a2..6addb96752038 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -35,7 +35,9 @@ private[streaming] object RawTextSender extends Logging { def main(args: Array[String]) { if (args.length != 4) { + // scalastyle:off println System.err.println("Usage: RawTextSender ") + // scalastyle:on println System.exit(1) } // Parse the arguments using a pattern match diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index c8eef833eb431..dd32ad5ad811d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -106,7 +106,7 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: } private[streaming] -object RecurringTimer { +object RecurringTimer extends Logging { def main(args: Array[String]) { var lastRecurTime = 0L @@ -114,7 +114,7 @@ object RecurringTimer { def onRecur(time: Long) { val currentTime = System.currentTimeMillis() - println("" + currentTime + ": " + (currentTime - lastRecurTime)) + logInfo("" + currentTime + ": " + (currentTime - lastRecurTime)) lastRecurTime = currentTime } val timer = new RecurringTimer(new SystemClock(), period, onRecur, "Test") diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 2e00b980b9e44..a34f23475804a 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -364,6 +364,14 @@ private void testReduceByWindow(boolean withInverse) { @SuppressWarnings("unchecked") @Test public void testQueueStream() { + ssc.stop(); + // Create a new JavaStreamingContext without checkpointing + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + List> expected = Arrays.asList( Arrays.asList(1,2,3), Arrays.asList(4,5,6), @@ -1766,29 +1774,10 @@ public JavaStreamingContext call() { Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); - // Function to create JavaStreamingContext using existing JavaSparkContext - // without any output operations (used to detect the new context) - Function creatingFunc2 = - new Function() { - public JavaStreamingContext call(JavaSparkContext context) { - newContextCreated.set(true); - return new JavaStreamingContext(context, Seconds.apply(1)); - } - }; - - JavaSparkContext sc = new JavaSparkContext(conf); - newContextCreated.set(false); - ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc); - Assert.assertTrue("new context not created", newContextCreated.get()); - ssc.stop(false); - - newContextCreated.set(false); - ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true); - Assert.assertTrue("new context not created", newContextCreated.get()); - ssc.stop(false); - newContextCreated.set(false); - ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc); + JavaSparkContext sc = new JavaSparkContext(conf); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 87bc20f79c3cd..08faeaa58f419 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -255,7 +255,7 @@ class BasicOperationsSuite extends TestSuiteBase { Seq( ) ) val operation = (s1: DStream[String], s2: DStream[String]) => { - s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))).mapValues(x => (x._1.toSeq, x._2.toSeq)) + s1.map(x => (x, 1)).cogroup(s2.map(x => (x, "x"))).mapValues(x => (x._1.toSeq, x._2.toSeq)) } testOperation(inputData1, inputData2, operation, outputData, true) } @@ -427,9 +427,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("updateStateByKey - object lifecycle") { val inputData = Seq( - Seq("a","b"), + Seq("a", "b"), null, - Seq("a","c","a"), + Seq("a", "c", "a"), Seq("c"), null, null @@ -557,6 +557,9 @@ class BasicOperationsSuite extends TestSuiteBase { withTestServer(new TestServer()) { testServer => withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => testServer.start() + + val batchCounter = new BatchCounter(ssc) + // Set up the streaming context and input streams val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) @@ -587,7 +590,11 @@ class BasicOperationsSuite extends TestSuiteBase { for (i <- 0 until input.size) { testServer.send(input(i).toString + "\n") Thread.sleep(200) + val numCompletedBatches = batchCounter.getNumCompletedBatches clock.advance(batchDuration.milliseconds) + if (!batchCounter.waitUntilBatchesCompleted(numCompletedBatches + 1, 5000)) { + fail("Batch took more than 5 seconds to complete") + } collectRddInfo() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 6b0a3f91d4d06..6a94928076236 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -424,11 +424,11 @@ class CheckpointSuite extends TestSuiteBase { } } } - clock.advance(batchDuration.milliseconds) eventually(eventuallyTimeout) { // Wait until all files have been recorded and all batches have started assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } + clock.advance(batchDuration.milliseconds) // Wait for a checkpoint to be written eventually(eventuallyTimeout) { assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) @@ -454,9 +454,12 @@ class CheckpointSuite extends TestSuiteBase { // recorded before failure were saved and successfully recovered logInfo("*********** RESTARTING ************") withStreamingContext(new StreamingContext(checkpointDir)) { ssc => - // So that the restarted StreamingContext's clock has gone forward in time since failure - ssc.conf.set("spark.streaming.manualClock.jump", (batchDuration * 3).milliseconds.toString) - val oldClockTime = clock.getTimeMillis() + // "batchDuration.milliseconds * 3" has gone before restarting StreamingContext. And because + // the recovery time is read from the checkpoint time but the original clock doesn't align + // with the batch time, we need to add the offset "batchDuration.milliseconds / 2". + ssc.conf.set("spark.streaming.manualClock.jump", + (batchDuration.milliseconds / 2 + batchDuration.milliseconds * 3).toString) + val oldClockTime = clock.getTimeMillis() // 15000ms clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val batchCounter = new BatchCounter(ssc) val outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] @@ -467,10 +470,10 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() // Verify that the clock has traveled forward to the expected time eventually(eventuallyTimeout) { - clock.getTimeMillis() === oldClockTime + assert(clock.getTimeMillis() === oldClockTime) } - // Wait for pre-failure batch to be recomputed (3 while SSC was down plus last batch) - val numBatchesAfterRestart = 4 + // There are 5 batches between 6000ms and 15000ms (inclusive). + val numBatchesAfterRestart = 5 eventually(eventuallyTimeout) { assert(batchCounter.getNumCompletedBatches === numBatchesAfterRestart) } @@ -483,7 +486,6 @@ class CheckpointSuite extends TestSuiteBase { assert(batchCounter.getNumCompletedBatches === index + numBatchesAfterRestart + 1) } } - clock.advance(batchDuration.milliseconds) logInfo("Output after restart = " + outputStream.output.mkString("[", ", ", "]")) assert(outputStream.output.size > 0, "No files processed after restart") ssc.stop() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala new file mode 100644 index 0000000000000..9b5e4dc819a2b --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.NotSerializableException + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{HashPartitioner, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.ReturnStatementInClosureException + +/** + * Test that closures passed to DStream operations are actually cleaned. + */ +class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { + private var ssc: StreamingContext = null + + override def beforeAll(): Unit = { + val sc = new SparkContext("local", "test") + ssc = new StreamingContext(sc, Seconds(1)) + } + + override def afterAll(): Unit = { + ssc.stop(stopSparkContext = true) + ssc = null + } + + test("user provided closures are actually cleaned") { + val dstream = new DummyInputDStream(ssc) + val pairDstream = dstream.map { i => (i, i) } + // DStream + testMap(dstream) + testFlatMap(dstream) + testFilter(dstream) + testMapPartitions(dstream) + testReduce(dstream) + testForeach(dstream) + testForeachRDD(dstream) + testTransform(dstream) + testTransformWith(dstream) + testReduceByWindow(dstream) + // PairDStreamFunctions + testReduceByKey(pairDstream) + testCombineByKey(pairDstream) + testReduceByKeyAndWindow(pairDstream) + testUpdateStateByKey(pairDstream) + testMapValues(pairDstream) + testFlatMapValues(pairDstream) + // StreamingContext + testTransform2(ssc, dstream) + } + + /** + * Verify that the expected exception is thrown. + * + * We use return statements as an indication that a closure is actually being cleaned. + * We expect closure cleaner to find the return statements in the user provided closures. + */ + private def expectCorrectException(body: => Unit): Unit = { + try { + body + } catch { + case rse: ReturnStatementInClosureException => // Success! + case e @ (_: NotSerializableException | _: SparkException) => + throw new TestException( + s"Expected ReturnStatementInClosureException, but got $e.\n" + + "This means the closure provided by user is not actually cleaned.") + } + } + + // DStream operations + private def testMap(ds: DStream[Int]): Unit = expectCorrectException { + ds.map { _ => return; 1 } + } + private def testFlatMap(ds: DStream[Int]): Unit = expectCorrectException { + ds.flatMap { _ => return; Seq.empty } + } + private def testFilter(ds: DStream[Int]): Unit = expectCorrectException { + ds.filter { _ => return; true } + } + private def testMapPartitions(ds: DStream[Int]): Unit = expectCorrectException { + ds.mapPartitions { _ => return; Seq.empty.toIterator } + } + private def testReduce(ds: DStream[Int]): Unit = expectCorrectException { + ds.reduce { case (_, _) => return; 1 } + } + private def testForeach(ds: DStream[Int]): Unit = { + val foreachF1 = (rdd: RDD[Int], t: Time) => return + val foreachF2 = (rdd: RDD[Int]) => return + expectCorrectException { ds.foreach(foreachF1) } + expectCorrectException { ds.foreach(foreachF2) } + } + private def testForeachRDD(ds: DStream[Int]): Unit = { + val foreachRDDF1 = (rdd: RDD[Int], t: Time) => return + val foreachRDDF2 = (rdd: RDD[Int]) => return + expectCorrectException { ds.foreachRDD(foreachRDDF1) } + expectCorrectException { ds.foreachRDD(foreachRDDF2) } + } + private def testTransform(ds: DStream[Int]): Unit = { + val transformF1 = (rdd: RDD[Int]) => { return; rdd } + val transformF2 = (rdd: RDD[Int], time: Time) => { return; rdd } + expectCorrectException { ds.transform(transformF1) } + expectCorrectException { ds.transform(transformF2) } + } + private def testTransformWith(ds: DStream[Int]): Unit = { + val transformF1 = (rdd1: RDD[Int], rdd2: RDD[Int]) => { return; rdd1 } + val transformF2 = (rdd1: RDD[Int], rdd2: RDD[Int], time: Time) => { return; rdd2 } + expectCorrectException { ds.transformWith(ds, transformF1) } + expectCorrectException { ds.transformWith(ds, transformF2) } + } + private def testReduceByWindow(ds: DStream[Int]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + expectCorrectException { ds.reduceByWindow(reduceF, Seconds(1), Seconds(2)) } + expectCorrectException { ds.reduceByWindow(reduceF, reduceF, Seconds(1), Seconds(2)) } + } + + // PairDStreamFunctions operations + private def testReduceByKey(ds: DStream[(Int, Int)]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + expectCorrectException { ds.reduceByKey(reduceF) } + expectCorrectException { ds.reduceByKey(reduceF, 5) } + expectCorrectException { ds.reduceByKey(reduceF, new HashPartitioner(5)) } + } + private def testCombineByKey(ds: DStream[(Int, Int)]): Unit = { + expectCorrectException { + ds.combineByKey[Int]( + { _: Int => return; 1 }, + { case (_: Int, _: Int) => return; 1 }, + { case (_: Int, _: Int) => return; 1 }, + new HashPartitioner(5) + ) + } + } + private def testReduceByKeyAndWindow(ds: DStream[(Int, Int)]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + val filterF = (_: (Int, Int)) => { return; false } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1)) } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2)) } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2), 5) } + expectCorrectException { + ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2), new HashPartitioner(5)) + } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, reduceF, Seconds(2)) } + expectCorrectException { + ds.reduceByKeyAndWindow( + reduceF, reduceF, Seconds(2), Seconds(3), new HashPartitioner(5), filterF) + } + } + private def testUpdateStateByKey(ds: DStream[(Int, Int)]): Unit = { + val updateF1 = (_: Seq[Int], _: Option[Int]) => { return; Some(1) } + val updateF2 = (_: Iterator[(Int, Seq[Int], Option[Int])]) => { return; Seq((1, 1)).toIterator } + val initialRDD = ds.ssc.sparkContext.emptyRDD[Int].map { i => (i, i) } + expectCorrectException { ds.updateStateByKey(updateF1) } + expectCorrectException { ds.updateStateByKey(updateF1, 5) } + expectCorrectException { ds.updateStateByKey(updateF1, new HashPartitioner(5)) } + expectCorrectException { + ds.updateStateByKey(updateF1, new HashPartitioner(5), initialRDD) + } + expectCorrectException { + ds.updateStateByKey(updateF2, new HashPartitioner(5), true) + } + expectCorrectException { + ds.updateStateByKey(updateF2, new HashPartitioner(5), true, initialRDD) + } + } + private def testMapValues(ds: DStream[(Int, Int)]): Unit = expectCorrectException { + ds.mapValues { _ => return; 1 } + } + private def testFlatMapValues(ds: DStream[(Int, Int)]): Unit = expectCorrectException { + ds.flatMapValues { _ => return; Seq.empty } + } + + // StreamingContext operations + private def testTransform2(ssc: StreamingContext, ds: DStream[Int]): Unit = { + val transformF = (rdds: Seq[RDD[_]], time: Time) => { return; ssc.sparkContext.emptyRDD[Int] } + expectCorrectException { ssc.transform(Seq(ds), transformF) } + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala new file mode 100644 index 0000000000000..8844c9d74b933 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.rdd.RDDOperationScope +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.ui.UIUtils + +/** + * Tests whether scope information is passed from DStream operations to RDDs correctly. + */ +class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { + private var ssc: StreamingContext = null + private val batchDuration: Duration = Seconds(1) + + override def beforeAll(): Unit = { + ssc = new StreamingContext(new SparkContext("local", "test"), batchDuration) + } + + override def afterAll(): Unit = { + ssc.stop(stopSparkContext = true) + } + + before { assertPropertiesNotSet() } + after { assertPropertiesNotSet() } + + test("dstream without scope") { + val dummyStream = new DummyDStream(ssc) + dummyStream.initialize(Time(0)) + + // This DStream is not instantiated in any scope, so all RDDs + // created by this stream should similarly not have a scope + assert(dummyStream.baseScope === None) + assert(dummyStream.getOrCompute(Time(1000)).get.scope === None) + assert(dummyStream.getOrCompute(Time(2000)).get.scope === None) + assert(dummyStream.getOrCompute(Time(3000)).get.scope === None) + } + + test("input dstream without scope") { + val inputStream = new DummyInputDStream(ssc) + inputStream.initialize(Time(0)) + + val baseScope = inputStream.baseScope.map(RDDOperationScope.fromJson) + val scope1 = inputStream.getOrCompute(Time(1000)).get.scope + val scope2 = inputStream.getOrCompute(Time(2000)).get.scope + val scope3 = inputStream.getOrCompute(Time(3000)).get.scope + + // This DStream is not instantiated in any scope, so all RDDs + assertDefined(baseScope, scope1, scope2, scope3) + assert(baseScope.get.name.startsWith("dummy stream")) + assertScopeCorrect(baseScope.get, scope1.get, 1000) + assertScopeCorrect(baseScope.get, scope2.get, 2000) + assertScopeCorrect(baseScope.get, scope3.get, 3000) + } + + test("scoping simple operations") { + val inputStream = new DummyInputDStream(ssc) + val mappedStream = inputStream.map { i => i + 1 } + val filteredStream = mappedStream.filter { i => i % 2 == 0 } + filteredStream.initialize(Time(0)) + + val mappedScopeBase = mappedStream.baseScope.map(RDDOperationScope.fromJson) + val mappedScope1 = mappedStream.getOrCompute(Time(1000)).get.scope + val mappedScope2 = mappedStream.getOrCompute(Time(2000)).get.scope + val mappedScope3 = mappedStream.getOrCompute(Time(3000)).get.scope + val filteredScopeBase = filteredStream.baseScope.map(RDDOperationScope.fromJson) + val filteredScope1 = filteredStream.getOrCompute(Time(1000)).get.scope + val filteredScope2 = filteredStream.getOrCompute(Time(2000)).get.scope + val filteredScope3 = filteredStream.getOrCompute(Time(3000)).get.scope + + // These streams are defined in their respective scopes "map" and "filter", so all + // RDDs created by these streams should inherit the IDs and names of their parent + // DStream's base scopes + assertDefined(mappedScopeBase, mappedScope1, mappedScope2, mappedScope3) + assertDefined(filteredScopeBase, filteredScope1, filteredScope2, filteredScope3) + assert(mappedScopeBase.get.name === "map") + assert(filteredScopeBase.get.name === "filter") + assertScopeCorrect(mappedScopeBase.get, mappedScope1.get, 1000) + assertScopeCorrect(mappedScopeBase.get, mappedScope2.get, 2000) + assertScopeCorrect(mappedScopeBase.get, mappedScope3.get, 3000) + assertScopeCorrect(filteredScopeBase.get, filteredScope1.get, 1000) + assertScopeCorrect(filteredScopeBase.get, filteredScope2.get, 2000) + assertScopeCorrect(filteredScopeBase.get, filteredScope3.get, 3000) + } + + test("scoping nested operations") { + val inputStream = new DummyInputDStream(ssc) + val countStream = inputStream.countByWindow(Seconds(10), Seconds(1)) + countStream.initialize(Time(0)) + + val countScopeBase = countStream.baseScope.map(RDDOperationScope.fromJson) + val countScope1 = countStream.getOrCompute(Time(1000)).get.scope + val countScope2 = countStream.getOrCompute(Time(2000)).get.scope + val countScope3 = countStream.getOrCompute(Time(3000)).get.scope + + // Assert that all children RDDs inherit the DStream operation name correctly + assertDefined(countScopeBase, countScope1, countScope2, countScope3) + assert(countScopeBase.get.name === "countByWindow") + assertScopeCorrect(countScopeBase.get, countScope1.get, 1000) + assertScopeCorrect(countScopeBase.get, countScope2.get, 2000) + assertScopeCorrect(countScopeBase.get, countScope3.get, 3000) + + // All streams except the input stream should share the same scopes as `countStream` + def testStream(stream: DStream[_]): Unit = { + if (stream != inputStream) { + val myScopeBase = stream.baseScope.map(RDDOperationScope.fromJson) + val myScope1 = stream.getOrCompute(Time(1000)).get.scope + val myScope2 = stream.getOrCompute(Time(2000)).get.scope + val myScope3 = stream.getOrCompute(Time(3000)).get.scope + assertDefined(myScopeBase, myScope1, myScope2, myScope3) + assert(myScopeBase === countScopeBase) + assert(myScope1 === countScope1) + assert(myScope2 === countScope2) + assert(myScope3 === countScope3) + // Climb upwards to test the parent streams + stream.dependencies.foreach(testStream) + } + } + testStream(countStream) + } + + /** Assert that the RDD operation scope properties are not set in our SparkContext. */ + private def assertPropertiesNotSet(): Unit = { + assert(ssc != null) + assert(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY) == null) + assert(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY) == null) + } + + /** Assert that the given RDD scope inherits the name and ID of the base scope correctly. */ + private def assertScopeCorrect( + baseScope: RDDOperationScope, + rddScope: RDDOperationScope, + batchTime: Long): Unit = { + assertScopeCorrect(baseScope.id, baseScope.name, rddScope, batchTime) + } + + /** Assert that the given RDD scope inherits the base name and ID correctly. */ + private def assertScopeCorrect( + baseScopeId: String, + baseScopeName: String, + rddScope: RDDOperationScope, + batchTime: Long): Unit = { + val formattedBatchTime = UIUtils.formatBatchTime( + batchTime, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) + assert(rddScope.id === s"${baseScopeId}_$batchTime") + assert(rddScope.name.replaceAll("\\n", " ") === s"$baseScopeName @ $formattedBatchTime") + } + + /** Assert that all the specified options are defined. */ + private def assertDefined[T](options: Option[T]*): Unit = { + options.zipWithIndex.foreach { case (o, i) => assert(o.isDefined, s"Option $i was empty!") } + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 93e6b0cd7c661..b74d67c63a788 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.scheduler.{StreamingListenerBatchCompleted, StreamingListener} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -105,6 +106,36 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } + test("socket input stream - no block in a batch") { + withTestServer(new TestServer()) { testServer => + testServer.start() + + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + ssc.addStreamingListener(ssc.progressListener) + + val batchCounter = new BatchCounter(ssc) + val networkStream = ssc.socketTextStream( + "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + outputStream.register() + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds) + + // Make sure the first batch is finished + if (!batchCounter.waitUntilBatchesCompleted(1, 30000)) { + fail("Timeout: cannot finish all batches in 30 seconds") + } + + networkStream.generatedRDDs.foreach { case (_, rdd) => + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + } + } + } + } + test("binary records stream") { val testDir: File = null try { @@ -387,7 +418,7 @@ class TestServer(portToBind: Int = 0) extends Logging { val servingThread = new Thread() { override def run() { try { - while(true) { + while (true) { logInfo("Accepting connections on port " + port) val clientSocket = serverSocket.accept() if (startLatch.getCount == 1) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index e0f14fd954280..6e9d4431090a2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -43,6 +43,7 @@ object MasterFailureTest extends Logging { @volatile var setupCalled = false def main(args: Array[String]) { + // scalastyle:off println if (args.size < 2) { println( "Usage: MasterFailureTest <# batches> " + @@ -60,6 +61,7 @@ object MasterFailureTest extends Logging { testUpdateStateByKey(directory, numBatches, batchDuration) println("\n\nSUCCESS\n\n") + // scalastyle:on println } def testMap(directory: String, numBatches: Int, batchDuration: Duration) { @@ -291,10 +293,12 @@ object MasterFailureTest extends Logging { } // Log the output + // scalastyle:off println println("Expected output, size = " + expectedOutput.size) println(expectedOutput.mkString("[", ",", "]")) println("Output, size = " + output.size) println(output.mkString("[", ",", "]")) + // scalastyle:on println // Match the output with the expected output output.foreach(o => diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 23804237bda80..6c0c926755c20 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -25,7 +25,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ @@ -41,11 +41,14 @@ import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ -class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class ReceivedBlockHandlerSuite + extends SparkFunSuite + with BeforeAndAfter + with Matchers + with Logging { val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val hadoopConf = new Configuration() - val storageLevel = StorageLevel.MEMORY_ONLY_SER val streamId = 1 val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -53,10 +56,12 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val serializer = new KryoSerializer(conf) val manualClock = new ManualClock val blockManagerSize = 10000000 + val blockManagerBuffer = new ArrayBuffer[BlockManager]() var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null + var storageLevel: StorageLevel = null var tempDirectory: File = null before { @@ -66,20 +71,21 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, - blockManagerSize, conf, mapOutputTracker, shuffleManager, - new NioBlockTransferService(conf, securityMgr), securityMgr, 0) - blockManager.initialize("app-id") + storageLevel = StorageLevel.MEMORY_ONLY_SER + blockManager = createBlockManager(blockManagerSize, conf) tempDirectory = Utils.createTempDir() manualClock.setTime(0) } after { - if (blockManager != null) { - blockManager.stop() - blockManager = null + for ( blockManager <- blockManagerBuffer ) { + if (blockManager != null) { + blockManager.stop() + } } + blockManager = null + blockManagerBuffer.clear() if (blockManagerMaster != null) { blockManagerMaster.stop() blockManagerMaster = null @@ -170,6 +176,130 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche } } + test("Test Block - count messages") { + // Test count with BlockManagedBasedBlockHandler + testCountWithBlockManagerBasedBlockHandler(true) + // Test count with WriteAheadLogBasedBlockHandler + testCountWithBlockManagerBasedBlockHandler(false) + } + + test("Test Block - isFullyConsumed") { + val sparkConf = new SparkConf() + sparkConf.set("spark.storage.unrollMemoryThreshold", "512") + // spark.storage.unrollFraction set to 0.4 for BlockManager + sparkConf.set("spark.storage.unrollFraction", "0.4") + // Block Manager with 12000 * 0.4 = 4800 bytes of free space for unroll + blockManager = createBlockManager(12000, sparkConf) + + // there is not enough space to store this block in MEMORY, + // But BlockManager will be able to sereliaze this block to WAL + // and hence count returns correct value. + testRecordcount(false, StorageLevel.MEMORY_ONLY, + IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) + + // there is not enough space to store this block in MEMORY, + // But BlockManager will be able to sereliaze this block to DISK + // and hence count returns correct value. + testRecordcount(true, StorageLevel.MEMORY_AND_DISK, + IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) + + // there is not enough space to store this block With MEMORY_ONLY StorageLevel. + // BlockManager will not be able to unroll this block + // and hence it will not tryToPut this block, resulting the SparkException + storageLevel = StorageLevel.MEMORY_ONLY + withBlockManagerBasedBlockHandler { handler => + val thrown = intercept[SparkException] { + storeSingleBlock(handler, IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator)) + } + } + } + + private def testCountWithBlockManagerBasedBlockHandler(isBlockManagerBasedBlockHandler: Boolean) { + // ByteBufferBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None) + // ByteBufferBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None) + // ArrayBufferBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25)) + // ArrayBufferBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25)) + // ArrayBufferBlock-DISK_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY, + ArrayBufferBlock(ArrayBuffer.fill(50)(0)), blockManager, Some(50)) + // ArrayBufferBlock-MEMORY_AND_DISK + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK, + ArrayBufferBlock(ArrayBuffer.fill(75)(0)), blockManager, Some(75)) + // IteratorBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100)) + // IteratorBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100)) + // IteratorBlock-DISK_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY, + IteratorBlock((ArrayBuffer.fill(125)(0)).iterator), blockManager, Some(125)) + // IteratorBlock-MEMORY_AND_DISK + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK, + IteratorBlock((ArrayBuffer.fill(150)(0)).iterator), blockManager, Some(150)) + } + + private def createBlockManager( + maxMem: Long, + conf: SparkConf, + name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { + val transfer = new NioBlockTransferService(conf, securityMgr) + val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf, + mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + manager.initialize("app-id") + blockManagerBuffer += manager + manager + } + + /** + * Test storing of data using different types of Handler, StorageLevle and ReceivedBlocks + * and verify the correct record count + */ + private def testRecordcount(isBlockManagedBasedBlockHandler: Boolean, + sLevel: StorageLevel, + receivedBlock: ReceivedBlock, + bManager: BlockManager, + expectedNumRecords: Option[Long] + ) { + blockManager = bManager + storageLevel = sLevel + var bId: StreamBlockId = null + try { + if (isBlockManagedBasedBlockHandler) { + // test received block with BlockManager based handler + withBlockManagerBasedBlockHandler { handler => + val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock) + bId = blockId + assert(blockStoreResult.numRecords === expectedNumRecords, + "Message count not matches for a " + + receivedBlock.getClass.getName + + " being inserted using BlockManagerBasedBlockHandler with " + sLevel) + } + } else { + // test received block with WAL based handler + withWriteAheadLogBasedBlockHandler { handler => + val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock) + bId = blockId + assert(blockStoreResult.numRecords === expectedNumRecords, + "Message count not matches for a " + + receivedBlock.getClass.getName + + " being inserted using WriteAheadLogBasedBlockHandler with " + sLevel) + } + } + } finally { + // Removing the Block Id to use same blockManager for next test + blockManager.removeBlock(bId, true) + } + } + /** * Test storing of data using different forms of ReceivedBlocks and verify that they succeeded * using the given verification function @@ -247,9 +377,21 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche (blockIds, storeResults) } + /** Store single block using a handler */ + private def storeSingleBlock( + handler: ReceivedBlockHandler, + block: ReceivedBlock + ): (StreamBlockId, ReceivedBlockStoreResult) = { + val blockId = generateBlockId + val blockStoreResult = handler.storeBlock(blockId, block) + logDebug("Done inserting") + (blockId, blockStoreResult) + } + private def getWriteAheadLogFiles(): Seq[String] = { getLogFilesInDirectory(checkpointDirToLogDir(tempDirectory.toString, streamId)) } private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong) } + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index b1af8d5eaacfb..f793a12843b2f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -25,10 +25,10 @@ import scala.language.{implicitConversions, postfixOps} import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ @@ -37,7 +37,7 @@ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} class ReceivedBlockTrackerSuite - extends FunSuite with BeforeAndAfter with Matchers with Logging { + extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val hadoopConf = new Configuration() val akkaTimeout = 10 seconds @@ -224,8 +224,8 @@ class ReceivedBlockTrackerSuite /** Generate blocks infos using random ids */ def generateBlockInfos(): Seq[ReceivedBlockInfo] = { - List.fill(5)(ReceivedBlockInfo(streamId, 0, None, - BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) + List.fill(5)(ReceivedBlockInfo(streamId, Some(0L), None, + BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)), Some(0L)))) } /** Get all the data written in the given write ahead log file. */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 5f93332896de1..f588cf5bc1e7c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,24 +17,29 @@ package org.apache.spark.streaming -import java.io.File +import java.io.{File, NotSerializableException} import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.Queue + import org.apache.commons.io.FileUtils -import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} -import org.scalatest.concurrent.Timeouts +import org.scalatest.{Assertions, BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { +class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging { val master = "local[2]" val appName = this.getClass.getSimpleName @@ -110,6 +115,15 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } + test("checkPoint from conf") { + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.streaming.checkpoint.directory", checkpointDirectory) + val ssc = new StreamingContext(myConf, batchDuration) + assert(ssc.checkpointDir != null) + } + test("state matching") { import StreamingContextState._ assert(INITIALIZED === INITIALIZED) @@ -132,6 +146,41 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } + test("start with non-seriazable DStream checkpoints") { + val checkpointDir = Utils.createTempDir() + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDir.getAbsolutePath) + addInputStream(ssc).foreachRDD { rdd => + // Refer to this.appName from inside closure so that this closure refers to + // the instance of StreamingContextSuite, and is therefore not serializable + rdd.count() + appName + } + + // Test whether start() fails early when checkpointing is enabled + val exception = intercept[NotSerializableException] { + ssc.start() + } + assert(exception.getMessage().contains("DStreams with their functions are not serializable")) + assert(ssc.getState() !== StreamingContextState.ACTIVE) + assert(StreamingContext.getActive().isEmpty) + } + + test("start failure should stop internal components") { + ssc = new StreamingContext(conf, batchDuration) + val inputStream = addInputStream(ssc) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + inputStream.map(x => (x, 1)).updateStateByKey[Int](updateFunc) + // Require that the start fails because checkpoint directory was not set + intercept[Exception] { + ssc.start() + } + assert(ssc.getState() === StreamingContextState.STOPPED) + assert(ssc.scheduler.isStarted === false) + } + + test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() @@ -163,7 +212,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() ssc.stop() - intercept[SparkException] { + intercept[IllegalStateException] { ssc.start() // start after stop should throw exception } assert(ssc.getState() === StreamingContextState.STOPPED) @@ -262,6 +311,25 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w Thread.sleep(100) } + test ("registering and de-registering of streamingSource") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + ssc = new StreamingContext(conf, batchDuration) + assert(ssc.getState() === StreamingContextState.INITIALIZED) + addInputStream(ssc).register() + ssc.start() + + val sources = StreamingContextSuite.getSources(ssc.env.metricsSystem) + val streamingSource = StreamingContextSuite.getStreamingSource(ssc) + assert(sources.contains(streamingSource)) + assert(ssc.getState() === StreamingContextState.ACTIVE) + + ssc.stop() + val sourcesAfterStop = StreamingContextSuite.getSources(ssc.env.metricsSystem) + val streamingSourceAfterStop = StreamingContextSuite.getStreamingSource(ssc) + assert(ssc.getState() === StreamingContextState.STOPPED) + assert(!sourcesAfterStop.contains(streamingSourceAfterStop)) + } + test("awaitTermination") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) @@ -419,76 +487,16 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) assert(ssc != null, "no context created") assert(!newContextCreated, "old context not recovered") - assert(ssc.conf.get("someKey") === "someValue") - } - } - - test("getOrCreate with existing SparkContext") { - val conf = new SparkConf().setMaster(master).setAppName(appName) - sc = new SparkContext(conf) - - // Function to create StreamingContext that has a config to identify it to be new context - var newContextCreated = false - def creatingFunction(sparkContext: SparkContext): StreamingContext = { - newContextCreated = true - new StreamingContext(sparkContext, batchDuration) - } - - // Call ssc.stop(stopSparkContext = false) after a body of cody - def testGetOrCreate(body: => Unit): Unit = { - newContextCreated = false - try { - body - } finally { - if (ssc != null) { - ssc.stop(stopSparkContext = false) - } - ssc = null - } + assert(ssc.conf.get("someKey") === "someValue", "checkpointed config not recovered") } - val emptyPath = Utils.createTempDir().getAbsolutePath() - - // getOrCreate should create new context with empty path + // getOrCreate should recover StreamingContext with existing SparkContext testGetOrCreate { - ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - } - - val corrutedCheckpointPath = createCorruptedCheckpoint() - - // getOrCreate should throw exception with fake checkpoint file and createOnError = false - intercept[Exception] { - ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc) - } - - // getOrCreate should throw exception with fake checkpoint file - intercept[Exception] { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, sc, createOnError = false) - } - - // getOrCreate should create new context with fake checkpoint file and createOnError = true - testGetOrCreate { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, sc, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - } - - val checkpointPath = createValidCheckpoint() - - // StreamingContext.getOrCreate should recover context with checkpoint path - testGetOrCreate { - ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc) + sc = new SparkContext(conf) + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) assert(ssc != null, "no context created") assert(!newContextCreated, "old context not recovered") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - assert(!ssc.conf.contains("someKey"), - "recovered StreamingContext unexpectedly has old config") + assert(!ssc.conf.contains("someKey"), "checkpointed config unexpectedly recovered") } } @@ -641,7 +649,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val anotherInput = addInputStream(anotherSsc) anotherInput.foreachRDD { rdd => rdd.count } - val exception = intercept[SparkException] { + val exception = intercept[IllegalStateException] { anotherSsc.start() } assert(exception.getMessage.contains("StreamingContext"), "Did not get the right exception") @@ -664,7 +672,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w def testForException(clue: String, expectedErrorMsg: String)(body: => Unit): Unit = { withClue(clue) { - val ex = intercept[SparkException] { + val ex = intercept[IllegalStateException] { body } assert(ex.getMessage.toLowerCase().contains(expectedErrorMsg)) @@ -690,6 +698,19 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w transformed.foreachRDD { rdd => rdd.collect() } } } + test("queueStream doesn't support checkpointing") { + val checkpointDir = Utils.createTempDir() + ssc = new StreamingContext(master, appName, batchDuration) + val rdd = ssc.sparkContext.parallelize(1 to 10) + ssc.queueStream[Int](Queue(rdd)).print() + ssc.checkpoint(checkpointDir.getAbsolutePath) + val e = intercept[NotSerializableException] { + ssc.start() + } + // StreamingContext.validate changes the message, so use "contains" here + assert(e.getMessage.contains("queueStream doesn't support checkpointing")) + } + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => 1 to i) val inputStream = new TestInputStream(s, input, 1) @@ -773,7 +794,9 @@ class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) def onStop() { // Simulate slow receiver by waiting for all records to be produced - while(!SlowTestReceiver.receivedAllRecords) Thread.sleep(100) + while (!SlowTestReceiver.receivedAllRecords) { + Thread.sleep(100) + } // no clean to be done, the receiving thread should stop on it own } } @@ -819,3 +842,18 @@ package object testPackage extends Assertions { } } } + +/** + * Helper methods for testing StreamingContextSuite + * This includes methods to access private methods and fields in StreamingContext and MetricsSystem + */ +private object StreamingContextSuite extends PrivateMethodTester { + private val _sources = PrivateMethod[ArrayBuffer[Source]]('sources) + private def getSources(metricsSystem: MetricsSystem): ArrayBuffer[Source] = { + metricsSystem.invokePrivate(_sources()) + } + private val _streamingSource = PrivateMethod[StreamingSource]('streamingSource) + private def getStreamingSource(streamingContext: StreamingContext): StreamingSource = { + streamingContext.invokePrivate(_streamingSource()) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 312cce408cfe7..4bc1dd4a30fc4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -59,7 +59,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosSubmitted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) @@ -77,7 +77,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosStarted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) @@ -98,7 +98,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosCompleted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) @@ -116,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { ssc.start() try { - eventually(timeout(2000 millis), interval(20 millis)) { + eventually(timeout(30 seconds), interval(20 millis)) { collector.startedReceiverStreamIds.size should equal (1) collector.startedReceiverStreamIds(0) should equal (0) collector.stoppedReceiverStreamIds should have size 1 @@ -133,8 +133,10 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { /** Check if a sequence of numbers is in increasing order */ def isInIncreasingOrder(seq: Seq[Long]): Boolean = { - for(i <- 1 until seq.size) { - if (seq(i - 1) > seq(i)) return false + for (i <- 1 until seq.size) { + if (seq(i - 1) > seq(i)) { + return false + } } true } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 4f70ae7f1f187..0d58a7b54412f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -24,17 +24,35 @@ import scala.collection.mutable.SynchronizedBuffer import scala.language.implicitConversions import scala.reflect.ClassTag -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.time.{Span, Seconds => ScalaTestSeconds} import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{ManualClock, Utils} +/** + * A dummy stream that does absolutely nothing. + */ +private[streaming] class DummyDStream(ssc: StreamingContext) extends DStream[Int](ssc) { + override def dependencies: List[DStream[Int]] = List.empty + override def slideDuration: Duration = Seconds(1) + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.emptyRDD[Int]) +} + +/** + * A dummy input stream that does absolutely nothing. + */ +private[streaming] class DummyInputDStream(ssc: StreamingContext) extends InputDStream[Int](ssc) { + override def start(): Unit = { } + override def stop(): Unit = { } + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.emptyRDD[Int]) +} + /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, * replayable, reliable message queue like Kafka. It requires a sequence as input, and @@ -58,7 +76,7 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], } // Report the input data's information to InputInfoTracker for testing - val inputInfo = InputInfo(id, selectedInput.length.toLong) + val inputInfo = StreamInputInfo(id, selectedInput.length.toLong) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) @@ -186,7 +204,7 @@ class BatchCounter(ssc: StreamingContext) { * This is the base trait for Spark Streaming testsuites. This provides basic functionality * to run user-defined set of input on user-defined stream operations, and verify the output. */ -trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { +trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context def framework: String = this.getClass.getSimpleName diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 441bbf95d0153..a08578680cff9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -27,20 +27,20 @@ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ import org.apache.spark._ - - - +import org.apache.spark.ui.SparkUICssErrorHandler /** - * Selenium tests for the Spark Web UI. + * Selenium tests for the Spark Streaming Web UI. */ class UISeleniumSuite - extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { + extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { implicit var webDriver: WebDriver = _ override def beforeAll(): Unit = { - webDriver = new HtmlUnitDriver + webDriver = new HtmlUnitDriver { + getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) + } } override def afterAll(): Unit = { @@ -197,4 +197,3 @@ class UISeleniumSuite } } } - diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 6859b65c7165f..cb017b798b2a4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -21,15 +21,15 @@ import java.io.File import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} class WriteAheadLogBackedBlockRDDSuite - extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterEach { val conf = new SparkConf() .setMaster("local[2]") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index 5478b41845943..f5248acf712b9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.streaming.scheduler -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.streaming.{Time, Duration, StreamingContext} -class InputInfoTrackerSuite extends FunSuite with BeforeAndAfter { +class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { private var ssc: StreamingContext = _ @@ -46,8 +46,8 @@ class InputInfoTrackerSuite extends FunSuite with BeforeAndAfter { val streamId1 = 0 val streamId2 = 1 val time = Time(0L) - val inputInfo1 = InputInfo(streamId1, 100L) - val inputInfo2 = InputInfo(streamId2, 300L) + val inputInfo1 = StreamInputInfo(streamId1, 100L) + val inputInfo2 = StreamInputInfo(streamId2, 300L) inputInfoTracker.reportInfo(time, inputInfo1) inputInfoTracker.reportInfo(time, inputInfo2) @@ -63,8 +63,8 @@ class InputInfoTrackerSuite extends FunSuite with BeforeAndAfter { val inputInfoTracker = new InputInfoTracker(ssc) val streamId1 = 0 - val inputInfo1 = InputInfo(streamId1, 100L) - val inputInfo2 = InputInfo(streamId1, 300L) + val inputInfo1 = StreamInputInfo(streamId1, 100L) + val inputInfo2 = StreamInputInfo(streamId1, 300L) inputInfoTracker.reportInfo(Time(0), inputInfo1) inputInfoTracker.reportInfo(Time(1), inputInfo2) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index 7865b06c2e3c2..a2dbae149f311 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -76,7 +76,6 @@ class JobGeneratorSuite extends TestSuiteBase { if (time.milliseconds == longBatchTime) { while (waitLatch.getCount() > 0) { waitLatch.await() - println("Await over") } } }) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala new file mode 100644 index 0000000000000..a6e783861dbe6 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.streaming._ +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.receiver._ +import org.apache.spark.util.Utils + +/** Testsuite for receiver scheduling */ +class ReceiverTrackerSuite extends TestSuiteBase { + val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") + val ssc = new StreamingContext(sparkConf, Milliseconds(100)) + val tracker = new ReceiverTracker(ssc) + val launcher = new tracker.ReceiverLauncher() + val executors: List[String] = List("0", "1", "2", "3") + + test("receiver scheduling - all or none have preferred location") { + + def parse(s: String): Array[Array[String]] = { + val outerSplit = s.split("\\|") + val loc = new Array[Array[String]](outerSplit.length) + var i = 0 + for (i <- 0 until outerSplit.length) { + loc(i) = outerSplit(i).split("\\,") + } + loc + } + + def testScheduler(numReceivers: Int, preferredLocation: Boolean, allocation: String) { + val receivers = + if (preferredLocation) { + Array.tabulate(numReceivers)(i => new DummyReceiver(host = + Some(((i + 1) % executors.length).toString))) + } else { + Array.tabulate(numReceivers)(_ => new DummyReceiver) + } + val locations = launcher.scheduleReceivers(receivers, executors) + val expectedLocations = parse(allocation) + assert(locations.deep === expectedLocations.deep) + } + + testScheduler(numReceivers = 5, preferredLocation = false, allocation = "0|1|2|3|0") + testScheduler(numReceivers = 3, preferredLocation = false, allocation = "0,3|1|2") + testScheduler(numReceivers = 4, preferredLocation = true, allocation = "1|2|3|0") + } + + test("receiver scheduling - some have preferred location") { + val numReceivers = 4; + val receivers: Seq[Receiver[_]] = Seq(new DummyReceiver(host = Some("1")), + new DummyReceiver, new DummyReceiver, new DummyReceiver) + val locations = launcher.scheduleReceivers(receivers, executors) + assert(locations(0)(0) === "1") + assert(locations(1)(0) === "0") + assert(locations(2)(0) === "1") + assert(locations(0).length === 1) + assert(locations(3).length === 1) + } +} + +/** + * Dummy receiver implementation + */ +private class DummyReceiver(host: Option[String] = None) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + + def onStart() { + } + + def onStop() { + } + + override def preferredLocation: Option[String] = host +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 2a0f45830e03c..40dc1fb601bd0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -49,10 +49,12 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map( + 0 -> StreamInputInfo(0, 300L), + 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test"))) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) listener.runningBatches should be (Nil) @@ -64,7 +66,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) @@ -94,7 +96,9 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { batchUIData.get.schedulingDelay should be (batchInfoStarted.schedulingDelay) batchUIData.get.processingDelay should be (batchInfoStarted.processingDelay) batchUIData.get.totalDelay should be (batchInfoStarted.totalDelay) - batchUIData.get.streamIdToNumRecords should be (Map(0 -> 300L, 1 -> 300L)) + batchUIData.get.streamIdToInputInfo should be (Map( + 0 -> StreamInputInfo(0, 300L), + 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test")))) batchUIData.get.numRecords should be(600) batchUIData.get.outputOpIdSparkJobIdPairs should be Seq(OutputOpIdAndSparkJobId(0, 0), @@ -103,7 +107,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { OutputOpIdAndSparkJobId(1, 1)) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (Nil) @@ -141,9 +145,9 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) for(_ <- 0 until (limit + 10)) { listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) @@ -182,7 +186,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { batchUIData.get.schedulingDelay should be (batchInfoSubmitted.schedulingDelay) batchUIData.get.processingDelay should be (batchInfoSubmitted.processingDelay) batchUIData.get.totalDelay should be (batchInfoSubmitted.totalDelay) - batchUIData.get.streamIdToNumRecords should be (Map.empty) + batchUIData.get.streamIdToInputInfo should be (Map.empty) batchUIData.get.numRecords should be (0) batchUIData.get.outputOpIdSparkJobIdPairs should be (Seq(OutputOpIdAndSparkJobId(0, 0))) @@ -211,14 +215,14 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) for (_ <- 0 until 2 * limit) { - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) // onJobStart @@ -235,7 +239,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.onJobStart(jobStart4) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala index 6df1a63ab2e37..d3ca2b58f36c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.streaming.ui +import java.util.TimeZone import java.util.concurrent.TimeUnit -import org.scalatest.FunSuite import org.scalatest.Matchers -class UIUtilsSuite extends FunSuite with Matchers{ +import org.apache.spark.SparkFunSuite + +class UIUtilsSuite extends SparkFunSuite with Matchers{ test("shortTimeUnitString") { assert("ns" === UIUtils.shortTimeUnitString(TimeUnit.NANOSECONDS)) @@ -64,4 +66,14 @@ class UIUtilsSuite extends FunSuite with Matchers{ val convertedTime = UIUtils.convertToTimeUnit(milliseconds, unit) convertedTime should be (expectedTime +- 1E-6) } + + test("formatBatchTime") { + val tzForTest = TimeZone.getTimeZone("America/Los_Angeles") + val batchTime = 1431637480452L // Thu May 14 14:04:40 PDT 2015 + assert("2015/05/14 14:04:40" === UIUtils.formatBatchTime(batchTime, 1000, timezone = tzForTest)) + assert("2015/05/14 14:04:40.452" === + UIUtils.formatBatchTime(batchTime, 999, timezone = tzForTest)) + assert("14:04:40" === UIUtils.formatBatchTime(batchTime, 1000, false, timezone = tzForTest)) + assert("14:04:40.452" === UIUtils.formatBatchTime(batchTime, 999, false, timezone = tzForTest)) + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala index 9ebf7b484f421..78fc344b00177 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.streaming.util import java.io.ByteArrayOutputStream import java.util.concurrent.TimeUnit._ -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class RateLimitedOutputStreamSuite extends FunSuite { +class RateLimitedOutputStreamSuite extends SparkFunSuite { private def benchmark[U](f: => U): Long = { val start = System.nanoTime diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 79098bcf4861c..325ff7c74c39d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -28,15 +28,15 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} -class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { +class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { import WriteAheadLogSuite._ - + val hadoopConf = new Configuration() var tempDir: File = null var testDir: String = null @@ -359,7 +359,7 @@ object WriteAheadLogSuite { ): FileBasedWriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) - + // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => manualClock.advance(500) diff --git a/tools/pom.xml b/tools/pom.xml index 1c6f3e83a1819..feffde4c857eb 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 595ded6ae67fa..9418beb6b3e3a 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off classforname package org.apache.spark.tools import java.io.File @@ -92,7 +93,9 @@ object GenerateMIMAIgnore { ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol) } catch { + // scalastyle:off println case _: Throwable => println("Error instrumenting class:" + className) + // scalastyle:on println } } (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet) @@ -108,7 +111,9 @@ object GenerateMIMAIgnore { .filter(_.contains("$$")).map(classSymbol.fullName + "." + _) } catch { case t: Throwable => + // scalastyle:off println println("[WARN] Unable to detect inner functions for class:" + classSymbol.fullName) + // scalastyle:on println Seq.empty[String] } } @@ -128,12 +133,14 @@ object GenerateMIMAIgnore { getOrElse(Iterator.empty).mkString("\n") File(".generated-mima-class-excludes") .writeAll(previousContents + privateClasses.mkString("\n")) + // scalastyle:off println println("Created : .generated-mima-class-excludes in current directory.") val previousMembersContents = Try(File(".generated-mima-member-excludes").lines) .getOrElse(Iterator.empty).mkString("\n") File(".generated-mima-member-excludes").writeAll(previousMembersContents + privateMembers.mkString("\n")) println("Created : .generated-mima-member-excludes in current directory.") + // scalastyle:on println } @@ -174,9 +181,12 @@ object GenerateMIMAIgnore { try { classes += Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader) } catch { + // scalastyle:off println case _: Throwable => println("Unable to load:" + entry) + // scalastyle:on println } } classes } } +// scalastyle:on classforname diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 583823c90c5c6..856ea177a9a10 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -323,11 +323,14 @@ object JavaAPICompletenessChecker { val missingMethods = javaEquivalents -- javaMethods for (method <- missingMethods) { + // scalastyle:off println println(method) + // scalastyle:on println } } def main(args: Array[String]) { + // scalastyle:off println println("Missing RDD methods") printMissingMethods(classOf[RDD[_]], classOf[JavaRDD[_]]) println() @@ -359,5 +362,6 @@ object JavaAPICompletenessChecker { println("Missing PairDStream methods") printMissingMethods(classOf[PairDStreamFunctions[_, _]], classOf[JavaPairDStream[_, _]]) println() + // scalastyle:on println } } diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index baa97616eaff3..0dc2861253f17 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -85,7 +85,9 @@ object StoragePerfTester { latch.countDown() } catch { case e: Exception => + // scalastyle:off println println("Exception in child thread: " + e + " " + e.getMessage) + // scalastyle:on println System.exit(1) } } @@ -97,9 +99,11 @@ object StoragePerfTester { val bytesPerSecond = totalBytes.get() / time val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong + // scalastyle:off println System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits)) System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile))) System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong))) + // scalastyle:on println executor.shutdown() sc.stop() diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 5b0733206b2bc..33782c6c66f90 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -42,6 +42,10 @@ com.google.code.findbugs jsr305 + + com.google.guava + guava + @@ -61,6 +65,11 @@ junit-interface test + + org.mockito + mockito-core + test + target/scala-${scala.binary.version}/classes @@ -71,7 +80,7 @@ net.alchim31.maven scala-maven-plugin - + -XDignore.symbol.file diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java index 24b2892098059..192c6714b2406 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java @@ -25,8 +25,7 @@ public final class PlatformDependent { /** * Facade in front of {@link sun.misc.Unsafe}, used to avoid directly exposing Unsafe outside of - * this package. This also lets us aovid accidental use of deprecated methods or methods that - * aren't present in Java 6. + * this package. This also lets us avoid accidental use of deprecated methods. */ public static final class UNSAFE { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 53eadf96a6b52..cf693d01a4f5b 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.PlatformDependent; +import static org.apache.spark.unsafe.PlatformDependent.*; public class ByteArrayMethods { @@ -35,21 +35,27 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { } /** - * Optimized byte array equality check for 8-byte-word-aligned byte arrays. + * Optimized byte array equality check for byte arrays. * @return true if the arrays are equal, false otherwise */ - public static boolean wordAlignedArrayEquals( - Object leftBaseObject, - long leftBaseOffset, - Object rightBaseObject, - long rightBaseOffset, - long arrayLengthInBytes) { - for (int i = 0; i < arrayLengthInBytes; i += 8) { - final long left = - PlatformDependent.UNSAFE.getLong(leftBaseObject, leftBaseOffset + i); - final long right = - PlatformDependent.UNSAFE.getLong(rightBaseObject, rightBaseOffset + i); - if (left != right) return false; + public static boolean arrayEquals( + Object leftBase, + long leftOffset, + Object rightBase, + long rightOffset, + final long length) { + int i = 0; + while (i <= length - 8) { + if (UNSAFE.getLong(leftBase, leftOffset + i) != UNSAFE.getLong(rightBase, rightOffset + i)) { + return false; + } + i += 8; + } + while (i < length) { + if (UNSAFE.getByte(leftBase, leftOffset + i) != UNSAFE.getByte(rightBase, rightOffset + i)) { + return false; + } + i += 1; } return true; } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java index 28e23da108ebe..7c124173b0bbb 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java @@ -90,7 +90,7 @@ public boolean isSet(int index) { * To iterate over the true bits in a BitSet, use the following loop: *
        * 
    -   *  for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
    +   *  for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
        *    // operate on index i here
        *  }
        * 
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
    index 0987191c1c636..27462c7fa5e62 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
    @@ -87,7 +87,7 @@ public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidt
        * To iterate over the true bits in a BitSet, use the following loop:
        * 
        * 
    -   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
    +   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
        *    // operate on index i here
        *  }
        * 
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
    index 19d6a169fd2ad..d0bde69cc1068 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
    @@ -23,6 +23,8 @@
     import java.util.LinkedList;
     import java.util.List;
     
    +import com.google.common.annotations.VisibleForTesting;
    +
     import org.apache.spark.unsafe.*;
     import org.apache.spark.unsafe.array.ByteArrayMethods;
     import org.apache.spark.unsafe.array.LongArray;
    @@ -36,9 +38,8 @@
      * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers,
      * which is guaranteed to exhaust the space.
      * 

    - * The map can support up to 2^31 keys because we use 32 bit MurmurHash. If the key cardinality is - * higher than this, you should probably be using sorting instead of hashing for better cache - * locality. + * The map can support up to 2^29 keys. If the key cardinality is higher than this, you should + * probably be using sorting instead of hashing for better cache locality. *

    * This class is not thread safe. */ @@ -48,6 +49,11 @@ public final class BytesToBytesMap { private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; + /** + * Special record length that is placed after the last record in a data page. + */ + private static final int END_OF_PAGE_MARKER = -1; + private final TaskMemoryManager memoryManager; /** @@ -64,7 +70,7 @@ public final class BytesToBytesMap { /** * Offset into `currentDataPage` that points to the location where new data can be inserted into - * the page. + * the page. This does not incorporate the page's base offset. */ private long pageCursor = 0; @@ -74,6 +80,15 @@ public final class BytesToBytesMap { */ private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + /** + * The maximum number of keys that BytesToBytesMap supports. The hash table has to be + * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since + * that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array + * entries per key, giving us a maximum capacity of (1 << 29). + */ + @VisibleForTesting + static final int MAX_CAPACITY = (1 << 29); + // This choice of page table size and page size means that we can address up to 500 gigabytes // of memory. @@ -143,6 +158,13 @@ public BytesToBytesMap( this.loadFactor = loadFactor; this.loc = new Location(); this.enablePerfMetrics = enablePerfMetrics; + if (initialCapacity <= 0) { + throw new IllegalArgumentException("Initial capacity must be greater than 0"); + } + if (initialCapacity > MAX_CAPACITY) { + throw new IllegalArgumentException( + "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY); + } allocate(initialCapacity); } @@ -162,6 +184,55 @@ public BytesToBytesMap( */ public int size() { return size; } + private static final class BytesToBytesMapIterator implements Iterator { + + private final int numRecords; + private final Iterator dataPagesIterator; + private final Location loc; + + private int currentRecordNumber = 0; + private Object pageBaseObject; + private long offsetInPage; + + BytesToBytesMapIterator(int numRecords, Iterator dataPagesIterator, Location loc) { + this.numRecords = numRecords; + this.dataPagesIterator = dataPagesIterator; + this.loc = loc; + if (dataPagesIterator.hasNext()) { + advanceToNextPage(); + } + } + + private void advanceToNextPage() { + final MemoryBlock currentPage = dataPagesIterator.next(); + pageBaseObject = currentPage.getBaseObject(); + offsetInPage = currentPage.getBaseOffset(); + } + + @Override + public boolean hasNext() { + return currentRecordNumber != numRecords; + } + + @Override + public Location next() { + int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); + if (keyLength == END_OF_PAGE_MARKER) { + advanceToNextPage(); + keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); + } + loc.with(pageBaseObject, offsetInPage); + offsetInPage += 8 + 8 + keyLength + loc.getValueLength(); + currentRecordNumber++; + return loc; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } + /** * Returns an iterator for iterating over the entries of this map. * @@ -171,27 +242,7 @@ public BytesToBytesMap( * `lookup()`, the behavior of the returned iterator is undefined. */ public Iterator iterator() { - return new Iterator() { - - private int nextPos = bitset.nextSetBit(0); - - @Override - public boolean hasNext() { - return nextPos != -1; - } - - @Override - public Location next() { - final int pos = nextPos; - nextPos = bitset.nextSetBit(nextPos + 1); - return loc.with(pos, 0, true); - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - }; + return new BytesToBytesMapIterator(size, dataPages.iterator(), loc); } /** @@ -226,7 +277,7 @@ public Location lookup( final MemoryLocation keyAddress = loc.getKeyAddress(); final Object storedKeyBaseObject = keyAddress.getBaseObject(); final long storedKeyBaseOffset = keyAddress.getBaseOffset(); - final boolean areEqual = ByteArrayMethods.wordAlignedArrayEquals( + final boolean areEqual = ByteArrayMethods.arrayEquals( keyBaseObject, keyBaseOffset, storedKeyBaseObject, @@ -268,8 +319,11 @@ public final class Location { private int valueLength; private void updateAddressesAndSizes(long fullKeyAddress) { - final Object page = memoryManager.getPage(fullKeyAddress); - final long keyOffsetInPage = memoryManager.getOffsetInPage(fullKeyAddress); + updateAddressesAndSizes( + memoryManager.getPage(fullKeyAddress), memoryManager.getOffsetInPage(fullKeyAddress)); + } + + private void updateAddressesAndSizes(Object page, long keyOffsetInPage) { long position = keyOffsetInPage; keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position); position += 8; // word used to store the key size @@ -291,6 +345,12 @@ Location with(int pos, int keyHashcode, boolean isDefined) { return this; } + Location with(Object page, long keyOffsetInPage) { + this.isDefined = true; + updateAddressesAndSizes(page, keyOffsetInPage); + return this; + } + /** * Returns true if the key is defined at this position, and false otherwise. */ @@ -344,12 +404,17 @@ public int getValueLength() { * at the value address. *

    * It is only valid to call this method immediately after calling `lookup()` using the same key. + *

    + *

    + * The key and value must be word-aligned (that is, their sizes must multiples of 8). + *

    *

    * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` * will return information on the data stored by this `putNewKey` call. + *

    *

    * As an example usage, here's the proper way to store a new key: - *

    + *

    *
          *   Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
          *   if (!loc.isDefined()) {
    @@ -358,6 +423,7 @@ public int getValueLength() {
          * 
    *

    * Unspecified behavior if the key is not defined. + *

    */ public void putNewKey( Object keyBaseObject, @@ -367,20 +433,29 @@ public void putNewKey( long valueBaseOffset, int valueLengthBytes) { assert (!isDefined) : "Can only set value once for a key"; - isDefined = true; assert (keyLengthBytes % 8 == 0); assert (valueLengthBytes % 8 == 0); + if (size == MAX_CAPACITY) { + throw new IllegalStateException("BytesToBytesMap has reached maximum capacity"); + } // Here, we'll copy the data into our data pages. Because we only store a relative offset from // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. // (8 byte key length) (key) (8 byte value length) (value) final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; - assert(requiredSize <= PAGE_SIZE_BYTES); + assert (requiredSize <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker. size++; bitset.set(pos); - // If there's not enough space in the current page, allocate a new page: - if (currentDataPage == null || PAGE_SIZE_BYTES - pageCursor < requiredSize) { + // If there's not enough space in the current page, allocate a new page (8 bytes are reserved + // for the end-of-page marker). + if (currentDataPage == null || PAGE_SIZE_BYTES - 8 - pageCursor < requiredSize) { + if (currentDataPage != null) { + // There wasn't enough space in the current page, so write an end-of-page marker: + final Object pageBaseObject = currentDataPage.getBaseObject(); + final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; + PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + } MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES); dataPages.add(newPage); pageCursor = 0; @@ -414,7 +489,7 @@ public void putNewKey( longArray.set(pos * 2 + 1, keyHashcode); updateAddressesAndSizes(storedKeyAddress); isDefined = true; - if (size > growthThreshold) { + if (size > growthThreshold && longArray.size() < MAX_CAPACITY) { growAndRehash(); } } @@ -427,8 +502,11 @@ public void putNewKey( * @param capacity the new map capacity */ private void allocate(int capacity) { - capacity = Math.max((int) Math.min(Integer.MAX_VALUE, nextPowerOf2(capacity)), 64); - longArray = new LongArray(memoryManager.allocate(capacity * 8 * 2)); + assert (capacity >= 0); + // The capacity needs to be divisible by 64 so that our bit set can be sized properly + capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64); + assert (capacity <= MAX_CAPACITY); + longArray = new LongArray(memoryManager.allocate(capacity * 8L * 2)); bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); this.growthThreshold = (int) (capacity * loadFactor); @@ -494,10 +572,16 @@ public long getNumHashCollisions() { return numHashCollisions; } + @VisibleForTesting + int getNumDataPages() { + return dataPages.size(); + } + /** * Grows the size of the hash table and re-hash everything. */ - private void growAndRehash() { + @VisibleForTesting + void growAndRehash() { long resizeStartTime = -1; if (enablePerfMetrics) { resizeStartTime = System.nanoTime(); @@ -508,7 +592,7 @@ private void growAndRehash() { final int oldCapacity = (int) oldBitSet.capacity(); // Allocate the new data structures - allocate(Math.min(Integer.MAX_VALUE, growthStrategy.nextCapacity(oldCapacity))); + allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY)); // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it) for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java index 7c321baffe82d..20654e4eeaa02 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java @@ -32,7 +32,9 @@ public interface HashMapGrowthStrategy { class Doubling implements HashMapGrowthStrategy { @Override public int nextCapacity(int currentCapacity) { - return currentCapacity * 2; + assert (currentCapacity > 0); + // Guard against overflow + return (currentCapacity * 2 > 0) ? (currentCapacity * 2) : Integer.MAX_VALUE; } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java index 62c29c8cc1e4d..cbbe8594627a5 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java @@ -17,6 +17,12 @@ package org.apache.spark.unsafe.memory; +import java.lang.ref.WeakReference; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; +import javax.annotation.concurrent.GuardedBy; + /** * Manages memory for an executor. Individual operators / tasks allocate memory through * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager. @@ -33,6 +39,12 @@ public class ExecutorMemoryManager { */ final boolean inHeap; + @GuardedBy("this") + private final Map>> bufferPoolsBySize = + new HashMap>>(); + + private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; + /** * Construct a new ExecutorMemoryManager. * @@ -43,16 +55,57 @@ public ExecutorMemoryManager(MemoryAllocator allocator) { this.allocator = allocator; } + /** + * Returns true if allocations of the given size should go through the pooling mechanism and + * false otherwise. + */ + private boolean shouldPool(long size) { + // Very small allocations are less likely to benefit from pooling. + // At some point, we should explore supporting pooling for off-heap memory, but for now we'll + // ignore that case in the interest of simplicity. + return size >= POOLING_THRESHOLD_BYTES && allocator instanceof HeapMemoryAllocator; + } + /** * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed * to be zeroed out (call `zero()` on the result if this is necessary). */ MemoryBlock allocate(long size) throws OutOfMemoryError { - return allocator.allocate(size); + if (shouldPool(size)) { + synchronized (this) { + final LinkedList> pool = bufferPoolsBySize.get(size); + if (pool != null) { + while (!pool.isEmpty()) { + final WeakReference blockReference = pool.pop(); + final MemoryBlock memory = blockReference.get(); + if (memory != null) { + assert (memory.size() == size); + return memory; + } + } + bufferPoolsBySize.remove(size); + } + } + return allocator.allocate(size); + } else { + return allocator.allocate(size); + } } void free(MemoryBlock memory) { - allocator.free(memory); + final long size = memory.size(); + if (shouldPool(size)) { + synchronized (this) { + LinkedList> pool = bufferPoolsBySize.get(size); + if (pool == null) { + pool = new LinkedList>(); + bufferPoolsBySize.put(size, pool); + } + pool.add(new WeakReference(memory)); + } + } else { + allocator.free(memory); + } } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 9224988e6ad69..10881969dbc78 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -19,6 +19,7 @@ import java.util.*; +import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,14 +44,22 @@ * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is * approximately 35 terabytes of memory. */ -public final class TaskMemoryManager { +public class TaskMemoryManager { private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); - /** - * The number of entries in the page table. - */ - private static final int PAGE_TABLE_SIZE = 1 << 13; + /** The number of bits used to address the page table. */ + private static final int PAGE_NUMBER_BITS = 13; + + /** The number of bits used to encode offsets in data pages. */ + @VisibleForTesting + static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51 + + /** The number of entries in the page table. */ + private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; + + /** Maximum supported data page size */ + private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS); /** Bit mask for the lower 51 bits of a long. */ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; @@ -101,11 +110,9 @@ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { * intended for allocating large blocks of memory that will be shared between operators. */ public MemoryBlock allocatePage(long size) { - if (logger.isTraceEnabled()) { - logger.trace("Allocating {} byte page", size); - } - if (size >= (1L << 51)) { - throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes"); + if (size > MAXIMUM_PAGE_SIZE) { + throw new IllegalArgumentException( + "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes"); } final int pageNumber; @@ -120,8 +127,8 @@ public MemoryBlock allocatePage(long size) { final MemoryBlock page = executorMemoryManager.allocate(size); page.pageNumber = pageNumber; pageTable[pageNumber] = page; - if (logger.isDebugEnabled()) { - logger.debug("Allocate page number {} ({} bytes)", pageNumber, size); + if (logger.isTraceEnabled()) { + logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); } return page; } @@ -130,9 +137,6 @@ public MemoryBlock allocatePage(long size) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. */ public void freePage(MemoryBlock page) { - if (logger.isTraceEnabled()) { - logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size()); - } assert (page.pageNumber != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; executorMemoryManager.free(page); @@ -140,8 +144,8 @@ public void freePage(MemoryBlock page) { allocatedPages.clear(page.pageNumber); } pageTable[page.pageNumber] = null; - if (logger.isDebugEnabled()) { - logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + if (logger.isTraceEnabled()) { + logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } } @@ -173,14 +177,36 @@ public void free(MemoryBlock memory) { /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. + * + * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}. + * @param offsetInPage an offset in this page which incorporates the base offset. In other words, + * this should be the value that you would pass as the base offset into an + * UNSAFE call (e.g. page.baseOffset() + something). + * @return an encoded page address. */ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { - if (inHeap) { - assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; - return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); - } else { - return offsetInPage; + if (!inHeap) { + // In off-heap mode, an offset is an absolute address that may require a full 64 bits to + // encode. Due to our page size limitation, though, we can convert this into an offset that's + // relative to the page's base offset; this relative offset will fit in 51 bits. + offsetInPage -= page.getBaseOffset(); } + return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + } + + @VisibleForTesting + public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { + assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } + + @VisibleForTesting + public static int decodePageNumber(long pagePlusOffsetAddress) { + return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); + } + + private static long decodeOffset(long pagePlusOffsetAddress) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); } /** @@ -189,7 +215,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { */ public Object getPage(long pagePlusOffsetAddress) { if (inHeap) { - final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51); + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); final Object page = pageTable[pageNumber].getBaseObject(); assert (page != null); @@ -204,10 +230,15 @@ public Object getPage(long pagePlusOffsetAddress) { * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public long getOffsetInPage(long pagePlusOffsetAddress) { + final long offsetInPage = decodeOffset(pagePlusOffsetAddress); if (inHeap) { - return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + return offsetInPage; } else { - return pagePlusOffsetAddress; + // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we + // converted the absolute address into a relative address. Here, we invert that operation: + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + return pageTable[pageNumber].getBaseOffset() + offsetInPage; } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java new file mode 100644 index 0000000000000..eb7475e9df869 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import java.io.Serializable; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * The internal representation of interval type. + */ +public final class Interval implements Serializable { + public static final long MICROS_PER_MILLI = 1000L; + public static final long MICROS_PER_SECOND = MICROS_PER_MILLI * 1000; + public static final long MICROS_PER_MINUTE = MICROS_PER_SECOND * 60; + public static final long MICROS_PER_HOUR = MICROS_PER_MINUTE * 60; + public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24; + public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7; + + /** + * A function to generate regex which matches interval string's unit part like "3 years". + * + * First, we can leave out some units in interval string, and we only care about the value of + * unit, so here we use non-capturing group to wrap the actual regex. + * At the beginning of the actual regex, we should match spaces before the unit part. + * Next is the number part, starts with an optional "-" to represent negative value. We use + * capturing group to wrap this part as we need the value later. + * Finally is the unit name, ends with an optional "s". + */ + private static String unitRegex(String unit) { + return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?"; + } + + private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") + + unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") + + unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond")); + + private static long toLong(String s) { + if (s == null) { + return 0; + } else { + return Long.valueOf(s); + } + } + + public static Interval fromString(String s) { + if (s == null) { + return null; + } + Matcher m = p.matcher(s); + if (!m.matches() || s.equals("interval")) { + return null; + } else { + long months = toLong(m.group(1)) * 12 + toLong(m.group(2)); + long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK; + microseconds += toLong(m.group(4)) * MICROS_PER_DAY; + microseconds += toLong(m.group(5)) * MICROS_PER_HOUR; + microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE; + microseconds += toLong(m.group(7)) * MICROS_PER_SECOND; + microseconds += toLong(m.group(8)) * MICROS_PER_MILLI; + microseconds += toLong(m.group(9)); + return new Interval((int) months, microseconds); + } + } + + public final int months; + public final long microseconds; + + public Interval(int months, long microseconds) { + this.months = months; + this.microseconds = microseconds; + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || !(other instanceof Interval)) return false; + + Interval o = (Interval) other; + return this.months == o.months && this.microseconds == o.microseconds; + } + + @Override + public int hashCode() { + return 31 * months + (int) microseconds; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("interval"); + + if (months != 0) { + appendUnit(sb, months / 12, "year"); + appendUnit(sb, months % 12, "month"); + } + + if (microseconds != 0) { + long rest = microseconds; + appendUnit(sb, rest / MICROS_PER_WEEK, "week"); + rest %= MICROS_PER_WEEK; + appendUnit(sb, rest / MICROS_PER_DAY, "day"); + rest %= MICROS_PER_DAY; + appendUnit(sb, rest / MICROS_PER_HOUR, "hour"); + rest %= MICROS_PER_HOUR; + appendUnit(sb, rest / MICROS_PER_MINUTE, "minute"); + rest %= MICROS_PER_MINUTE; + appendUnit(sb, rest / MICROS_PER_SECOND, "second"); + rest %= MICROS_PER_SECOND; + appendUnit(sb, rest / MICROS_PER_MILLI, "millisecond"); + rest %= MICROS_PER_MILLI; + appendUnit(sb, rest, "microsecond"); + } + + return sb.toString(); + } + + private void appendUnit(StringBuilder sb, long value, String unit) { + if (value != 0) { + sb.append(" " + value + " " + unit + "s"); + } + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java new file mode 100644 index 0000000000000..e7f9fbb2bc682 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -0,0 +1,518 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import javax.annotation.Nonnull; +import java.io.Serializable; +import java.io.UnsupportedEncodingException; + +import org.apache.spark.unsafe.array.ByteArrayMethods; + +import static org.apache.spark.unsafe.PlatformDependent.*; + + +/** + * A UTF-8 String for internal Spark use. + *

    + * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, + * search, see http://en.wikipedia.org/wiki/UTF-8 for details. + *

    + * Note: This is not designed for general use cases, should not be used outside SQL. + */ +public final class UTF8String implements Comparable, Serializable { + + @Nonnull + private final Object base; + private final long offset; + private final int numBytes; + + private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, + 6, 6, 6, 6}; + + /** + * Creates an UTF8String from byte array, which should be encoded in UTF-8. + * + * Note: `bytes` will be hold by returned UTF8String. + */ + public static UTF8String fromBytes(byte[] bytes) { + if (bytes != null) { + return new UTF8String(bytes, BYTE_ARRAY_OFFSET, bytes.length); + } else { + return null; + } + } + + /** + * Creates an UTF8String from String. + */ + public static UTF8String fromString(String str) { + if (str == null) return null; + try { + return fromBytes(str.getBytes("utf-8")); + } catch (UnsupportedEncodingException e) { + // Turn the exception into unchecked so we can find out about it at runtime, but + // don't need to add lots of boilerplate code everywhere. + throwException(e); + return null; + } + } + + protected UTF8String(Object base, long offset, int size) { + this.base = base; + this.offset = offset; + this.numBytes = size; + } + + /** + * Returns the number of bytes for a code point with the first byte as `b` + * @param b The first byte of a code point + */ + private static int numBytesForFirstByte(final byte b) { + final int offset = (b & 0xFF) - 192; + return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1; + } + + /** + * Returns the number of bytes + */ + public int numBytes() { + return numBytes; + } + + /** + * Returns the number of code points in it. + */ + public int numChars() { + int len = 0; + for (int i = 0; i < numBytes; i += numBytesForFirstByte(getByte(i))) { + len += 1; + } + return len; + } + + /** + * Returns the underline bytes, will be a copy of it if it's part of another array. + */ + public byte[] getBytes() { + // avoid copy if `base` is `byte[]` + if (offset == BYTE_ARRAY_OFFSET && base instanceof byte[] + && ((byte[]) base).length == numBytes) { + return (byte[]) base; + } else { + byte[] bytes = new byte[numBytes]; + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + return bytes; + } + } + + /** + * Returns a substring of this. + * @param start the position of first code point + * @param until the position after last code point, exclusive. + */ + public UTF8String substring(final int start, final int until) { + if (until <= start || start >= numBytes) { + return fromBytes(new byte[0]); + } + + int i = 0; + int c = 0; + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + int j = i; + while (i < numBytes && c < until) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + byte[] bytes = new byte[i - j]; + copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); + return fromBytes(bytes); + } + + /** + * Returns whether this contains `substring` or not. + */ + public boolean contains(final UTF8String substring) { + if (substring.numBytes == 0) { + return true; + } + + byte first = substring.getByte(0); + for (int i = 0; i <= numBytes - substring.numBytes; i++) { + if (getByte(i) == first && matchAt(substring, i)) { + return true; + } + } + return false; + } + + /** + * Returns the byte at position `i`. + */ + private byte getByte(int i) { + return UNSAFE.getByte(base, offset + i); + } + + private boolean matchAt(final UTF8String s, int pos) { + if (s.numBytes + pos > numBytes || pos < 0) { + return false; + } + return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); + } + + public boolean startsWith(final UTF8String prefix) { + return matchAt(prefix, 0); + } + + public boolean endsWith(final UTF8String suffix) { + return matchAt(suffix, numBytes - suffix.numBytes); + } + + /** + * Returns the upper case of this string + */ + public UTF8String toUpperCase() { + return fromString(toString().toUpperCase()); + } + + /** + * Returns the lower case of this string + */ + public UTF8String toLowerCase() { + return fromString(toString().toLowerCase()); + } + + /** + * Copy the bytes from the current UTF8String, and make a new UTF8String. + * @param start the start position of the current UTF8String in bytes. + * @param end the end position of the current UTF8String in bytes. + * @return a new UTF8String in the position of [start, end] of current UTF8String bytes. + */ + private UTF8String copyUTF8String(int start, int end) { + int len = end - start + 1; + byte[] newBytes = new byte[len]; + copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); + return UTF8String.fromBytes(newBytes); + } + + public UTF8String trim() { + int s = 0; + int e = this.numBytes - 1; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) == 0x20) s++; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) == 0x20) e--; + + if (s > e) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, e); + } + } + + public UTF8String trimLeft() { + int s = 0; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) == 0x20) s++; + if (s == this.numBytes) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, this.numBytes - 1); + } + } + + public UTF8String trimRight() { + int e = numBytes - 1; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) == 0x20) e--; + + if (e < 0) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(0, e); + } + } + + public UTF8String reverse() { + byte[] bytes = getBytes(); + byte[] result = new byte[bytes.length]; + + int i = 0; // position in byte + while (i < numBytes) { + int len = numBytesForFirstByte(getByte(i)); + System.arraycopy(bytes, i, result, result.length - i - len, len); + + i += len; + } + + return UTF8String.fromBytes(result); + } + + public UTF8String repeat(int times) { + if (times <=0) { + return fromBytes(new byte[0]); + } + + byte[] newBytes = new byte[numBytes * times]; + System.arraycopy(getBytes(), 0, newBytes, 0, numBytes); + + int copied = 1; + while (copied < times) { + int toCopy = Math.min(copied, times - copied); + System.arraycopy(newBytes, 0, newBytes, copied * numBytes, numBytes * toCopy); + copied += toCopy; + } + + return UTF8String.fromBytes(newBytes); + } + + /** + * Returns the position of the first occurrence of substr in + * current string from the specified position (0-based index). + * + * @param v the string to be searched + * @param start the start position of the current string for searching + * @return the position of the first occurrence of substr, if not found, -1 returned. + */ + public int indexOf(UTF8String v, int start) { + if (v.numBytes() == 0) { + return 0; + } + + // locate to the start position. + int i = 0; // position in byte + int c = 0; // position in character + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + do { + if (i + v.numBytes > numBytes) { + return -1; + } + if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { + return c; + } + i += numBytesForFirstByte(getByte(i)); + c += 1; + } while(i < numBytes); + + return -1; + } + + /** + * Returns str, right-padded with pad to a length of len + * For example: + * ('hi', 5, '??') => 'hi???' + * ('hi', 1, '??') => 'h' + */ + public UTF8String rpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + System.arraycopy(getBytes(), 0, data, 0, this.numBytes); + int offset = this.numBytes; + int idx = 0; + byte[] padBytes = pad.getBytes(); + while (idx < count) { + System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + ++idx; + offset += pad.numBytes; + } + System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + + return UTF8String.fromBytes(data); + } + } + + /** + * Returns str, left-padded with pad to a length of len. + * For example: + * ('hi', 5, '??') => '???hi' + * ('hi', 1, '??') => 'h' + */ + public UTF8String lpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + + int offset = 0; + int idx = 0; + byte[] padBytes = pad.getBytes(); + while (idx < count) { + System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + ++idx; + offset += pad.numBytes; + } + System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + offset += remain.numBytes; + System.arraycopy(getBytes(), 0, data, offset, numBytes()); + + return UTF8String.fromBytes(data); + } + } + + @Override + public String toString() { + try { + return new String(getBytes(), "utf-8"); + } catch (UnsupportedEncodingException e) { + // Turn the exception into unchecked so we can find out about it at runtime, but + // don't need to add lots of boilerplate code everywhere. + throwException(e); + return "unknown"; // we will never reach here. + } + } + + @Override + public UTF8String clone() { + return fromBytes(getBytes()); + } + + @Override + public int compareTo(final UTF8String other) { + int len = Math.min(numBytes, other.numBytes); + // TODO: compare 8 bytes as unsigned long + for (int i = 0; i < len; i ++) { + // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. + int res = (getByte(i) & 0xFF) - (other.getByte(i) & 0xFF); + if (res != 0) { + return res; + } + } + return numBytes - other.numBytes; + } + + public int compare(final UTF8String other) { + return compareTo(other); + } + + @Override + public boolean equals(final Object other) { + if (other instanceof UTF8String) { + UTF8String o = (UTF8String) other; + if (numBytes != o.numBytes){ + return false; + } + return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); + } else { + return false; + } + } + + /** + * Levenshtein distance is a metric for measuring the distance of two strings. The distance is + * defined by the minimum number of single-character edits (i.e. insertions, deletions or + * substitutions) that are required to change one of the strings into the other. + */ + public int levenshteinDistance(UTF8String other) { + // Implementation adopted from org.apache.common.lang3.StringUtils.getLevenshteinDistance + + int n = numChars(); + int m = other.numChars(); + + if (n == 0) { + return m; + } else if (m == 0) { + return n; + } + + UTF8String s, t; + + if (n <= m) { + s = this; + t = other; + } else { + s = other; + t = this; + int swap; + swap = n; + n = m; + m = swap; + } + + int p[] = new int[n + 1]; + int d[] = new int[n + 1]; + int swap[]; + + int i, i_bytes, j, j_bytes, num_bytes_j, cost; + + for (i = 0; i <= n; i++) { + p[i] = i; + } + + for (j = 0, j_bytes = 0; j < m; j_bytes += num_bytes_j, j++) { + num_bytes_j = numBytesForFirstByte(t.getByte(j_bytes)); + d[0] = j + 1; + + for (i = 0, i_bytes = 0; i < n; i_bytes += numBytesForFirstByte(s.getByte(i_bytes)), i++) { + if (s.getByte(i_bytes) != t.getByte(j_bytes) || + num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { + cost = 1; + } else { + cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, + s.offset + i_bytes, num_bytes_j)) ? 0 : 1; + } + d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); + } + + swap = p; + p = d; + d = swap; + } + + return p[n]; + } + + @Override + public int hashCode() { + int result = 1; + for (int i = 0; i < numBytes; i ++) { + result = 31 * result + getByte(i); + } + return result; + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java index 18393db9f382f..a93fc0ee297c4 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java @@ -18,7 +18,6 @@ package org.apache.spark.unsafe.bitset; import junit.framework.Assert; -import org.apache.spark.unsafe.bitset.BitSet; import org.junit.Test; import org.apache.spark.unsafe.memory.MemoryBlock; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 7a5c0622d1ffb..dae47e4bab0cb 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -25,24 +25,40 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.mockito.AdditionalMatchers.geq; +import static org.mockito.Mockito.*; import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.memory.*; import org.apache.spark.unsafe.PlatformDependent; import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import static org.apache.spark.unsafe.PlatformDependent.LONG_ARRAY_OFFSET; + public abstract class AbstractBytesToBytesMapSuite { private final Random rand = new Random(42); private TaskMemoryManager memoryManager; + private TaskMemoryManager sizeLimitedMemoryManager; @Before public void setup() { memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); + // Mocked memory manager for tests that check the maximum array size, since actually allocating + // such large arrays will cause us to run out of memory in our tests. + sizeLimitedMemoryManager = spy(memoryManager); + when(sizeLimitedMemoryManager.allocate(geq(1L << 20))).thenAnswer(new Answer() { + @Override + public MemoryBlock answer(InvocationOnMock invocation) throws Throwable { + if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) { + throw new OutOfMemoryError("Requested array size exceeds VM limit"); + } + return memoryManager.allocate(1L << 20); + } + }); } @After @@ -83,7 +99,7 @@ private static boolean arrayEquals( byte[] expected, MemoryLocation actualAddr, long actualLengthBytes) { - return (actualLengthBytes == expected.length) && ByteArrayMethods.wordAlignedArrayEquals( + return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals( expected, BYTE_ARRAY_OFFSET, actualAddr.getBaseObject(), @@ -101,6 +117,7 @@ public void emptyMap() { final int keyLengthInBytes = keyLengthInWords * 8; final byte[] key = getRandomByteArray(keyLengthInWords); Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); + Assert.assertFalse(map.iterator().hasNext()); } finally { map.free(); } @@ -159,7 +176,7 @@ public void setAndRetrieveAKey() { @Test public void iteratorTest() throws Exception { - final int size = 128; + final int size = 4096; BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2); try { for (long i = 0; i < size; i++) { @@ -167,14 +184,26 @@ public void iteratorTest() throws Exception { final BytesToBytesMap.Location loc = map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8); Assert.assertFalse(loc.isDefined()); - loc.putNewKey( - value, - PlatformDependent.LONG_ARRAY_OFFSET, - 8, - value, - PlatformDependent.LONG_ARRAY_OFFSET, - 8 - ); + // Ensure that we store some zero-length keys + if (i % 5 == 0) { + loc.putNewKey( + null, + PlatformDependent.LONG_ARRAY_OFFSET, + 0, + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8 + ); + } else { + loc.putNewKey( + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8, + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8 + ); + } } final java.util.BitSet valuesSeen = new java.util.BitSet(size); final Iterator iter = map.iterator(); @@ -183,11 +212,16 @@ public void iteratorTest() throws Exception { Assert.assertTrue(loc.isDefined()); final MemoryLocation keyAddress = loc.getKeyAddress(); final MemoryLocation valueAddress = loc.getValueAddress(); - final long key = PlatformDependent.UNSAFE.getLong( - keyAddress.getBaseObject(), keyAddress.getBaseOffset()); final long value = PlatformDependent.UNSAFE.getLong( valueAddress.getBaseObject(), valueAddress.getBaseOffset()); - Assert.assertEquals(key, value); + final long keyLength = loc.getKeyLength(); + if (keyLength == 0) { + Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); + } else { + final long key = PlatformDependent.UNSAFE.getLong( + keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + Assert.assertEquals(value, key); + } valuesSeen.set((int) value); } Assert.assertEquals(size, valuesSeen.cardinality()); @@ -196,6 +230,74 @@ public void iteratorTest() throws Exception { } } + @Test + public void iteratingOverDataPagesWithWastedSpace() throws Exception { + final int NUM_ENTRIES = 1000 * 1000; + final int KEY_LENGTH = 16; + final int VALUE_LENGTH = 40; + final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES); + // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte + // pages won't be evenly-divisible by records of this size, which will cause us to waste some + // space at the end of the page. This is necessary in order for us to take the end-of-record + // handling branch in iterator(). + try { + for (int i = 0; i < NUM_ENTRIES; i++) { + final long[] key = new long[] { i, i }; // 2 * 8 = 16 bytes + final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes + final BytesToBytesMap.Location loc = map.lookup( + key, + LONG_ARRAY_OFFSET, + KEY_LENGTH + ); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + key, + LONG_ARRAY_OFFSET, + KEY_LENGTH, + value, + LONG_ARRAY_OFFSET, + VALUE_LENGTH + ); + } + Assert.assertEquals(2, map.getNumDataPages()); + + final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES); + final Iterator iter = map.iterator(); + final long key[] = new long[KEY_LENGTH / 8]; + final long value[] = new long[VALUE_LENGTH / 8]; + while (iter.hasNext()) { + final BytesToBytesMap.Location loc = iter.next(); + Assert.assertTrue(loc.isDefined()); + Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); + Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); + PlatformDependent.copyMemory( + loc.getKeyAddress().getBaseObject(), + loc.getKeyAddress().getBaseOffset(), + key, + LONG_ARRAY_OFFSET, + KEY_LENGTH + ); + PlatformDependent.copyMemory( + loc.getValueAddress().getBaseObject(), + loc.getValueAddress().getBaseOffset(), + value, + LONG_ARRAY_OFFSET, + VALUE_LENGTH + ); + for (long j : key) { + Assert.assertEquals(key[0], j); + } + for (long j : value) { + Assert.assertEquals(key[0], j); + } + valuesSeen.set((int) key[0]); + } + Assert.assertEquals(NUM_ENTRIES, valuesSeen.cardinality()); + } finally { + map.free(); + } + } + @Test public void randomizedStressTest() { final int size = 65536; @@ -247,4 +349,35 @@ public void randomizedStressTest() { map.free(); } } + + @Test + public void initialCapacityBoundsChecking() { + try { + new BytesToBytesMap(sizeLimitedMemoryManager, 0); + Assert.fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // expected exception + } + + try { + new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1); + Assert.fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // expected exception + } + + // Can allocate _at_ the max capacity + BytesToBytesMap map = + new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY); + map.free(); + } + + @Test + public void resizingLargeMap() { + // As long as a map's capacity is below the max, we should be able to resize up to the max + BytesToBytesMap map = + new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64); + map.growAndRehash(); + map.free(); + } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java index 932882f1ca248..06fb081183659 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java @@ -38,4 +38,27 @@ public void leakedPageMemoryIsDetected() { Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); } + @Test + public void encodePageNumberAndOffsetOffHeap() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); + final MemoryBlock dataPage = manager.allocatePage(256); + // In off-heap mode, an offset is an absolute address that may require more than 51 bits to + // encode. This test exercises that corner-case: + final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset); + Assert.assertEquals(null, manager.getPage(encodedAddress)); + Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress)); + } + + @Test + public void encodePageNumberAndOffsetOnHeap() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = manager.allocatePage(256); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); + Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); + Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); + } + } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java new file mode 100644 index 0000000000000..44a949a371f2b --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -0,0 +1,105 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.unsafe.types; + +import org.junit.Test; + +import static junit.framework.Assert.*; +import static org.apache.spark.unsafe.types.Interval.*; + +public class IntervalSuite { + + @Test + public void equalsTest() { + Interval i1 = new Interval(3, 123); + Interval i2 = new Interval(3, 321); + Interval i3 = new Interval(1, 123); + Interval i4 = new Interval(3, 123); + + assertNotSame(i1, i2); + assertNotSame(i1, i3); + assertNotSame(i2, i3); + assertEquals(i1, i4); + } + + @Test + public void toStringTest() { + Interval i; + + i = new Interval(34, 0); + assertEquals(i.toString(), "interval 2 years 10 months"); + + i = new Interval(-34, 0); + assertEquals(i.toString(), "interval -2 years -10 months"); + + i = new Interval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + assertEquals(i.toString(), "interval 3 weeks 13 hours 123 microseconds"); + + i = new Interval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); + assertEquals(i.toString(), "interval -3 weeks -13 hours -123 microseconds"); + + i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds"); + } + + @Test + public void fromStringTest() { + testSingleUnit("year", 3, 36, 0); + testSingleUnit("month", 3, 3, 0); + testSingleUnit("week", 3, 0, 3 * MICROS_PER_WEEK); + testSingleUnit("day", 3, 0, 3 * MICROS_PER_DAY); + testSingleUnit("hour", 3, 0, 3 * MICROS_PER_HOUR); + testSingleUnit("minute", 3, 0, 3 * MICROS_PER_MINUTE); + testSingleUnit("second", 3, 0, 3 * MICROS_PER_SECOND); + testSingleUnit("millisecond", 3, 0, 3 * MICROS_PER_MILLI); + testSingleUnit("microsecond", 3, 0, 3); + + String input; + + input = "interval -5 years 23 month"; + Interval result = new Interval(-5 * 12 + 23, 0); + assertEquals(Interval.fromString(input), result); + + // Error cases + input = "interval 3month 1 hour"; + assertEquals(Interval.fromString(input), null); + + input = "interval 3 moth 1 hour"; + assertEquals(Interval.fromString(input), null); + + input = "interval"; + assertEquals(Interval.fromString(input), null); + + input = "int"; + assertEquals(Interval.fromString(input), null); + + input = ""; + assertEquals(Interval.fromString(input), null); + + input = null; + assertEquals(Interval.fromString(input), null); + } + + private void testSingleUnit(String unit, int number, int months, long microseconds) { + String input1 = "interval " + number + " " + unit; + String input2 = "interval " + number + " " + unit + "s"; + Interval result = new Interval(months, microseconds); + assertEquals(Interval.fromString(input1), result); + assertEquals(Interval.fromString(input2), result); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java new file mode 100644 index 0000000000000..694bdc29f39d1 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -0,0 +1,237 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.unsafe.types; + +import java.io.UnsupportedEncodingException; + +import org.junit.Test; + +import static junit.framework.Assert.*; + +import static org.apache.spark.unsafe.types.UTF8String.*; + +public class UTF8StringSuite { + + private void checkBasic(String str, int len) throws UnsupportedEncodingException { + UTF8String s1 = fromString(str); + UTF8String s2 = fromBytes(str.getBytes("utf8")); + assertEquals(s1.numChars(), len); + assertEquals(s2.numChars(), len); + + assertEquals(s1.toString(), str); + assertEquals(s2.toString(), str); + assertEquals(s1, s2); + + assertEquals(s1.hashCode(), s2.hashCode()); + + assertEquals(s1.compareTo(s2), 0); + + assertEquals(s1.contains(s2), true); + assertEquals(s2.contains(s1), true); + assertEquals(s1.startsWith(s1), true); + assertEquals(s1.endsWith(s1), true); + } + + @Test + public void basicTest() throws UnsupportedEncodingException { + checkBasic("", 0); + checkBasic("hello", 5); + checkBasic("大 千 世 界", 7); + } + + @Test + public void compareTo() { + assertTrue(fromString("abc").compareTo(fromString("ABC")) > 0); + assertTrue(fromString("abc0").compareTo(fromString("abc")) > 0); + assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabc")) == 0); + assertTrue(fromString("aBcabcabc").compareTo(fromString("Abcabcabc")) > 0); + assertTrue(fromString("Abcabcabc").compareTo(fromString("abcabcabC")) < 0); + assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabC")) > 0); + + assertTrue(fromString("abc").compareTo(fromString("世界")) < 0); + assertTrue(fromString("你好").compareTo(fromString("世界")) > 0); + assertTrue(fromString("你好123").compareTo(fromString("你好122")) > 0); + } + + protected void testUpperandLower(String upper, String lower) { + UTF8String us = fromString(upper); + UTF8String ls = fromString(lower); + assertEquals(ls, us.toLowerCase()); + assertEquals(us, ls.toUpperCase()); + assertEquals(us, us.toUpperCase()); + assertEquals(ls, ls.toLowerCase()); + } + + @Test + public void upperAndLower() { + testUpperandLower("", ""); + testUpperandLower("0123456", "0123456"); + testUpperandLower("ABCXYZ", "abcxyz"); + testUpperandLower("ЀЁЂѺΏỀ", "ѐёђѻώề"); + testUpperandLower("大千世界 数据砖头", "大千世界 数据砖头"); + } + + @Test + public void contains() { + assertTrue(fromString("").contains(fromString(""))); + assertTrue(fromString("hello").contains(fromString("ello"))); + assertFalse(fromString("hello").contains(fromString("vello"))); + assertFalse(fromString("hello").contains(fromString("hellooo"))); + assertTrue(fromString("大千世界").contains(fromString("千世界"))); + assertFalse(fromString("大千世界").contains(fromString("世千"))); + assertFalse(fromString("大千世界").contains(fromString("大千世界好"))); + } + + @Test + public void startsWith() { + assertTrue(fromString("").startsWith(fromString(""))); + assertTrue(fromString("hello").startsWith(fromString("hell"))); + assertFalse(fromString("hello").startsWith(fromString("ell"))); + assertFalse(fromString("hello").startsWith(fromString("hellooo"))); + assertTrue(fromString("数据砖头").startsWith(fromString("数据"))); + assertFalse(fromString("大千世界").startsWith(fromString("千"))); + assertFalse(fromString("大千世界").startsWith(fromString("大千世界好"))); + } + + @Test + public void endsWith() { + assertTrue(fromString("").endsWith(fromString(""))); + assertTrue(fromString("hello").endsWith(fromString("ello"))); + assertFalse(fromString("hello").endsWith(fromString("ellov"))); + assertFalse(fromString("hello").endsWith(fromString("hhhello"))); + assertTrue(fromString("大千世界").endsWith(fromString("世界"))); + assertFalse(fromString("大千世界").endsWith(fromString("世"))); + assertFalse(fromString("数据砖头").endsWith(fromString("我的数据砖头"))); + } + + @Test + public void substring() { + assertEquals(fromString(""), fromString("hello").substring(0, 0)); + assertEquals(fromString("el"), fromString("hello").substring(1, 3)); + assertEquals(fromString("数"), fromString("数据砖头").substring(0, 1)); + assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3)); + assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5)); + assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2)); + } + + @Test + public void trims() { + assertEquals(fromString("hello"), fromString(" hello ").trim()); + assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); + assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); + + assertEquals(fromString(""), fromString(" ").trim()); + assertEquals(fromString(""), fromString(" ").trimLeft()); + assertEquals(fromString(""), fromString(" ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); + assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft()); + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString("数据砖头").trim()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight()); + } + + @Test + public void indexOf() { + assertEquals(0, fromString("").indexOf(fromString(""), 0)); + assertEquals(-1, fromString("").indexOf(fromString("l"), 0)); + assertEquals(0, fromString("hello").indexOf(fromString(""), 0)); + assertEquals(2, fromString("hello").indexOf(fromString("l"), 0)); + assertEquals(3, fromString("hello").indexOf(fromString("l"), 3)); + assertEquals(-1, fromString("hello").indexOf(fromString("a"), 0)); + assertEquals(2, fromString("hello").indexOf(fromString("ll"), 0)); + assertEquals(-1, fromString("hello").indexOf(fromString("ll"), 4)); + assertEquals(1, fromString("数据砖头").indexOf(fromString("据砖"), 0)); + assertEquals(-1, fromString("数据砖头").indexOf(fromString("数"), 3)); + assertEquals(0, fromString("数据砖头").indexOf(fromString("数"), 0)); + assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0)); + } + + @Test + public void reverse() { + assertEquals(fromString("olleh"), fromString("hello").reverse()); + assertEquals(fromString(""), fromString("").reverse()); + assertEquals(fromString("者行孙"), fromString("孙行者").reverse()); + assertEquals(fromString("者行孙 olleh"), fromString("hello 孙行者").reverse()); + } + + @Test + public void repeat() { + assertEquals(fromString("数d数d数d数d数d"), fromString("数d").repeat(5)); + assertEquals(fromString("数d"), fromString("数d").repeat(1)); + assertEquals(fromString(""), fromString("数d").repeat(-1)); + } + + @Test + public void pad() { + assertEquals(fromString("hel"), fromString("hello").lpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").lpad(5, fromString("????"))); + assertEquals(fromString("?hello"), fromString("hello").lpad(6, fromString("????"))); + assertEquals(fromString("???????hello"), fromString("hello").lpad(12, fromString("????"))); + assertEquals(fromString("?????hello"), fromString("hello").lpad(10, fromString("?????"))); + assertEquals(fromString("???????"), fromString("").lpad(7, fromString("?????"))); + + assertEquals(fromString("hel"), fromString("hello").rpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").rpad(5, fromString("????"))); + assertEquals(fromString("hello?"), fromString("hello").rpad(6, fromString("????"))); + assertEquals(fromString("hello???????"), fromString("hello").rpad(12, fromString("????"))); + assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????"))); + assertEquals(fromString("???????"), fromString("").rpad(7, fromString("?????"))); + + + assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????"))); + assertEquals(fromString("?数据砖头"), fromString("数据砖头").lpad(5, fromString("????"))); + assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????"))); + assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者"))); + assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者"))); + assertEquals(fromString("孙行者孙行者孙行数据砖头"), fromString("数据砖头").lpad(12, fromString("孙行者"))); + + assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????"))); + assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????"))); + assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????"))); + assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者"))); + assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者"))); + assertEquals(fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者"))); + } + + @Test + public void levenshteinDistance() { + assertEquals( + UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("")), 0); + assertEquals( + UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("a")), 1); + assertEquals( + UTF8String.fromString("aaapppp").levenshteinDistance(UTF8String.fromString("")), 7); + assertEquals( + UTF8String.fromString("frog").levenshteinDistance(UTF8String.fromString("fog")), 1); + assertEquals( + UTF8String.fromString("fly").levenshteinDistance(UTF8String.fromString("ant")),3); + assertEquals( + UTF8String.fromString("elephant").levenshteinDistance(UTF8String.fromString("hippo")), 7); + assertEquals( + UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("elephant")), 7); + assertEquals( + UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("zzzzzzzz")), 8); + assertEquals( + UTF8String.fromString("hello").levenshteinDistance(UTF8String.fromString("hallo")),1); + assertEquals( + UTF8String.fromString("世界千世").levenshteinDistance(UTF8String.fromString("千a世b")),4); + } +} diff --git a/yarn/pom.xml b/yarn/pom.xml index 7c8c3613e7a05..2aeed98285aa8 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -30,6 +30,7 @@ Spark Project YARN yarn + 1.9 @@ -38,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.hadoop hadoop-yarn-api @@ -85,7 +93,12 @@ jetty-servlet - + + + org.apache.hadoop hadoop-yarn-server-tests @@ -94,62 +107,47 @@ org.mockito - mockito-all + mockito-core test + + org.mortbay.jetty + jetty + 6.1.26 + + + org.mortbay.jetty + servlet-api + + + test + + + com.sun.jersey + jersey-core + ${jersey.version} + test + + + com.sun.jersey + jersey-json + ${jersey.version} + test + + + stax + stax-api + + + + + com.sun.jersey + jersey-server + ${jersey.version} + test + - - - - - hadoop-2.2 - - 1.9 - - - - org.mortbay.jetty - jetty - 6.1.26 - - - org.mortbay.jetty - servlet-api - - - test - - - com.sun.jersey - jersey-core - ${jersey.version} - test - - - com.sun.jersey - jersey-json - ${jersey.version} - test - - - stax - stax-api - - - - - com.sun.jersey - jersey-server - ${jersey.version} - test - - - - - + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala index aaae6f9734a85..56e4741b93873 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -60,8 +60,13 @@ private[yarn] class AMDelegationTokenRenewer( private val hadoopUtil = YarnSparkHadoopUtil.get - private val daysToKeepFiles = sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) - private val numFilesToKeep = sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + private val daysToKeepFiles = + sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) + private val numFilesToKeep = + sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + private val freshHadoopConf = + hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme) /** * Schedule a login from the keytab and principal set using the --principal and --keytab @@ -120,8 +125,8 @@ private[yarn] class AMDelegationTokenRenewer( private def cleanupOldFiles(): Unit = { import scala.concurrent.duration._ try { - val remoteFs = FileSystem.get(hadoopConf) - val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val remoteFs = FileSystem.get(freshHadoopConf) + val credentialsPath = new Path(credentialsFile) val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles days).toMillis hadoopUtil.listFilesSorted( remoteFs, credentialsPath.getParent, @@ -160,19 +165,19 @@ private[yarn] class AMDelegationTokenRenewer( val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) logInfo("Successfully logged into KDC.") val tempCreds = keytabLoggedInUGI.getCredentials - val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val credentialsPath = new Path(credentialsFile) val dst = credentialsPath.getParent keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { // Get a copy of the credentials override def run(): Void = { val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst - hadoopUtil.obtainTokensForNamenodes(nns, hadoopConf, tempCreds) + hadoopUtil.obtainTokensForNamenodes(nns, freshHadoopConf, tempCreds) null } }) // Add the temp credentials back to the original ones. UserGroupInformation.getCurrentUser.addCredentials(tempCreds) - val remoteFs = FileSystem.get(hadoopConf) + val remoteFs = FileSystem.get(freshHadoopConf) // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file // and update the lastCredentialsFileSuffix. @@ -186,13 +191,12 @@ private[yarn] class AMDelegationTokenRenewer( } val nextSuffix = lastCredentialsFileSuffix + 1 val tokenPathStr = - sparkConf.get("spark.yarn.credentials.file") + - SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix + credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix val tokenPath = new Path(tokenPathStr) val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) logInfo("Writing out delegation tokens to " + tempTokenPath.toString) val credentials = UserGroupInformation.getCurrentUser.getCredentials - credentials.writeTokenStorageFile(tempTokenPath, hadoopConf) + credentials.writeTokenStorageFile(tempTokenPath, freshHadoopConf) logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr") remoteFs.rename(tempTokenPath, tokenPath) logInfo("Delegation token file rename complete.") diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 29752969e6152..83dafa4a125d2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -32,9 +32,9 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} import org.apache.spark.SparkException -import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer -import org.apache.spark.scheduler.cluster.YarnSchedulerBackend +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util._ @@ -46,6 +46,14 @@ private[spark] class ApplicationMaster( client: YarnRMClient) extends Logging { + // Load the properties file with the Spark configuration and set entries as system properties, + // so that user code run inside the AM also has access to them. + if (args.propertiesFile != null) { + Utils.getPropertiesFromFile(args.propertiesFile).foreach { case (k, v) => + sys.props(k) = v + } + } + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. @@ -67,6 +75,7 @@ private[spark] class ApplicationMaster( @volatile private var reporterThread: Thread = _ @volatile private var allocator: YarnAllocator = _ + private val allocatorLock = new Object() // Fields used in client mode. private var rpcEnv: RpcEnv = null @@ -220,7 +229,7 @@ private[spark] class ApplicationMaster( sparkContextRef.compareAndSet(sc, null) } - private def registerAM(uiAddress: String, securityMgr: SecurityManager) = { + private def registerAM(_rpcEnv: RpcEnv, uiAddress: String, securityMgr: SecurityManager) = { val sc = sparkContextRef.get() val appId = client.getAttemptId().getApplicationId().toString() @@ -231,8 +240,14 @@ private[spark] class ApplicationMaster( .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } .getOrElse("") - allocator = client.register(yarnConf, - if (sc != null) sc.getConf else sparkConf, + val _sparkConf = if (sc != null) sc.getConf else sparkConf + val driverUrl = _rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + allocator = client.register(driverUrl, + yarnConf, + _sparkConf, if (sc != null) sc.preferredNodeLocationData else Map(), uiAddress, historyAddress, @@ -279,7 +294,7 @@ private[spark] class ApplicationMaster( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + registerAM(rpcEnv, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } } @@ -289,7 +304,7 @@ private[spark] class ApplicationMaster( rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) waitForSparkDriver() addAmIpFilter() - registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + registerAM(rpcEnv, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -300,11 +315,14 @@ private[spark] class ApplicationMaster( val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) // we want to be reasonably responsive without causing too many requests to RM. - val schedulerInterval = - sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "5s") + val heartbeatInterval = math.max(0, math.min(expiryInterval / 2, + sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "3s"))) - // must be <= expiryInterval / 2. - val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) + // we want to check more frequently for pending containers + val initialAllocationInterval = math.min(heartbeatInterval, + sparkConf.getTimeAsMs("spark.yarn.scheduler.initial-allocation.interval", "200ms")) + + var nextAllocationInterval = initialAllocationInterval // The number of failures in a row until Reporter thread give up val reporterMaxFailures = sparkConf.getInt("spark.yarn.scheduler.reporterThread.maxFailures", 5) @@ -330,15 +348,29 @@ private[spark] class ApplicationMaster( if (!NonFatal(e) || failureCount >= reporterMaxFailures) { finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " + - s"${failureCount} time(s) from Reporter thread.") - + s"$failureCount time(s) from Reporter thread.") } else { - logWarning(s"Reporter thread fails ${failureCount} time(s) in a row.", e) + logWarning(s"Reporter thread fails $failureCount time(s) in a row.", e) } } } try { - Thread.sleep(interval) + val numPendingAllocate = allocator.getNumPendingAllocate + val sleepInterval = + if (numPendingAllocate > 0) { + val currentAllocationInterval = + math.min(heartbeatInterval, nextAllocationInterval) + nextAllocationInterval = currentAllocationInterval * 2 // avoid overflow + currentAllocationInterval + } else { + nextAllocationInterval = initialAllocationInterval + heartbeatInterval + } + logDebug(s"Number of pending allocations is $numPendingAllocate. " + + s"Sleeping for $sleepInterval.") + allocatorLock.synchronized { + allocatorLock.wait(sleepInterval) + } } catch { case e: InterruptedException => } @@ -349,7 +381,8 @@ private[spark] class ApplicationMaster( t.setDaemon(true) t.setName("Reporter") t.start() - logInfo("Started progress reporter thread - sleep time : " + interval) + logInfo(s"Started progress reporter thread with (heartbeat : $heartbeatInterval, " + + s"initial allocation : $initialAllocationInterval) intervals") t } @@ -465,9 +498,11 @@ private[spark] class ApplicationMaster( new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } + var userArgs = args.userArgs if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { - System.setProperty("spark.submit.pyFiles", - PythonRunner.formatPaths(args.pyFiles).mkString(",")) + // When running pyspark, the app is run using PythonRunner. The second argument is the list + // of files to add to PYTHONPATH, which Client.scala already handles, so it's empty. + userArgs = Seq(args.primaryPyFile, "") ++ userArgs } if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { // TODO(davies): add R dependencies here @@ -478,9 +513,7 @@ private[spark] class ApplicationMaster( val userThread = new Thread { override def run() { try { - val mainArgs = new Array[String](args.userArgs.size) - args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) - mainMethod.invoke(null, mainArgs) + mainMethod.invoke(null, userArgs.toArray) finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) logDebug("Done running users class") } catch { @@ -524,8 +557,15 @@ private[spark] class ApplicationMaster( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestExecutors(requestedTotal) => Option(allocator) match { - case Some(a) => a.requestTotalExecutors(requestedTotal) - case None => logWarning("Container allocator is not ready to request executors yet.") + case Some(a) => + allocatorLock.synchronized { + if (a.requestTotalExecutors(requestedTotal)) { + allocatorLock.notifyAll() + } + } + + case None => + logWarning("Container allocator is not ready to request executors yet.") } context.reply(true) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index ae6dc1094d724..37f793763367e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -26,11 +26,11 @@ class ApplicationMasterArguments(val args: Array[String]) { var userClass: String = null var primaryPyFile: String = null var primaryRFile: String = null - var pyFiles: String = null - var userArgs: Seq[String] = Seq[String]() + var userArgs: Seq[String] = Nil var executorMemory = 1024 var executorCores = 1 var numExecutors = DEFAULT_NUMBER_EXECUTORS + var propertiesFile: String = null parseArgs(args.toList) @@ -59,10 +59,6 @@ class ApplicationMasterArguments(val args: Array[String]) { primaryRFile = value args = tail - case ("--py-files") :: value :: tail => - pyFiles = value - args = tail - case ("--args" | "--arg") :: value :: tail => userArgsBuffer += value args = tail @@ -79,13 +75,19 @@ class ApplicationMasterArguments(val args: Array[String]) { executorCores = value args = tail + case ("--properties-file") :: value :: tail => + propertiesFile = value + args = tail + case _ => printUsageAndExit(1, args) } } if (primaryPyFile != null && primaryRFile != null) { + // scalastyle:off println System.err.println("Cannot have primary-py-file and primary-r-file at the same time") + // scalastyle:on println System.exit(-1) } @@ -93,6 +95,7 @@ class ApplicationMasterArguments(val args: Array[String]) { } def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + // scalastyle:off println if (unknownParam != null) { System.err.println("Unknown/unsupported param " + unknownParam) } @@ -111,6 +114,7 @@ class ApplicationMasterArguments(val args: Array[String]) { | --executor-cores NUM Number of cores for the executors (Default: 1) | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) """.stripMargin) + // scalastyle:on println System.exit(exitCode) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index d21a7393478ce..f86b6d1e5d7bc 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,18 +17,21 @@ package org.apache.spark.deploy.yarn -import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream} +import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException, + OutputStreamWriter} import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import java.nio.ByteBuffer import java.security.PrivilegedExceptionAction -import java.util.UUID +import java.util.{Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} +import scala.util.control.NonFatal +import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects import com.google.common.io.Files @@ -91,30 +94,59 @@ private[spark] class Client( * available in the alpha API. */ def submitApplication(): ApplicationId = { - // Setup the credentials before doing anything else, so we have don't have issues at any point. - setupCredentials() - yarnClient.init(yarnConf) - yarnClient.start() - - logInfo("Requesting a new application from cluster with %d NodeManagers" - .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers)) - - // Get a new application from our RM - val newApp = yarnClient.createApplication() - val newAppResponse = newApp.getNewApplicationResponse() - val appId = newAppResponse.getApplicationId() - - // Verify whether the cluster has enough resources for our AM - verifyClusterResources(newAppResponse) - - // Set up the appropriate contexts to launch our AM - val containerContext = createContainerLaunchContext(newAppResponse) - val appContext = createApplicationSubmissionContext(newApp, containerContext) - - // Finally, submit and monitor the application - logInfo(s"Submitting application ${appId.getId} to ResourceManager") - yarnClient.submitApplication(appContext) - appId + var appId: ApplicationId = null + try { + // Setup the credentials before doing anything else, + // so we have don't have issues at any point. + setupCredentials() + yarnClient.init(yarnConf) + yarnClient.start() + + logInfo("Requesting a new application from cluster with %d NodeManagers" + .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers)) + + // Get a new application from our RM + val newApp = yarnClient.createApplication() + val newAppResponse = newApp.getNewApplicationResponse() + appId = newAppResponse.getApplicationId() + + // Verify whether the cluster has enough resources for our AM + verifyClusterResources(newAppResponse) + + // Set up the appropriate contexts to launch our AM + val containerContext = createContainerLaunchContext(newAppResponse) + val appContext = createApplicationSubmissionContext(newApp, containerContext) + + // Finally, submit and monitor the application + logInfo(s"Submitting application ${appId.getId} to ResourceManager") + yarnClient.submitApplication(appContext) + appId + } catch { + case e: Throwable => + if (appId != null) { + cleanupStagingDir(appId) + } + throw e + } + } + + /** + * Cleanup application staging directory. + */ + private def cleanupStagingDir(appId: ApplicationId): Unit = { + val appStagingDir = getAppStagingDir(appId) + try { + val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) + val stagingDirPath = new Path(appStagingDir) + val fs = FileSystem.get(hadoopConf) + if (!preserveFiles && fs.exists(stagingDirPath)) { + logInfo("Deleting staging directory " + stagingDirPath) + fs.delete(stagingDirPath, true) + } + } catch { + case ioe: IOException => + logWarning("Failed to cleanup staging dir " + appStagingDir, ioe) + } } /** @@ -218,7 +250,9 @@ private[spark] class Client( * This is used for setting up a container launch context for our ApplicationMaster. * Exposed for testing. */ - def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { + def prepareLocalResources( + appStagingDir: String, + pySparkArchives: Seq[String]): HashMap[String, LocalResource] = { logInfo("Preparing resources for our AM container") // Upload Spark and the application JAR to the remote file system if necessary, // and add them as local resources to the application master. @@ -248,20 +282,6 @@ private[spark] class Client( "for alternatives.") } - // If we passed in a keytab, make sure we copy the keytab to the staging directory on - // HDFS, and setup the relevant environment vars, so the AM can login again. - if (loginFromKeytab) { - logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + - " via the YARN Secure Distributed Cache.") - val localUri = new URI(args.keytab) - val localPath = getQualifiedLocalPath(localUri, hadoopConf) - val destinationPath = copyFileToRemote(dst, localPath, replication) - val destFs = FileSystem.get(destinationPath.toUri(), hadoopConf) - distCacheMgr.addResource( - destFs, hadoopConf, destinationPath, localResources, LocalResourceType.FILE, - sparkConf.get("spark.yarn.keytab"), statCache, appMasterOnly = true) - } - def addDistributedUri(uri: URI): Boolean = { val uriStr = uri.toString() if (distributedUris.contains(uriStr)) { @@ -273,6 +293,58 @@ private[spark] class Client( } } + /** + * Distribute a file to the cluster. + * + * If the file's path is a "local:" URI, it's actually not distributed. Other files are copied + * to HDFS (if not already there) and added to the application's distributed cache. + * + * @param path URI of the file to distribute. + * @param resType Type of resource being distributed. + * @param destName Name of the file in the distributed cache. + * @param targetDir Subdirectory where to place the file. + * @param appMasterOnly Whether to distribute only to the AM. + * @return A 2-tuple. First item is whether the file is a "local:" URI. Second item is the + * localized path for non-local paths, or the input `path` for local paths. + * The localized path will be null if the URI has already been added to the cache. + */ + def distribute( + path: String, + resType: LocalResourceType = LocalResourceType.FILE, + destName: Option[String] = None, + targetDir: Option[String] = None, + appMasterOnly: Boolean = false): (Boolean, String) = { + val localURI = new URI(path.trim()) + if (localURI.getScheme != LOCAL_SCHEME) { + if (addDistributedUri(localURI)) { + val localPath = getQualifiedLocalPath(localURI, hadoopConf) + val linkname = targetDir.map(_ + "/").getOrElse("") + + destName.orElse(Option(localURI.getFragment())).getOrElse(localPath.getName()) + val destPath = copyFileToRemote(dst, localPath, replication) + val destFs = FileSystem.get(destPath.toUri(), hadoopConf) + distCacheMgr.addResource( + destFs, hadoopConf, destPath, localResources, resType, linkname, statCache, + appMasterOnly = appMasterOnly) + (false, linkname) + } else { + (false, null) + } + } else { + (true, path.trim()) + } + } + + // If we passed in a keytab, make sure we copy the keytab to the staging directory on + // HDFS, and setup the relevant environment vars, so the AM can login again. + if (loginFromKeytab) { + logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + + " via the YARN Secure Distributed Cache.") + val (_, localizedPath) = distribute(args.keytab, + destName = Some(sparkConf.get("spark.yarn.keytab")), + appMasterOnly = true) + require(localizedPath != null, "Keytab file already distributed.") + } + /** * Copy the given main resource to the distributed cache if the scheme is not "local". * Otherwise, set the corresponding key in our SparkConf to handle it downstream. @@ -285,33 +357,18 @@ private[spark] class Client( (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR), (APP_JAR, args.userJar, CONF_SPARK_USER_JAR), ("log4j.properties", oldLog4jConf.orNull, null) - ).foreach { case (destName, _localPath, confKey) => - val localPath: String = if (_localPath != null) _localPath.trim() else "" - if (!localPath.isEmpty()) { - val localURI = new URI(localPath) - if (localURI.getScheme != LOCAL_SCHEME) { - if (addDistributedUri(localURI)) { - val src = getQualifiedLocalPath(localURI, hadoopConf) - val destPath = copyFileToRemote(dst, src, replication) - val destFs = FileSystem.get(destPath.toUri(), hadoopConf) - distCacheMgr.addResource(destFs, hadoopConf, destPath, - localResources, LocalResourceType.FILE, destName, statCache) - } - } else if (confKey != null) { + ).foreach { case (destName, path, confKey) => + if (path != null && !path.trim().isEmpty()) { + val (isLocal, localizedPath) = distribute(path, destName = Some(destName)) + if (isLocal && confKey != null) { + require(localizedPath != null, s"Path $path already distributed.") // If the resource is intended for local use only, handle this downstream // by setting the appropriate property - sparkConf.set(confKey, localPath) + sparkConf.set(confKey, localizedPath) } } } - createConfArchive().foreach { file => - require(addDistributedUri(file.toURI())) - val destPath = copyFileToRemote(dst, new Path(file.toURI()), replication) - distCacheMgr.addResource(fs, hadoopConf, destPath, localResources, LocalResourceType.ARCHIVE, - LOCALIZED_HADOOP_CONF_DIR, statCache, appMasterOnly = true) - } - /** * Do the same for any additional resources passed in through ClientArguments. * Each resource category is represented by a 3-tuple of: @@ -327,21 +384,10 @@ private[spark] class Client( ).foreach { case (flist, resType, addToClasspath) => if (flist != null && !flist.isEmpty()) { flist.split(',').foreach { file => - val localURI = new URI(file.trim()) - if (localURI.getScheme != LOCAL_SCHEME) { - if (addDistributedUri(localURI)) { - val localPath = new Path(localURI) - val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) - val destPath = copyFileToRemote(dst, localPath, replication) - distCacheMgr.addResource( - fs, hadoopConf, destPath, localResources, resType, linkname, statCache) - if (addToClasspath) { - cachedSecondaryJarLinks += linkname - } - } - } else if (addToClasspath) { - // Resource is intended for local use only and should be added to the class path - cachedSecondaryJarLinks += file.trim() + val (_, localizedPath) = distribute(file, resType = resType) + require(localizedPath != null) + if (addToClasspath) { + cachedSecondaryJarLinks += localizedPath } } } @@ -350,11 +396,31 @@ private[spark] class Client( sparkConf.set(CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) } + if (isClusterMode && args.primaryPyFile != null) { + distribute(args.primaryPyFile, appMasterOnly = true) + } + + pySparkArchives.foreach { f => distribute(f) } + + // The python files list needs to be treated especially. All files that are not an + // archive need to be placed in a subdirectory that will be added to PYTHONPATH. + args.pyFiles.foreach { f => + val targetDir = if (f.endsWith(".py")) Some(LOCALIZED_PYTHON_DIR) else None + distribute(f, targetDir = targetDir) + } + + // Distribute an archive with Hadoop and Spark configuration for the AM. + val (_, confLocalizedPath) = distribute(createConfArchive().getAbsolutePath(), + resType = LocalResourceType.ARCHIVE, + destName = Some(LOCALIZED_CONF_DIR), + appMasterOnly = true) + require(confLocalizedPath != null) + localResources } /** - * Create an archive with the Hadoop config files for distribution. + * Create an archive with the config files for distribution. * * These are only used by the AM, since executors will use the configuration object broadcast by * the driver. The files are zipped and added to the job as an archive, so that YARN will explode @@ -366,8 +432,11 @@ private[spark] class Client( * * Currently this makes a shallow copy of the conf directory. If there are cases where a * Hadoop config directory contains subdirectories, this code will have to be fixed. + * + * The archive also contains some Spark configuration. Namely, it saves the contents of + * SparkConf in a file to be loaded by the AM process. */ - private def createConfArchive(): Option[File] = { + private def createConfArchive(): File = { val hadoopConfFiles = new HashMap[String, File]() Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => sys.env.get(envKey).foreach { path => @@ -382,28 +451,32 @@ private[spark] class Client( } } - if (!hadoopConfFiles.isEmpty) { - val hadoopConfArchive = File.createTempFile(LOCALIZED_HADOOP_CONF_DIR, ".zip", - new File(Utils.getLocalDir(sparkConf))) - - val hadoopConfStream = new ZipOutputStream(new FileOutputStream(hadoopConfArchive)) - try { - hadoopConfStream.setLevel(0) - hadoopConfFiles.foreach { case (name, file) => - if (file.canRead()) { - hadoopConfStream.putNextEntry(new ZipEntry(name)) - Files.copy(file, hadoopConfStream) - hadoopConfStream.closeEntry() - } + val confArchive = File.createTempFile(LOCALIZED_CONF_DIR, ".zip", + new File(Utils.getLocalDir(sparkConf))) + val confStream = new ZipOutputStream(new FileOutputStream(confArchive)) + + try { + confStream.setLevel(0) + hadoopConfFiles.foreach { case (name, file) => + if (file.canRead()) { + confStream.putNextEntry(new ZipEntry(name)) + Files.copy(file, confStream) + confStream.closeEntry() } - } finally { - hadoopConfStream.close() } - Some(hadoopConfArchive) - } else { - None + // Save Spark configuration to a file in the archive. + val props = new Properties() + sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) } + confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE)) + val writer = new OutputStreamWriter(confStream, UTF_8) + props.store(writer, "Spark configuration.") + writer.flush() + confStream.closeEntry() + } finally { + confStream.close() } + confArchive } /** @@ -431,7 +504,9 @@ private[spark] class Client( /** * Set up the environment for launching our ApplicationMaster container. */ - private def setupLaunchEnv(stagingDir: String): HashMap[String, String] = { + private def setupLaunchEnv( + stagingDir: String, + pySparkArchives: Seq[String]): HashMap[String, String] = { logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.driver.extraClassPath") @@ -449,9 +524,6 @@ private[spark] class Client( val renewalInterval = getTokenRenewalInterval(stagingDirPath) sparkConf.set("spark.yarn.token.renewal.interval", renewalInterval.toString) } - // Set the environment variables to be passed on to the executors. - distCacheMgr.setDistFilesEnv(env) - distCacheMgr.setDistArchivesEnv(env) // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* val amEnvPrefix = "spark.yarn.appMasterEnv." @@ -468,15 +540,32 @@ private[spark] class Client( env("SPARK_YARN_USER_ENV") = userEnvs } - // if spark.submit.pyArchives is in sparkConf, append pyArchives to PYTHONPATH - // that can be passed on to the ApplicationMaster and the executors. - if (sparkConf.contains("spark.submit.pyArchives")) { - var pythonPath = sparkConf.get("spark.submit.pyArchives") - if (env.contains("PYTHONPATH")) { - pythonPath = Seq(env.get("PYTHONPATH"), pythonPath).mkString(File.pathSeparator) + // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH + // of the container processes too. Add all non-.py files directly to PYTHONPATH. + // + // NOTE: the code currently does not handle .py files defined with a "local:" scheme. + val pythonPath = new ListBuffer[String]() + val (pyFiles, pyArchives) = args.pyFiles.partition(_.endsWith(".py")) + if (pyFiles.nonEmpty) { + pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + LOCALIZED_PYTHON_DIR) + } + (pySparkArchives ++ pyArchives).foreach { path => + val uri = new URI(path) + if (uri.getScheme != LOCAL_SCHEME) { + pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + new Path(path).getName()) + } else { + pythonPath += uri.getPath() } - env("PYTHONPATH") = pythonPath - sparkConf.setExecutorEnv("PYTHONPATH", pythonPath) + } + + // Finally, update the Spark config to propagate PYTHONPATH to the AM and executors. + if (pythonPath.nonEmpty) { + val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath) + .mkString(YarnSparkHadoopUtil.getClassPathSeparator) + env("PYTHONPATH") = pythonPathStr + sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr) } // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to @@ -526,8 +615,19 @@ private[spark] class Client( logInfo("Setting up container launch context for our AM") val appId = newAppResponse.getApplicationId val appStagingDir = getAppStagingDir(appId) - val localResources = prepareLocalResources(appStagingDir) - val launchEnv = setupLaunchEnv(appStagingDir) + val pySparkArchives = + if (sys.props.getOrElse("spark.yarn.isPython", "false").toBoolean) { + findPySparkArchives() + } else { + Nil + } + val launchEnv = setupLaunchEnv(appStagingDir, pySparkArchives) + val localResources = prepareLocalResources(appStagingDir, pySparkArchives) + + // Set the environment variables to be passed on to the executors. + distCacheMgr.setDistFilesEnv(launchEnv) + distCacheMgr.setDistArchivesEnv(launchEnv) + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) amContainer.setLocalResources(localResources) amContainer.setEnvironment(launchEnv) @@ -567,13 +667,6 @@ private[spark] class Client( javaOpts += "-XX:CMSIncrementalDutyCycle=10" } - // Forward the Spark configuration to the application master / executors. - // TODO: it might be nicer to pass these as an internal environment variable rather than - // as Java options, due to complications with string parsing of nested quotes. - for ((k, v) <- sparkConf.getAll) { - javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") - } - // Include driver-specific java options if we are launching a driver if (isClusterMode) { val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions") @@ -584,7 +677,7 @@ private[spark] class Client( val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { - prefixEnv = Some(Utils.libraryPathEnvPrefix(libraryPaths)) + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths))) } if (sparkConf.getOption("spark.yarn.am.extraJavaOptions").isDefined) { logWarning("spark.yarn.am.extraJavaOptions will not take effect in cluster mode") @@ -606,7 +699,7 @@ private[spark] class Client( } sparkConf.getOption("spark.yarn.am.extraLibraryPath").foreach { paths => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(paths))) + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) } } @@ -626,14 +719,8 @@ private[spark] class Client( Nil } val primaryPyFile = - if (args.primaryPyFile != null) { - Seq("--primary-py-file", args.primaryPyFile) - } else { - Nil - } - val pyFiles = - if (args.pyFiles != null) { - Seq("--py-files", args.pyFiles) + if (isClusterMode && args.primaryPyFile != null) { + Seq("--primary-py-file", new Path(args.primaryPyFile).getName()) } else { Nil } @@ -645,13 +732,10 @@ private[spark] class Client( } val amClass = if (isClusterMode) { - Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName + Utils.classForName("org.apache.spark.deploy.yarn.ApplicationMaster").getName } else { - Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName + Utils.classForName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName } - if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { - args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs - } if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs } @@ -659,11 +743,13 @@ private[spark] class Client( Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) } val amArgs = - Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++ + Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++ userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString) + "--num-executors ", args.numExecutors.toString, + "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) // Command for the ApplicationMaster val commands = prefixEnv ++ Seq( @@ -742,6 +828,9 @@ private[spark] class Client( case e: ApplicationNotFoundException => logError(s"Application $appId not found.") return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + case NonFatal(e) => + logError(s"Failed to contact YARN for application $appId.", e) + return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED) } val state = report.getYarnApplicationState @@ -760,6 +849,7 @@ private[spark] class Client( if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { + cleanupStagingDir(appId) return (state, report.getFinalApplicationStatus) } @@ -827,12 +917,28 @@ private[spark] class Client( } } } + + private def findPySparkArchives(): Seq[String] = { + sys.env.get("PYSPARK_ARCHIVES_PATH") + .map(_.split(",").toSeq) + .getOrElse { + val pyLibPath = Seq(sys.env("SPARK_HOME"), "python", "lib").mkString(File.separator) + val pyArchivesFile = new File(pyLibPath, "pyspark.zip") + require(pyArchivesFile.exists(), + "pyspark.zip not found; cannot run pyspark application in YARN mode.") + val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip") + require(py4jFile.exists(), + "py4j-0.8.2.1-src.zip not found; cannot run pyspark application in YARN mode.") + Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) + } + } + } object Client extends Logging { def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { - println("WARNING: This client is deprecated and will be removed in a " + + logWarning("WARNING: This client is deprecated and will be removed in a " + "future version of Spark. Use ./bin/spark-submit with \"--master yarn\"") } @@ -877,8 +983,14 @@ object Client extends Logging { // Distribution-defined classpath to add to processes val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH" - // Subdirectory where the user's hadoop config files will be placed. - val LOCALIZED_HADOOP_CONF_DIR = "__hadoop_conf__" + // Subdirectory where the user's Spark and Hadoop config files will be placed. + val LOCALIZED_CONF_DIR = "__spark_conf__" + + // Name of the file in the conf archive containing Spark configuration. + val SPARK_CONF_FILE = "__spark_conf__.properties" + + // Subdirectory where the user's python files (not archives) will be placed. + val LOCALIZED_PYTHON_DIR = "__pyfiles__" /** * Find the user-defined Spark jar if configured, or return the jar containing this @@ -995,15 +1107,15 @@ object Client extends Logging { env: HashMap[String, String], isAM: Boolean, extraClassPath: Option[String] = None): Unit = { - extraClassPath.foreach(addClasspathEntry(_, env)) - addClasspathEntry( - YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env - ) + extraClassPath.foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } + addClasspathEntry(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env) if (isAM) { addClasspathEntry( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + - LOCALIZED_HADOOP_CONF_DIR, env) + LOCALIZED_CONF_DIR, env) } if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { @@ -1014,12 +1126,14 @@ object Client extends Logging { getUserClasspath(sparkConf) } userClassPath.foreach { x => - addFileToClasspath(x, null, env) + addFileToClasspath(sparkConf, x, null, env) } } - addFileToClasspath(new URI(sparkJar(sparkConf)), SPARK_JAR, env) + addFileToClasspath(sparkConf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) populateHadoopClasspath(conf, env) - sys.env.get(ENV_DIST_CLASSPATH).foreach(addClasspathEntry(_, env)) + sys.env.get(ENV_DIST_CLASSPATH).foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } } /** @@ -1048,16 +1162,18 @@ object Client extends Logging { * * If not a "local:" file and no alternate name, the environment is not modified. * + * @parma conf Spark configuration. * @param uri URI to add to classpath (optional). * @param fileName Alternate name for the file (optional). * @param env Map holding the environment variables. */ private def addFileToClasspath( + conf: SparkConf, uri: URI, fileName: String, env: HashMap[String, String]): Unit = { if (uri != null && uri.getScheme == LOCAL_SCHEME) { - addClasspathEntry(uri.getPath, env) + addClasspathEntry(getClusterPath(conf, uri.getPath), env) } else if (fileName != null) { addClasspathEntry(buildPath( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env) @@ -1071,6 +1187,29 @@ object Client extends Logging { private def addClasspathEntry(path: String, env: HashMap[String, String]): Unit = YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path) + /** + * Returns the path to be sent to the NM for a path that is valid on the gateway. + * + * This method uses two configuration values: + * + * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may + * only be valid in the gateway node. + * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may + * contain, for example, env variable references, which will be expanded by the NMs when + * starting containers. + * + * If either config is not available, the input path is returned. + */ + def getClusterPath(conf: SparkConf, path: String): String = { + val localPath = conf.get("spark.yarn.config.gatewayPath", null) + val clusterPath = conf.get("spark.yarn.config.replacementPath", null) + if (localPath != null && clusterPath != null) { + path.replace(localPath, clusterPath) + } else { + path + } + } + /** * Obtains token for the Hive metastore and adds them to the credentials. */ @@ -1120,9 +1259,9 @@ object Client extends Logging { logDebug("HiveMetaStore configured in localmode") } } catch { - case e:java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } - case e:java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } - case e:Exception => { logError("Unexpected Exception " + e) + case e: java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } + case e: java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } + case e: Exception => { logError("Unexpected Exception " + e) throw new RuntimeException("Unexpected exception", e) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 5653c9f14dc6d..20d63d40cf605 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -30,7 +30,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var archives: String = null var userJar: String = null var userClass: String = null - var pyFiles: String = null + var pyFiles: Seq[String] = Nil var primaryPyFile: String = null var primaryRFile: String = null var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() @@ -46,7 +46,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var keytab: String = null def isClusterMode: Boolean = userClass != null - private var driverMemory: Int = 512 // MB + private var driverMemory: Int = Utils.DEFAULT_DRIVER_MEM_MB // MB private var driverCores: Int = 1 private val driverMemOverheadKey = "spark.yarn.driver.memoryOverhead" private val amMemKey = "spark.yarn.am.memory" @@ -98,6 +98,12 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) numExecutors = initialNumExecutors } + principal = Option(principal) + .orElse(sparkConf.getOption("spark.yarn.principal")) + .orNull + keytab = Option(keytab) + .orElse(sparkConf.getOption("spark.yarn.keytab")) + .orNull } /** @@ -117,6 +123,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new SparkException("Executor cores must not be less than " + "spark.task.cpus.") } + // scalastyle:off println if (isClusterMode) { for (key <- Seq(amMemKey, amMemOverheadKey, amCoresKey)) { if (sparkConf.contains(key)) { @@ -138,11 +145,13 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) .map(_.toInt) .foreach { cores => amCores = cores } } + // scalastyle:on println } private def parseArgs(inputArgs: List[String]): Unit = { var args = inputArgs + // scalastyle:off println while (!args.isEmpty) { args match { case ("--jar") :: value :: tail => @@ -222,7 +231,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) args = tail case ("--py-files") :: value :: tail => - pyFiles = value + pyFiles = value.split(",") args = tail case ("--files") :: value :: tail => @@ -247,6 +256,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new IllegalArgumentException(getUsageMessage(args)) } } + // scalastyle:on println if (primaryPyFile != null && primaryRFile != null) { throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" + @@ -256,8 +266,9 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) private def getUsageMessage(unknownParam: List[String] = null): String = { val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else "" + val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB message + - """ + s""" |Usage: org.apache.spark.deploy.yarn.Client [options] |Options: | --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster @@ -269,7 +280,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores per executor (Default: 1). - | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb) + | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: $mem_mb Mb) | --driver-cores NUM Number of cores used by the driver (Default: 1). | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) | --name NAME The name of your application (Default: Spark) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index c592ecfdfce06..3d3a966960e9f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -43,22 +43,22 @@ private[spark] class ClientDistributedCacheManager() extends Logging { * Add a resource to the list of distributed cache resources. This list can * be sent to the ApplicationMaster and possibly the executors so that it can * be downloaded into the Hadoop distributed cache for use by this application. - * Adds the LocalResource to the localResources HashMap passed in and saves + * Adds the LocalResource to the localResources HashMap passed in and saves * the stats of the resources to they can be sent to the executors and verified. * * @param fs FileSystem * @param conf Configuration * @param destPath path to the resource * @param localResources localResource hashMap to insert the resource into - * @param resourceType LocalResourceType + * @param resourceType LocalResourceType * @param link link presented in the distributed cache to the destination - * @param statCache cache to store the file/directory stats + * @param statCache cache to store the file/directory stats * @param appMasterOnly Whether to only add the resource to the app master */ def addResource( fs: FileSystem, conf: Configuration, - destPath: Path, + destPath: Path, localResources: HashMap[String, LocalResource], resourceType: LocalResourceType, link: String, @@ -74,15 +74,15 @@ private[spark] class ClientDistributedCacheManager() extends Logging { amJarRsrc.setSize(destStatus.getLen()) if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name") localResources(link) = amJarRsrc - + if (!appMasterOnly) { val uri = destPath.toUri() val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link) if (resourceType == LocalResourceType.FILE) { - distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } else { - distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } } @@ -95,13 +95,13 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (keys, tupleValues) = distCacheFiles.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { - env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = - timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = - sizes.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = - visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = + sizes.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -112,13 +112,13 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (keys, tupleValues) = distCacheArchives.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { - env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = - timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = - sizes.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = - visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + sizes.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -160,7 +160,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { def ancestorsHaveExecutePermissions( fs: FileSystem, path: Path, - statCache: Map[URI, FileStatus]): Boolean = { + statCache: Map[URI, FileStatus]): Boolean = { var current = path while (current != null) { // the subdirs in the path should have execute permissions for others @@ -197,7 +197,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = { val stat = statCache.get(uri) match { case Some(existstat) => existstat - case None => + case None => val newStat = fs.getFileStatus(new Path(uri)) statCache.put(uri, newStat) newStat diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala index 229c2c4d5eb36..94feb6393fd69 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala @@ -35,6 +35,9 @@ private[spark] class ExecutorDelegationTokenUpdater( @volatile private var lastCredentialsFileSuffix = 0 private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + private val freshHadoopConf = + SparkHadoopUtil.get.getConfBypassingFSCache( + hadoopConf, new Path(credentialsFile).toUri.getScheme) private val delegationTokenRenewer = Executors.newSingleThreadScheduledExecutor( @@ -49,7 +52,7 @@ private[spark] class ExecutorDelegationTokenUpdater( def updateCredentialsIfRequired(): Unit = { try { val credentialsFilePath = new Path(credentialsFile) - val remoteFs = FileSystem.get(hadoopConf) + val remoteFs = FileSystem.get(freshHadoopConf) SparkHadoopUtil.get.listFilesSorted( remoteFs, credentialsFilePath.getParent, credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 9d04d241dae9e..78e27fb7f3337 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -146,7 +146,7 @@ class ExecutorRunnable( javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.props.get("spark.executor.extraLibraryPath").foreach { p => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p))) + prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) } javaOpts += "-Djava.io.tmpdir=" + @@ -195,7 +195,7 @@ class ExecutorRunnable( val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri => val absPath = if (new File(uri.getPath()).isAbsolute()) { - uri.getPath() + Client.getClusterPath(sparkConf, uri.getPath()) } else { Client.buildPath(Environment.PWD.$(), uri.getPath()) } @@ -303,8 +303,8 @@ class ExecutorRunnable( val address = container.getNodeHttpAddress val baseUrl = s"$httpScheme$address/node/containerlogs/$containerId/$user" - env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=0" - env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=0" + env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=-4096" + env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=-4096" } System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 8a08f561a2df2..940873fbd046c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -34,10 +34,8 @@ import org.apache.hadoop.yarn.util.RackResolver import org.apache.log4j.{Level, Logger} -import org.apache.spark.{SparkEnv, Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.AkkaUtils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -53,6 +51,7 @@ import org.apache.spark.util.AkkaUtils * synchronized. */ private[yarn] class YarnAllocator( + driverUrl: String, conf: Configuration, sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], @@ -107,13 +106,6 @@ private[yarn] class YarnAllocator( new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) launcherPool.allowCoreThreadTimeOut(true) - private val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(securityMgr.akkaSSLOptions.enabled), - SparkEnv.driverActorSystemName, - sparkConf.get("spark.driver.host"), - sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) - // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) @@ -154,11 +146,16 @@ private[yarn] class YarnAllocator( * Request as many executors from the ResourceManager as needed to reach the desired total. If * the requested total is smaller than the current number of running executors, no executors will * be killed. + * + * @return Whether the new requested total is different than the old value. */ - def requestTotalExecutors(requestedTotal: Int): Unit = synchronized { + def requestTotalExecutors(requestedTotal: Int): Boolean = synchronized { if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal + true + } else { + false } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index b134751366522..7f533ee55e8bb 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -55,6 +55,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg * @param uiHistoryAddress Address of the application on the History Server. */ def register( + driverUrl: String, conf: YarnConfiguration, sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], @@ -72,7 +73,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(conf, sparkConf, amClient, getAttemptId(), args, securityMgr) + new YarnAllocator(driverUrl, conf, sparkConf, amClient, getAttemptId(), args, securityMgr) } /** @@ -89,9 +90,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg /** Returns the attempt ID. */ def getAttemptId(): ApplicationAttemptId = { - val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) - val containerId = ConverterUtils.toContainerId(containerIdString) - containerId.getApplicationAttemptId() + YarnSparkHadoopUtil.get.getContainerId.getApplicationAttemptId() } /** Returns the configuration for the AmIpFilter to add to the Spark UI. */ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index ba91872107d0c..68d01c17ef720 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -33,7 +33,8 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.hadoop.yarn.api.records.{Priority, ApplicationAccessType} +import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, Priority} +import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{SecurityManager, SparkConf, SparkException} @@ -136,12 +137,16 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { tokenRenewer.foreach(_.stop()) } + private[spark] def getContainerId: ContainerId = { + val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) + ConverterUtils.toContainerId(containerIdString) + } } object YarnSparkHadoopUtil { - // Additional memory overhead + // Additional memory overhead // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering - // the common cases. Memory overhead tends to grow with container size. + // the common cases. Memory overhead tends to grow with container size. val MEMORY_OVERHEAD_FACTOR = 0.10 val MEMORY_OVERHEAD_MIN = 384 diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 99c05329b4d73..3a0b9443d2d7b 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -41,7 +41,6 @@ private[spark] class YarnClientSchedulerBackend( * This waits until the application is running. */ override def start() { - super.start() val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort @@ -56,6 +55,12 @@ private[spark] class YarnClientSchedulerBackend( totalExpectedExecutors = args.numExecutors client = new Client(args, conf) appId = client.submitApplication() + + // SPARK-8687: Ensure all necessary properties have already been set before + // we initialize our driver scheduler backend, which serves these properties + // to the executors + super.start() + waitForApplication() monitorThread = asyncMonitorApplication() monitorThread.start() @@ -76,7 +81,8 @@ private[spark] class YarnClientSchedulerBackend( ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), - ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue") + ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), + ("--py-files", null, "spark.submit.pyFiles") ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( @@ -86,7 +92,7 @@ private[spark] class YarnClientSchedulerBackend( optionTuples.foreach { case (optionName, envVar, sparkProp) => if (sc.getConf.contains(sparkProp)) { extraArgs += (optionName, sc.getConf.get(sparkProp)) - } else if (System.getenv(envVar) != null) { + } else if (envVar != null && System.getenv(envVar) != null) { extraArgs += (optionName, System.getenv(envVar)) if (deprecatedEnvVars.contains(envVar)) { logWarning(s"NOTE: $envVar is deprecated. Use ${deprecatedEnvVars(envVar)} instead.") @@ -147,7 +153,9 @@ private[spark] class YarnClientSchedulerBackend( */ override def stop() { assert(client != null, "Attempted to stop this scheduler before starting it!") - monitorThread.interrupt() + if (monitorThread != null) { + monitorThread.interrupt() + } super.stop() client.stop() logInfo("Stopped") diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index aeb218a575455..33f580aaebdc0 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -17,10 +17,19 @@ package org.apache.spark.scheduler.cluster +import java.net.NetworkInterface + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.yarn.api.records.NodeState +import org.apache.hadoop.yarn.client.api.YarnClient +import org.apache.hadoop.yarn.conf.YarnConfiguration + import org.apache.spark.SparkContext +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.IntParam +import org.apache.spark.util.{IntParam, Utils} private[spark] class YarnClusterSchedulerBackend( scheduler: TaskSchedulerImpl, @@ -53,4 +62,71 @@ private[spark] class YarnClusterSchedulerBackend( logError("Application attempt ID is not set.") super.applicationAttemptId } + + override def getDriverLogUrls: Option[Map[String, String]] = { + var yarnClientOpt: Option[YarnClient] = None + var driverLogs: Option[Map[String, String]] = None + try { + val yarnConf = new YarnConfiguration(sc.hadoopConfiguration) + val containerId = YarnSparkHadoopUtil.get.getContainerId + yarnClientOpt = Some(YarnClient.createYarnClient()) + yarnClientOpt.foreach { yarnClient => + yarnClient.init(yarnConf) + yarnClient.start() + + // For newer versions of YARN, we can find the HTTP address for a given node by getting a + // container report for a given container. But container reports came only in Hadoop 2.4, + // so we basically have to get the node reports for all nodes and find the one which runs + // this container. For that we have to compare the node's host against the current host. + // Since the host can have multiple addresses, we need to compare against all of them to + // find out if one matches. + + // Get all the addresses of this node. + val addresses = + NetworkInterface.getNetworkInterfaces.asScala + .flatMap(_.getInetAddresses.asScala) + .toSeq + + // Find a node report that matches one of the addresses + val nodeReport = + yarnClient.getNodeReports(NodeState.RUNNING).asScala.find { x => + val host = x.getNodeId.getHost + addresses.exists { address => + address.getHostAddress == host || + address.getHostName == host || + address.getCanonicalHostName == host + } + } + + // Now that we have found the report for the Node Manager that the AM is running on, we + // can get the base HTTP address for the Node manager from the report. + // The format used for the logs for each container is well-known and can be constructed + // using the NM's HTTP address and the container ID. + // The NM may be running several containers, but we can build the URL for the AM using + // the AM's container ID, which we already know. + nodeReport.foreach { report => + val httpAddress = report.getHttpAddress + // lookup appropriate http scheme for container log urls + val yarnHttpPolicy = yarnConf.get( + YarnConfiguration.YARN_HTTP_POLICY_KEY, + YarnConfiguration.YARN_HTTP_POLICY_DEFAULT + ) + val user = Utils.getCurrentUserName() + val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" + logDebug(s"Base URL for logs: $baseUrl") + driverLogs = Some(Map( + "stderr" -> s"$baseUrl/stderr?start=-4096", + "stdout" -> s"$baseUrl/stdout?start=-4096")) + } + } + } catch { + case e: Exception => + logInfo("Node Report API is not available in the version of YARN being used, so AM" + + " logs link will not appear in application UI", e) + } finally { + yarnClientOpt.foreach(_.close()) + } + driverLogs + } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 80b57d1355a3a..804dfecde7867 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy.yarn import java.net.URI -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar import org.mockito.Mockito.when @@ -36,16 +35,18 @@ import org.apache.hadoop.yarn.util.{Records, ConverterUtils} import scala.collection.mutable.HashMap import scala.collection.mutable.Map +import org.apache.spark.SparkFunSuite -class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { + +class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar { class MockClientDistributedCacheManager extends ClientDistributedCacheManager { - override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): + override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): LocalResourceVisibility = { LocalResourceVisibility.PRIVATE } } - + test("test getFileStatus empty") { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] @@ -60,7 +61,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] val uri = new URI("/tmp/testing") - val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", + val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus()) val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus) @@ -77,7 +78,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -100,11 +101,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) // add another one and verify both there and order correct - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing2")) val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2") when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", + distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", statCache, false) val resource2 = localResources("link2") assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -116,7 +117,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val env2 = new HashMap[String, String]() distMgr.setDistFilesEnv(env2) val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') - val files = env2("SPARK_YARN_CACHE_FILES").split(',') + val files = env2("SPARK_YARN_CACHE_FILES").split(',') val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',') assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link") @@ -140,7 +141,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) intercept[Exception] { - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, statCache, false) } assert(localResources.get("link") === None) @@ -154,11 +155,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, true) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -188,11 +189,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 508819e242a26..837f8d3fa55a7 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -33,12 +33,12 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{SparkException, SparkConf} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils -class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { +class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { override def beforeAll(): Unit = { System.setProperty("SPARK_YARN_MODE", "true") @@ -113,7 +113,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { Environment.PWD.$() } cp should contain(pwdVar) - cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_HADOOP_CONF_DIR}") + cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_CONF_DIR}") cp should not contain (Client.SPARK_JAR) cp should not contain (Client.APP_JAR) } @@ -129,7 +129,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { val tempDir = Utils.createTempDir() try { - client.prepareLocalResources(tempDir.getAbsolutePath()) + client.prepareLocalResources(tempDir.getAbsolutePath(), Nil) sparkConf.getOption(Client.CONF_SPARK_USER_JAR) should be (Some(USER)) // The non-local path should be propagated by name only, since it will end up in the app's @@ -151,6 +151,25 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { } } + test("Cluster path translation") { + val conf = new Configuration() + val sparkConf = new SparkConf() + .set(Client.CONF_SPARK_JAR, "local:/localPath/spark.jar") + .set("spark.yarn.config.gatewayPath", "/localPath") + .set("spark.yarn.config.replacementPath", "/remotePath") + + Client.getClusterPath(sparkConf, "/localPath") should be ("/remotePath") + Client.getClusterPath(sparkConf, "/localPath/1:/localPath/2") should be ( + "/remotePath/1:/remotePath/2") + + val env = new MutableHashMap[String, String]() + Client.populateClasspath(null, conf, sparkConf, env, false, + extraClassPath = Some("/localPath/my1.jar")) + val cp = classpath(env) + cp should contain ("/remotePath/spark.jar") + cp should contain ("/remotePath/my1.jar") + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = @@ -203,7 +222,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { def getFieldValue2[A: ClassTag, A1: ClassTag, B]( clazz: Class[_], field: String, - defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { + defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { Try(clazz.getField(field)).map(_.get(null)).map { case v: A => mapTo(v) case v1: A1 => mapTo1(v1) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 455f1019d86dd..7509000771d94 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -26,13 +26,13 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.spark.SecurityManager +import org.apache.spark.{SecurityManager, SparkFunSuite} import org.apache.spark.SparkConf import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ import org.apache.spark.scheduler.SplitInfo -import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} class MockResolver extends DNSToSwitchMapping { @@ -46,7 +46,7 @@ class MockResolver extends DNSToSwitchMapping { def reloadCachedMappings(names: JList[String]) {} } -class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach { +class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { val conf = new Configuration() conf.setClass( CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, @@ -90,6 +90,7 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach "--jar", "somejar.jar", "--class", "SomeClass") new YarnAllocator( + "not used", conf, sparkConf, rmClient, diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index d3c606e0ed998..547863d9a0739 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.net.URL import java.util.Properties import java.util.concurrent.TimeUnit @@ -29,11 +30,12 @@ import com.google.common.io.ByteStreams import com.google.common.io.Files import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils} +import org.apache.spark._ import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListener, SparkListenerExecutorAdded} +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, + SparkListenerExecutorAdded} import org.apache.spark.util.Utils /** @@ -41,7 +43,7 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging { +class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { // log4j configuration for the YARN containers, so that their output is collected // by YARN instead of trying to overwrite unit-tests.log. @@ -54,6 +56,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit """.stripMargin private val TEST_PYFILE = """ + |import mod1, mod2 |import sys |from operator import add | @@ -65,7 +68,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit | sc = SparkContext(conf=SparkConf()) | status = open(sys.argv[1],'w') | result = "failure" - | rdd = sc.parallelize(range(10)) + | rdd = sc.parallelize(range(10)).map(lambda x: x * mod1.func() * mod2.func()) | cnt = rdd.count() | if cnt == 10: | result = "success" @@ -74,6 +77,11 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit | sc.stop() """.stripMargin + private val TEST_PYMODULE = """ + |def func(): + | return 42 + """.stripMargin + private var yarnCluster: MiniYARNCluster = _ private var tempDir: File = _ private var fakeSparkJar: File = _ @@ -122,7 +130,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - hadoopConfDir = new File(tempDir, Client.LOCALIZED_HADOOP_CONF_DIR) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) assert(hadoopConfDir.mkdir()) File.createTempFile("token", ".txt", hadoopConfDir) } @@ -149,26 +157,12 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit } } - // Enable this once fix SPARK-6700 - test("run Python application in yarn-cluster mode") { - val primaryPyFile = new File(tempDir, "test.py") - Files.write(TEST_PYFILE, primaryPyFile, UTF_8) - val pyFile = new File(tempDir, "test2.py") - Files.write(TEST_PYFILE, pyFile, UTF_8) - var result = File.createTempFile("result", null, tempDir) + test("run Python application in yarn-client mode") { + testPySpark(true) + } - // The sbt assembly does not include pyspark / py4j python dependencies, so we need to - // propagate SPARK_HOME so that those are added to PYTHONPATH. See PythonUtils.scala. - val sparkHome = sys.props("spark.test.home") - val extraConf = Map( - "spark.executorEnv.SPARK_HOME" -> sparkHome, - "spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome) - - runSpark(false, primaryPyFile.getAbsolutePath(), - sparkArgs = Seq("--py-files", pyFile.getAbsolutePath()), - appArgs = Seq(result.getAbsolutePath()), - extraConf = extraConf) - checkResult(result) + test("run Python application in yarn-cluster mode") { + testPySpark(false) } test("user class path first in client mode") { @@ -186,6 +180,33 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit checkResult(result) } + private def testPySpark(clientMode: Boolean): Unit = { + val primaryPyFile = new File(tempDir, "test.py") + Files.write(TEST_PYFILE, primaryPyFile, UTF_8) + + val moduleDir = + if (clientMode) { + // In client-mode, .py files added with --py-files are not visible in the driver. + // This is something that the launcher library would have to handle. + tempDir + } else { + val subdir = new File(tempDir, "pyModules") + subdir.mkdir() + subdir + } + val pyModule = new File(moduleDir, "mod1.py") + Files.write(TEST_PYMODULE, pyModule, UTF_8) + + val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir) + val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") + val result = File.createTempFile("result", null, tempDir) + + runSpark(clientMode, primaryPyFile.getAbsolutePath(), + sparkArgs = Seq("--py-files", pyFiles), + appArgs = Seq(result.getAbsolutePath())) + checkResult(result) + } + private def testUseClassPathFirst(clientMode: Boolean): Unit = { // Create a jar file that contains a different version of "test.resource". val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir) @@ -290,10 +311,15 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit private[spark] class SaveExecutorInfo extends SparkListener { val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() + var driverLogs: Option[collection.Map[String, String]] = None override def onExecutorAdded(executor: SparkListenerExecutorAdded) { addedExecutorInfos(executor.executorId) = executor.executorInfo } + + override def onApplicationStart(appStart: SparkListenerApplicationStart): Unit = { + driverLogs = appStart.driverLogs + } } private object YarnClusterDriver extends Logging with Matchers { @@ -302,23 +328,26 @@ private object YarnClusterDriver extends Logging with Matchers { def main(args: Array[String]): Unit = { if (args.length != 1) { + // scalastyle:off println System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | |Usage: YarnClusterDriver [result file] """.stripMargin) + // scalastyle:on println System.exit(1) } val sc = new SparkContext(new SparkConf() .set("spark.extraListeners", classOf[SaveExecutorInfo].getName) .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns")) + val conf = sc.getConf val status = new File(args(0)) var result = "failure" try { val data = sc.parallelize(1 to 4, 4).collect().toSet - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) data should be (Set(1, 2, 3, 4)) result = "success" } finally { @@ -335,6 +364,22 @@ private object YarnClusterDriver extends Logging with Matchers { executorInfos.foreach { info => assert(info.logUrlMap.nonEmpty) } + + // If we are running in yarn-cluster mode, verify that driver logs links and present and are + // in the expected format. + if (conf.get("spark.master") == "yarn-cluster") { + assert(listener.driverLogs.nonEmpty) + val driverLogs = listener.driverLogs.get + assert(driverLogs.size === 2) + assert(driverLogs.containsKey("stderr")) + assert(driverLogs.containsKey("stdout")) + val urlStr = driverLogs("stderr") + // Ensure that this is a valid URL, else this will throw an exception + new URL(urlStr) + val containerId = YarnSparkHadoopUtil.get.getContainerId + val user = Utils.getCurrentUserName() + assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096")) + } } } @@ -343,12 +388,14 @@ private object YarnClasspathTest { def main(args: Array[String]): Unit = { if (args.length != 2) { + // scalastyle:off println System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | |Usage: YarnClasspathTest [driver result file] [executor result file] """.stripMargin) + // scalastyle:on println System.exit(1) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index e10b985c3c236..49bee0866dd43 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -25,15 +25,15 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.apache.hadoop.yarn.api.records.ApplicationAccessType -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils -class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { +class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging { val hasBash = try {