diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000..2b65f6fe3cc80
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,2 @@
+*.bat text eol=crlf
+*.cmd text eol=crlf
diff --git a/.rat-excludes b/.rat-excludes
index ae9745673c87d..20e3372464386 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -1,5 +1,6 @@
target
.gitignore
+.gitattributes
.project
.classpath
.mima-excludes
diff --git a/LICENSE b/LICENSE
index f1732fb47afc0..3c667bf45059a 100644
--- a/LICENSE
+++ b/LICENSE
@@ -754,7 +754,7 @@ SUCH DAMAGE.
========================================================================
-For Timsort (core/src/main/java/org/apache/spark/util/collection/Sorter.java):
+For Timsort (core/src/main/java/org/apache/spark/util/collection/TimSort.java):
========================================================================
Copyright (C) 2008 The Android Open Source Project
@@ -771,6 +771,25 @@ See the License for the specific language governing permissions and
limitations under the License.
+========================================================================
+For LimitedInputStream
+ (network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java):
+========================================================================
+Copyright (C) 2007 The Guava Authors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+
========================================================================
BSD-style licenses
========================================================================
diff --git a/README.md b/README.md
index 9916ac7b1ae8e..8d57d50da96c9 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,8 @@ and Spark Streaming for stream processing.
## Online Documentation
You can find the latest Spark documentation, including a programming
-guide, on the [project web page](http://spark.apache.org/documentation.html).
+guide, on the [project web page](http://spark.apache.org/documentation.html)
+and [project wiki](https://cwiki.apache.org/confluence/display/SPARK).
This README file only contains basic setup instructions.
## Building Spark
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 11d4bea9361ab..31a01e4d8e1de 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -146,10 +146,6 @@
com/google/common/base/Present*
-
- org.apache.commons.math3
- org.spark-project.commons.math3
-
@@ -201,12 +197,6 @@
spark-hive_${scala.binary.version}${project.version}
-
-
-
-
- hive-0.12.0
- org.apache.sparkspark-hive-thriftserver_${scala.binary.version}
diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd
index 3cd0579aea8d3..a4c099fb45b14 100644
--- a/bin/compute-classpath.cmd
+++ b/bin/compute-classpath.cmd
@@ -1,117 +1,117 @@
-@echo off
-
-rem
-rem Licensed to the Apache Software Foundation (ASF) under one or more
-rem contributor license agreements. See the NOTICE file distributed with
-rem this work for additional information regarding copyright ownership.
-rem The ASF licenses this file to You under the Apache License, Version 2.0
-rem (the "License"); you may not use this file except in compliance with
-rem the License. You may obtain a copy of the License at
-rem
-rem http://www.apache.org/licenses/LICENSE-2.0
-rem
-rem Unless required by applicable law or agreed to in writing, software
-rem distributed under the License is distributed on an "AS IS" BASIS,
-rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-rem See the License for the specific language governing permissions and
-rem limitations under the License.
-rem
-
-rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
-rem script and the ExecutorRunner in standalone cluster mode.
-
-rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting
-rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we
-rem need to set it here because we use !datanucleus_jars! below.
-if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion
-setlocal enabledelayedexpansion
-:skip_delayed_expansion
-
-set SCALA_VERSION=2.10
-
-rem Figure out where the Spark framework is installed
-set FWDIR=%~dp0..\
-
-rem Load environment variables from conf\spark-env.cmd, if it exists
-if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
-
-rem Build up classpath
-set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%
-
-if not "x%SPARK_CONF_DIR%"=="x" (
- set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR%
-) else (
- set CLASSPATH=%CLASSPATH%;%FWDIR%conf
-)
-
-if exist "%FWDIR%RELEASE" (
- for %%d in ("%FWDIR%lib\spark-assembly*.jar") do (
- set ASSEMBLY_JAR=%%d
- )
-) else (
- for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do (
- set ASSEMBLY_JAR=%%d
- )
-)
-
-set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR%
-
-rem When Hive support is needed, Datanucleus jars must be included on the classpath.
-rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost.
-rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is
-rem built with Hive, so look for them there.
-if exist "%FWDIR%RELEASE" (
- set datanucleus_dir=%FWDIR%lib
-) else (
- set datanucleus_dir=%FWDIR%lib_managed\jars
-)
-set "datanucleus_jars="
-for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do (
- set datanucleus_jars=!datanucleus_jars!;%%d
-)
-set CLASSPATH=%CLASSPATH%;%datanucleus_jars%
-
-set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes
-
-set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes
-
-if "x%SPARK_TESTING%"=="x1" (
- rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH
- rem so that local compilation takes precedence over assembled jar
- set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH%
-)
-
-rem Add hadoop conf dir - else FileSystem.*, etc fail
-rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
-rem the configurtion files.
-if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
- set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
-:no_hadoop_conf_dir
-
-if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
- set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
-:no_yarn_conf_dir
-
-rem A bit of a hack to allow calling this script within run2.cmd without seeing output
-if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
-
-echo %CLASSPATH%
-
-:exit
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
+rem script and the ExecutorRunner in standalone cluster mode.
+
+rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting
+rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we
+rem need to set it here because we use !datanucleus_jars! below.
+if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion
+setlocal enabledelayedexpansion
+:skip_delayed_expansion
+
+set SCALA_VERSION=2.10
+
+rem Figure out where the Spark framework is installed
+set FWDIR=%~dp0..\
+
+rem Load environment variables from conf\spark-env.cmd, if it exists
+if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
+
+rem Build up classpath
+set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%
+
+if not "x%SPARK_CONF_DIR%"=="x" (
+ set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR%
+) else (
+ set CLASSPATH=%CLASSPATH%;%FWDIR%conf
+)
+
+if exist "%FWDIR%RELEASE" (
+ for %%d in ("%FWDIR%lib\spark-assembly*.jar") do (
+ set ASSEMBLY_JAR=%%d
+ )
+) else (
+ for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do (
+ set ASSEMBLY_JAR=%%d
+ )
+)
+
+set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR%
+
+rem When Hive support is needed, Datanucleus jars must be included on the classpath.
+rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost.
+rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is
+rem built with Hive, so look for them there.
+if exist "%FWDIR%RELEASE" (
+ set datanucleus_dir=%FWDIR%lib
+) else (
+ set datanucleus_dir=%FWDIR%lib_managed\jars
+)
+set "datanucleus_jars="
+for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do (
+ set datanucleus_jars=!datanucleus_jars!;%%d
+)
+set CLASSPATH=%CLASSPATH%;%datanucleus_jars%
+
+set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes
+
+set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes
+
+if "x%SPARK_TESTING%"=="x1" (
+ rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH
+ rem so that local compilation takes precedence over assembled jar
+ set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH%
+)
+
+rem Add hadoop conf dir - else FileSystem.*, etc fail
+rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
+rem the configurtion files.
+if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
+ set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
+:no_hadoop_conf_dir
+
+if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
+ set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
+:no_yarn_conf_dir
+
+rem A bit of a hack to allow calling this script within run2.cmd without seeing output
+if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
+
+echo %CLASSPATH%
+
+:exit
diff --git a/core/pom.xml b/core/pom.xml
index 8020a2daf81ec..41296e0eca330 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -46,7 +46,12 @@
org.apache.spark
- network
+ spark-network-common_2.10
+ ${project.version}
+
+
+ org.apache.spark
+ spark-network-shuffle_2.10${project.version}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
new file mode 100644
index 0000000000000..badd85ed48c82
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Register functions to show/hide columns based on checkboxes. These need
+ * to be registered after the page loads. */
+$(function() {
+ $("span.expand-additional-metrics").click(function(){
+ // Expand the list of additional metrics.
+ var additionalMetricsDiv = $(this).parent().find('.additional-metrics');
+ $(additionalMetricsDiv).toggleClass('collapsed');
+
+ // Switch the class of the arrow from open to closed.
+ $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open');
+ $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed');
+
+ // If clicking caused the metrics to expand, automatically check all options for additional
+ // metrics (don't trigger a click when collapsing metrics, because it leads to weird
+ // toggling behavior).
+ if (!$(additionalMetricsDiv).hasClass('collapsed')) {
+ $(this).parent().find('input:checkbox:not(:checked)').trigger('click');
+ }
+ });
+
+ $("input:checkbox:not(:checked)").each(function() {
+ var column = "table ." + $(this).attr("name");
+ $(column).hide();
+ });
+ // Stripe table rows after rows have been hidden to ensure correct striping.
+ stripeTables();
+
+ $("input:checkbox").click(function() {
+ var column = "table ." + $(this).attr("name");
+ $(column).toggle();
+ stripeTables();
+ });
+
+ // Trigger a click on the checkbox if a user clicks the label next to it.
+ $("span.additional-metric-title").click(function() {
+ $(this).parent().find('input:checkbox').trigger('click');
+ });
+});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js
new file mode 100644
index 0000000000000..6bb03015abb51
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/table.js
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Adds background colors to stripe table rows. This is necessary (instead of using css or the
+ * table striping provided by bootstrap) to appropriately stripe tables with hidden rows. */
+function stripeTables() {
+ $("table.table-striped-custom").each(function() {
+ $(this).find("tr:not(:hidden)").each(function (index) {
+ if (index % 2 == 1) {
+ $(this).css("background-color", "#f9f9f9");
+ } else {
+ $(this).css("background-color", "#ffffff");
+ }
+ });
+ });
+}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index 152bde5f6994f..db57712c83503 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -120,7 +120,51 @@ pre {
border: none;
}
+.stacktrace-details {
+ max-height: 300px;
+ overflow-y: auto;
+ margin: 0;
+ transition: max-height 0.5s ease-out, padding 0.5s ease-out;
+}
+
+.stacktrace-details.collapsed {
+ max-height: 0;
+ padding-top: 0;
+ padding-bottom: 0;
+ border: none;
+}
+
+span.expand-additional-metrics {
+ cursor: pointer;
+}
+
+span.additional-metric-title {
+ cursor: pointer;
+}
+
+.additional-metrics.collapsed {
+ display: none;
+}
+
.tooltip {
font-weight: normal;
}
+.arrow-open {
+ width: 0;
+ height: 0;
+ border-left: 5px solid transparent;
+ border-right: 5px solid transparent;
+ border-top: 5px solid black;
+ float: left;
+ margin-top: 6px;
+}
+
+.arrow-closed {
+ width: 0;
+ height: 0;
+ border-top: 5px solid transparent;
+ border-bottom: 5px solid transparent;
+ border-left: 5px solid black;
+ display: inline-block;
+}
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index b2cf022baf29f..ef93009a074e7 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -66,7 +66,6 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
// Lower and upper bounds on the number of executors. These are required.
private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1)
private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1)
- verifyBounds()
// How long there must be backlogged tasks for before an addition is triggered
private val schedulerBacklogTimeout = conf.getLong(
@@ -77,9 +76,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
"spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout)
// How long an executor must be idle for before it is removed
- private val removeThresholdSeconds = conf.getLong(
+ private val executorIdleTimeout = conf.getLong(
"spark.dynamicAllocation.executorIdleTimeout", 600)
+ // During testing, the methods to actually kill and add executors are mocked out
+ private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
+
+ validateSettings()
+
// Number of executors to add in the next round
private var numExecutorsToAdd = 1
@@ -103,17 +107,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
// Polling loop interval (ms)
private val intervalMillis: Long = 100
- // Whether we are testing this class. This should only be used internally.
- private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
-
// Clock used to schedule when executors should be added and removed
private var clock: Clock = new RealClock
/**
- * Verify that the lower and upper bounds on the number of executors are valid.
+ * Verify that the settings specified through the config are valid.
* If not, throw an appropriate exception.
*/
- private def verifyBounds(): Unit = {
+ private def validateSettings(): Unit = {
if (minNumExecutors < 0 || maxNumExecutors < 0) {
throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!")
}
@@ -124,6 +125,22 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " +
s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!")
}
+ if (schedulerBacklogTimeout <= 0) {
+ throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!")
+ }
+ if (sustainedSchedulerBacklogTimeout <= 0) {
+ throw new SparkException(
+ "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!")
+ }
+ if (executorIdleTimeout <= 0) {
+ throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!")
+ }
+ // Require external shuffle service for dynamic allocation
+ // Otherwise, we may lose shuffle files when killing executors
+ if (!conf.getBoolean("spark.shuffle.service.enabled", false) && !testing) {
+ throw new SparkException("Dynamic allocation of executors requires the external " +
+ "shuffle service. You may enable this through spark.shuffle.service.enabled.")
+ }
}
/**
@@ -254,7 +271,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
val removeRequestAcknowledged = testing || sc.killExecutor(executorId)
if (removeRequestAcknowledged) {
logInfo(s"Removing executor $executorId because it has been idle for " +
- s"$removeThresholdSeconds seconds (new desired total will be ${numExistingExecutors - 1})")
+ s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})")
executorsPendingToRemove.add(executorId)
true
} else {
@@ -329,8 +346,8 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
private def onExecutorIdle(executorId: String): Unit = synchronized {
if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
- s"scheduled to run on the executor (to expire in $removeThresholdSeconds seconds)")
- removeTimes(executorId) = clock.getTimeMillis + removeThresholdSeconds * 1000
+ s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)")
+ removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000
}
}
@@ -419,7 +436,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = {
val executorId = blockManagerAdded.blockManagerId.executorId
- if (executorId != "") {
+ if (executorId != SparkContext.DRIVER_IDENTIFIER) {
allocationManager.onExecutorAdded(executorId)
}
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 4cb0bd4142435..7d96962c4acd7 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -178,6 +178,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
} else {
+ logError("Missing all output locations for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
}
@@ -348,7 +349,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
new ConcurrentHashMap[Int, Array[MapStatus]]
}
-private[spark] object MapOutputTracker {
+private[spark] object MapOutputTracker extends Logging {
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
@@ -381,6 +382,7 @@ private[spark] object MapOutputTracker {
statuses.map {
status =>
if (status == null) {
+ logError("Missing an output location for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
} else {
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 0e0f1a7b2377e..dbff9d12b5ad7 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication}
import org.apache.hadoop.io.Text
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.network.sasl.SecretKeyHolder
/**
* Spark class responsible for security.
@@ -84,7 +85,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* Authenticator installed in the SecurityManager to how it does the authentication
* and in this case gets the user name and password from the request.
*
- * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
+ * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously
* exchange messages. For this we use the Java SASL
* (Simple Authentication and Security Layer) API and again use DIGEST-MD5
* as the authentication mechanism. This means the shared secret is not passed
@@ -98,7 +99,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* of protection they want. If we support those, the messages will also have to
* be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
*
- * Since the connectionManager does asynchronous messages passing, the SASL
+ * Since the NioBlockTransferService does asynchronous messages passing, the SASL
* authentication is a bit more complex. A ConnectionManager can be both a client
* and a Server, so for a particular connection is has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
@@ -107,6 +108,10 @@ import org.apache.spark.deploy.SparkHadoopUtil
* and waits for the response from the server and does the handshake before sending
* the real message.
*
+ * The NettyBlockTransferService ensures that SASL authentication is performed
+ * synchronously prior to any other communication on a connection. This is done in
+ * SaslClientBootstrap on the client side and SaslRpcHandler on the server side.
+ *
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
* properly. For non-Yarn deployments, users can write a filter to go through a
@@ -139,7 +144,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* can take place.
*/
-private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
+private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder {
// key used to store the spark secret in the Hadoop UGI
private val sparkSecretLookupKey = "sparkCookie"
@@ -337,4 +342,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
* @return the secret key as a String if authentication is enabled, otherwise returns null
*/
def getSecretKey(): String = secretKey
+
+ // Default SecurityManager only has a single secret key, so ignore appId.
+ override def getSaslUser(appId: String): String = getSaslUser()
+ override def getSecretKey(appId: String): String = getSecretKey()
}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index ad0a9017afead..4c6c86c7bad78 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -217,6 +217,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
*/
getAll.filter { case (k, _) => isAkkaConf(k) }
+ /**
+ * Returns the Spark application id, valid in the Driver after TaskScheduler registration and
+ * from the start in the Executor.
+ */
+ def getAppId: String = get("spark.app.id")
+
/** Does the configuration contain a given parameter? */
def contains(key: String): Boolean = settings.contains(key)
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 73668e83bbb1d..03ea672c813d1 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -21,9 +21,8 @@ import scala.language.implicitConversions
import java.io._
import java.net.URI
-import java.util.Arrays
+import java.util.{Arrays, Properties, UUID}
import java.util.concurrent.atomic.AtomicInteger
-import java.util.{Properties, UUID}
import java.util.UUID.randomUUID
import scala.collection.{Map, Set}
import scala.collection.generic.Growable
@@ -41,7 +40,8 @@ import akka.actor.Props
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
-import org.apache.spark.input.WholeTextFileInputFormat
+import org.apache.spark.executor.TriggerThreadDump
+import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
@@ -51,7 +51,7 @@ import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage._
import org.apache.spark.ui.SparkUI
import org.apache.spark.ui.jobs.JobProgressListener
-import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils}
+import org.apache.spark.util._
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -313,6 +313,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
val applicationId: String = taskScheduler.applicationId()
conf.set("spark.app.id", applicationId)
+ env.blockManager.initialize(applicationId)
+
val metricsSystem = env.metricsSystem
// The metrics system for Driver need to be set spark.app.id to app ID.
@@ -361,6 +363,29 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
override protected def childValue(parent: Properties): Properties = new Properties(parent)
}
+ /**
+ * Called by the web UI to obtain executor thread dumps. This method may be expensive.
+ * Logs an error and returns None if we failed to obtain a thread dump, which could occur due
+ * to an executor being dead or unresponsive or due to network issues while sending the thread
+ * dump message back to the driver.
+ */
+ private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = {
+ try {
+ if (executorId == SparkContext.DRIVER_IDENTIFIER) {
+ Some(Utils.getThreadDump())
+ } else {
+ val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get
+ val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem)
+ Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef,
+ AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf)))
+ }
+ } catch {
+ case e: Exception =>
+ logError(s"Exception getting thread dump from executor $executorId", e)
+ None
+ }
+ }
+
private[spark] def getLocalProperties: Properties = localProperties.get()
private[spark] def setLocalProperties(props: Properties) {
@@ -533,6 +558,73 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
minPartitions).setName(path)
}
+
+ /**
+ * :: Experimental ::
+ *
+ * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file
+ * (useful for binary data)
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
+ *
+ * @note Small files are preferred; very large files may cause bad performance.
+ */
+ @Experimental
+ def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
+ RDD[(String, PortableDataStream)] = {
+ val job = new NewHadoopJob(hadoopConfiguration)
+ NewFileInputFormat.addInputPath(job, new Path(path))
+ val updateConf = job.getConfiguration
+ new BinaryFileRDD(
+ this,
+ classOf[StreamInputFormat],
+ classOf[String],
+ classOf[PortableDataStream],
+ updateConf,
+ minPartitions).setName(path)
+ }
+
+ /**
+ * :: Experimental ::
+ *
+ * Load data from a flat binary file, assuming the length of each record is constant.
+ *
+ * @param path Directory to the input data files
+ * @param recordLength The length at which to split the records
+ * @return An RDD of data with values, represented as byte arrays
+ */
+ @Experimental
+ def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration)
+ : RDD[Array[Byte]] = {
+ conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength)
+ val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
+ classOf[FixedLengthBinaryInputFormat],
+ classOf[LongWritable],
+ classOf[BytesWritable],
+ conf=conf)
+ val data = br.map{ case (k, v) => v.getBytes}
+ data
+ }
+
/**
* Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other
* necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable),
@@ -1333,6 +1425,8 @@ object SparkContext extends Logging {
private[spark] val SPARK_UNKNOWN_USER = ""
+ private[spark] val DRIVER_IDENTIFIER = ""
+
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 6a6dfda363974..e7454beddbfd0 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
-import org.apache.spark.network.netty.{NettyBlockTransferService}
+import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
@@ -156,7 +156,7 @@ object SparkEnv extends Logging {
assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
val hostname = conf.get("spark.driver.host")
val port = conf.get("spark.driver.port").toInt
- create(conf, "", hostname, port, true, isLocal, listenerBus)
+ create(conf, SparkContext.DRIVER_IDENTIFIER, hostname, port, true, isLocal, listenerBus)
}
/**
@@ -274,9 +274,9 @@ object SparkEnv extends Logging {
val shuffleMemoryManager = new ShuffleMemoryManager(conf)
val blockTransferService =
- conf.get("spark.shuffle.blockTransferService", "nio").toLowerCase match {
+ conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
case "netty" =>
- new NettyBlockTransferService(conf)
+ new NettyBlockTransferService(conf, securityManager)
case "nio" =>
new NioBlockTransferService(conf, securityManager)
}
@@ -285,8 +285,9 @@ object SparkEnv extends Logging {
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver)
+ // NB: blockManager is not valid until initialize() is called later.
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
- serializer, conf, mapOutputTracker, shuffleManager, blockTransferService)
+ serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager)
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 376e69cd997d5..40237596570de 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.mapred._
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
+import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.rdd.HadoopRDD
/**
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
deleted file mode 100644
index a954fcc0c31fa..0000000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.RealmChoiceCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslClient
-import javax.security.sasl.SaslException
-
-import scala.collection.JavaConversions.mapAsJavaMap
-
-import com.google.common.base.Charsets.UTF_8
-
-/**
- * Implements SASL Client logic for Spark
- */
-private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Used to respond to server's counterpart, SaslServer with SASL tokens
- * represented as byte arrays.
- *
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST),
- null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslClientCallbackHandler(securityMgr))
-
- /**
- * Used to initiate SASL handshake with server.
- * @return response to challenge if needed
- */
- def firstToken(): Array[Byte] = {
- synchronized {
- val saslToken: Array[Byte] =
- if (saslClient != null && saslClient.hasInitialResponse()) {
- logDebug("has initial response")
- saslClient.evaluateChallenge(new Array[Byte](0))
- } else {
- new Array[Byte](0)
- }
- saslToken
- }
- }
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslClient != null) saslClient.isComplete() else false
- }
- }
-
- /**
- * Respond to server's SASL token.
- * @param saslTokenMessage contains server's SASL token
- * @return client's response SASL token
- */
- def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslClient might be using.
- */
- def dispose() {
- synchronized {
- if (saslClient != null) {
- try {
- saslClient.dispose()
- } catch {
- case e: SaslException => // ignored
- } finally {
- saslClient = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * that works with share secrets.
- */
- private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends
- CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
- private val secretKey = securityMgr.getSecretKey()
- private val userPassword: Array[Char] = SparkSaslServer.encodePassword(
- if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8))
-
- /**
- * Implementation used to respond to SASL request from the server.
- *
- * @param callbacks objects that indicate what credential information the
- * server's SaslServer requires from the client.
- */
- override def handle(callbacks: Array[Callback]) {
- logDebug("in the sasl client callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL client callback: setting username: " + userName)
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL client callback: setting userPassword")
- pc.setPassword(userPassword)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case cb: RealmChoiceCallback => {}
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback")
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
deleted file mode 100644
index 7c2afb364661f..0000000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
+++ /dev/null
@@ -1,176 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.AuthorizeCallback
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslException
-import javax.security.sasl.SaslServer
-import scala.collection.JavaConversions.mapAsJavaMap
-
-import com.google.common.base.Charsets.UTF_8
-import org.apache.commons.net.util.Base64
-
-/**
- * Encapsulates SASL server logic
- */
-private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Actual SASL work done by this object from javax.security.sasl.
- */
- private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null,
- SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslDigestCallbackHandler(securityMgr))
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslServer != null) saslServer.isComplete() else false
- }
- }
-
- /**
- * Used to respond to server SASL tokens.
- * @param token Server's SASL token
- * @return response to send back to the server.
- */
- def response(token: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslServer might be using.
- */
- def dispose() {
- synchronized {
- if (saslServer != null) {
- try {
- saslServer.dispose()
- } catch {
- case e: SaslException => // ignore
- } finally {
- saslServer = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * for SASL DIGEST-MD5 mechanism
- */
- private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager)
- extends CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
-
- override def handle(callbacks: Array[Callback]) {
- logDebug("In the sasl server callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL server callback: setting username")
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL server callback: setting userPassword")
- val password: Array[Char] =
- SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8))
- pc.setPassword(password)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case ac: AuthorizeCallback => {
- val authid = ac.getAuthenticationID()
- val authzid = ac.getAuthorizationID()
- if (authid.equals(authzid)) {
- logDebug("set auth to true")
- ac.setAuthorized(true)
- } else {
- logDebug("set auth to false")
- ac.setAuthorized(false)
- }
- if (ac.isAuthorized()) {
- logDebug("sasl server is authorized")
- ac.setAuthorizedID(authzid)
- }
- }
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback")
- }
- }
- }
-}
-
-private[spark] object SparkSaslServer {
-
- /**
- * This is passed as the server name when creating the sasl client/server.
- * This could be changed to be configurable in the future.
- */
- val SASL_DEFAULT_REALM = "default"
-
- /**
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- val DIGEST = "DIGEST-MD5"
-
- /**
- * The quality of protection is just "auth". This means that we are doing
- * authentication only, we are not supporting integrity or privacy protection of the
- * communication channel after authentication. This could be changed to be configurable
- * in the future.
- */
- val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true")
-
- /**
- * Encode a byte[] identifier as a Base64-encoded string.
- *
- * @param identifier identifier to encode
- * @return Base64-encoded string
- */
- def encodeIdentifier(identifier: Array[Byte]): String = {
- new String(Base64.encodeBase64(identifier), UTF_8)
- }
-
- /**
- * Encode a password as a base64-encoded char[] array.
- * @param password as a byte array.
- * @return password as a char array.
- */
- def encodePassword(password: Array[Byte]): Array[Char] = {
- new String(Base64.encodeBase64(password), UTF_8).toCharArray()
- }
-}
-
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 8f0c5e78416c2..af5fd8e0ac00c 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -69,11 +69,13 @@ case class FetchFailed(
bmAddress: BlockManagerId, // Note that bmAddress can be null
shuffleId: Int,
mapId: Int,
- reduceId: Int)
+ reduceId: Int,
+ message: String)
extends TaskFailedReason {
override def toErrorString: String = {
val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString
- s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId)"
+ s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " +
+ s"message=\n$message\n)"
}
}
@@ -81,15 +83,48 @@ case class FetchFailed(
* :: DeveloperApi ::
* Task failed due to a runtime exception. This is the most common failure case and also captures
* user program exceptions.
+ *
+ * `stackTrace` contains the stack trace of the exception itself. It still exists for backward
+ * compatibility. It's better to use `this(e: Throwable, metrics: Option[TaskMetrics])` to
+ * create `ExceptionFailure` as it will handle the backward compatibility properly.
+ *
+ * `fullStackTrace` is a better representation of the stack trace because it contains the whole
+ * stack trace including the exception and its causes
*/
@DeveloperApi
case class ExceptionFailure(
className: String,
description: String,
stackTrace: Array[StackTraceElement],
+ fullStackTrace: String,
metrics: Option[TaskMetrics])
extends TaskFailedReason {
- override def toErrorString: String = Utils.exceptionString(className, description, stackTrace)
+
+ private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) {
+ this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics)
+ }
+
+ override def toErrorString: String =
+ if (fullStackTrace == null) {
+ // fullStackTrace is added in 1.2.0
+ // If fullStackTrace is null, use the old error string for backward compatibility
+ exceptionString(className, description, stackTrace)
+ } else {
+ fullStackTrace
+ }
+
+ /**
+ * Return a nice string representation of the exception, including the stack trace.
+ * Note: It does not include the exception's causes, and is only used for backward compatibility.
+ */
+ private def exceptionString(
+ className: String,
+ description: String,
+ stackTrace: Array[StackTraceElement]): String = {
+ val desc = if (description == null) "" else description
+ val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n")
+ s"$className: $desc\n$st"
+ }
}
/**
@@ -117,8 +152,8 @@ case object TaskKilled extends TaskFailedReason {
* the task crashed the JVM.
*/
@DeveloperApi
-case object ExecutorLostFailure extends TaskFailedReason {
- override def toErrorString: String = "ExecutorLostFailure (executor lost)"
+case class ExecutorLostFailure(execId: String) extends TaskFailedReason {
+ override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)"
}
/**
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index efb8978f7ce12..5a8e5bb1f721a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -493,9 +493,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the top K elements from this RDD as defined by
+ * Returns the top k (largest) elements from this RDD as defined by
* the specified Comparator[T].
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @param comp the comparator that defines the order
* @return an array of top elements
*/
@@ -507,9 +507,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the top K elements from this RDD using the
+ * Returns the top k (largest) elements from this RDD using the
* natural ordering for T.
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @return an array of top elements
*/
def top(num: Int): JList[T] = {
@@ -518,9 +518,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the first K elements from this RDD as defined by
+ * Returns the first k (smallest) elements from this RDD as defined by
* the specified Comparator[T] and maintains the order.
- * @param num the number of top elements to return
+ * @param num k, the number of elements to return
* @param comp the comparator that defines the order
* @return an array of top elements
*/
@@ -552,9 +552,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the first K elements from this RDD using the
+ * Returns the first k (smallest) elements from this RDD using the
* natural ordering for T while maintain the order.
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @return an array of top elements
*/
def takeOrdered(num: Int): JList[T] = {
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 0565adf4d4ead..5c6e8d32c5c8a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -28,11 +28,13 @@ import scala.reflect.ClassTag
import com.google.common.base.Optional
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.input.PortableDataStream
import org.apache.hadoop.mapred.{InputFormat, JobConf}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.spark._
-import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam}
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD}
@@ -202,6 +204,8 @@ class JavaSparkContext(val sc: SparkContext)
def textFile(path: String, minPartitions: Int): JavaRDD[String] =
sc.textFile(path, minPartitions)
+
+
/**
* Read a directory of text files from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI. Each file is read as a single record and returned in a
@@ -245,6 +249,84 @@ class JavaSparkContext(val sc: SparkContext)
def wholeTextFiles(path: String): JavaPairRDD[String, String] =
new JavaPairRDD(sc.wholeTextFiles(path))
+ /**
+ * Read a directory of binary files from HDFS, a local file system (available on all nodes),
+ * or any Hadoop-supported file system URI as a byte array. Each file is read as a single
+ * record and returned in a key-value pair, where the key is the path of each file,
+ * the value is the content of each file.
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred; very large files but may cause bad performance.
+ *
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
+ */
+ def binaryFiles(path: String, minPartitions: Int): JavaPairRDD[String, PortableDataStream] =
+ new JavaPairRDD(sc.binaryFiles(path, minPartitions))
+
+ /**
+ * :: Experimental ::
+ *
+ * Read a directory of binary files from HDFS, a local file system (available on all nodes),
+ * or any Hadoop-supported file system URI as a byte array. Each file is read as a single
+ * record and returned in a key-value pair, where the key is the path of each file,
+ * the value is the content of each file.
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred; very large files but may cause bad performance.
+ */
+ @Experimental
+ def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] =
+ new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions))
+
+ /**
+ * :: Experimental ::
+ *
+ * Load data from a flat binary file, assuming the length of each record is constant.
+ *
+ * @param path Directory to the input data files
+ * @return An RDD of data with values, represented as byte arrays
+ */
+ @Experimental
+ def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = {
+ new JavaRDD(sc.binaryRecords(path, recordLength))
+ }
+
/** Get an RDD for a Hadoop SequenceFile with given key and value types.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
index 49dc95f349eac..5ba66178e2b78 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
@@ -61,8 +61,7 @@ private[python] object Converter extends Logging {
* Other objects are passed through without conversion.
*/
private[python] class WritableToJavaConverter(
- conf: Broadcast[SerializableWritable[Configuration]],
- batchSize: Int) extends Converter[Any, Any] {
+ conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] {
/**
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
@@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter(
map.put(convertWritable(k), convertWritable(v))
}
map
- case w: Writable =>
- if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w
+ case w: Writable => WritableUtils.clone(w, conf.value.value)
case other => other
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 61b125ef7c6c1..45beb8fc8c925 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -21,13 +21,13 @@ import java.io._
import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
+import org.apache.spark.input.PortableDataStream
+
import scala.collection.JavaConversions._
-import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials
import com.google.common.base.Charsets.UTF_8
-import net.razorvine.pickle.{Pickler, Unpickler}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
@@ -397,22 +397,33 @@ private[spark] object PythonRDD extends Logging {
newIter.asInstanceOf[Iterator[String]].foreach { str =>
writeUTF(str, dataOut)
}
- case pair: Tuple2[_, _] =>
- pair._1 match {
- case bytePair: Array[Byte] =>
- newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair =>
- dataOut.writeInt(pair._1.length)
- dataOut.write(pair._1)
- dataOut.writeInt(pair._2.length)
- dataOut.write(pair._2)
- }
- case stringPair: String =>
- newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
- writeUTF(pair._1, dataOut)
- writeUTF(pair._2, dataOut)
- }
- case other =>
- throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
+ case stream: PortableDataStream =>
+ newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream =>
+ val bytes = stream.toArray()
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+ case (key: String, stream: PortableDataStream) =>
+ newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach {
+ case (key, stream) =>
+ writeUTF(key, dataOut)
+ val bytes = stream.toArray()
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+ case (key: String, value: String) =>
+ newIter.asInstanceOf[Iterator[(String, String)]].foreach {
+ case (key, value) =>
+ writeUTF(key, dataOut)
+ writeUTF(value, dataOut)
+ }
+ case (key: Array[Byte], value: Array[Byte]) =>
+ newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
+ case (key, value) =>
+ dataOut.writeInt(key.length)
+ dataOut.write(key)
+ dataOut.writeInt(value.length)
+ dataOut.write(value)
}
case other =>
throw new SparkException("Unexpected element type " + first.getClass)
@@ -442,7 +453,7 @@ private[spark] object PythonRDD extends Logging {
val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration()))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -468,7 +479,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -494,7 +505,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -537,7 +548,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -563,7 +574,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -746,104 +757,6 @@ private[spark] object PythonRDD extends Logging {
converted.saveAsHadoopDataset(new JobConf(conf))
}
}
-
-
- /**
- * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
- */
- @deprecated("PySpark does not use it anymore", "1.1")
- def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- SerDeUtil.initialize()
- iter.flatMap { row =>
- unpickle.loads(row) match {
- // in case of objects are pickled in batch mode
- case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
- // not in batch mode
- case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
- }
- }
- }
- }
-
- /**
- * Convert an RDD of serialized Python tuple to Array (no recursive conversions).
- * It is only used by pyspark.sql.
- */
- def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = {
-
- def toArray(obj: Any): Array[_] = {
- obj match {
- case objs: JArrayList[_] =>
- objs.toArray
- case obj if obj.getClass.isArray =>
- obj.asInstanceOf[Array[_]].toArray
- }
- }
-
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- iter.flatMap { row =>
- val obj = unpickle.loads(row)
- if (batched) {
- obj.asInstanceOf[JArrayList[_]].map(toArray)
- } else {
- Seq(toArray(obj))
- }
- }
- }.toJavaRDD()
- }
-
- private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
- private val pickle = new Pickler()
- private var batch = 1
- private val buffer = new mutable.ArrayBuffer[Any]
-
- override def hasNext(): Boolean = iter.hasNext
-
- override def next(): Array[Byte] = {
- while (iter.hasNext && buffer.length < batch) {
- buffer += iter.next()
- }
- val bytes = pickle.dumps(buffer.toArray)
- val size = bytes.length
- // let 1M < size < 10M
- if (size < 1024 * 1024) {
- batch *= 2
- } else if (size > 1024 * 1024 * 10 && batch > 1) {
- batch /= 2
- }
- buffer.clear()
- bytes
- }
- }
-
- /**
- * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
- * PySpark.
- */
- def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
- jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
- }
-
- /**
- * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
- */
- def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
- pyRDD.rdd.mapPartitions { iter =>
- SerDeUtil.initialize()
- val unpickle = new Unpickler
- iter.flatMap { row =>
- val obj = unpickle.loads(row)
- if (batched) {
- obj.asInstanceOf[JArrayList[_]].asScala
- } else {
- Seq(obj)
- }
- }
- }.toJavaRDD()
- }
}
private
diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
index ebdc3533e0992..a4153aaa926f8 100644
--- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -18,8 +18,13 @@
package org.apache.spark.api.python
import java.nio.ByteOrder
+import java.util.{ArrayList => JArrayList}
+
+import org.apache.spark.api.java.JavaRDD
import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.util.Failure
import scala.util.Try
@@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging {
}
initialize()
+
+ /**
+ * Convert an RDD of Java objects to Array (no recursive conversions).
+ * It is only used by pyspark.sql.
+ */
+ def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = {
+ jrdd.rdd.map {
+ case objs: JArrayList[_] =>
+ objs.toArray
+ case obj if obj.getClass.isArray =>
+ obj.asInstanceOf[Array[_]].toArray
+ }.toJavaRDD()
+ }
+
+ /**
+ * Choose batch size based on size of objects
+ */
+ private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
+ private val pickle = new Pickler()
+ private var batch = 1
+ private val buffer = new mutable.ArrayBuffer[Any]
+
+ override def hasNext: Boolean = iter.hasNext
+
+ override def next(): Array[Byte] = {
+ while (iter.hasNext && buffer.length < batch) {
+ buffer += iter.next()
+ }
+ val bytes = pickle.dumps(buffer.toArray)
+ val size = bytes.length
+ // let 1M < size < 10M
+ if (size < 1024 * 1024) {
+ batch *= 2
+ } else if (size > 1024 * 1024 * 10 && batch > 1) {
+ batch /= 2
+ }
+ buffer.clear()
+ bytes
+ }
+ }
+
+ /**
+ * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
+ * PySpark.
+ */
+ private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
+ jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
+ }
+
+ /**
+ * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
+ */
+ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ initialize()
+ val unpickle = new Unpickler
+ iter.flatMap { row =>
+ val obj = unpickle.loads(row)
+ if (batched) {
+ obj.asInstanceOf[JArrayList[_]].asScala
+ } else {
+ Seq(obj)
+ }
+ }
+ }.toJavaRDD()
+ }
+
private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
val pickle = new Pickler
val kt = Try {
@@ -128,17 +200,18 @@ private[spark] object SerDeUtil extends Logging {
*/
def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = {
val (keyFailed, valueFailed) = checkPickle(rdd.first())
+
rdd.mapPartitions { iter =>
- val pickle = new Pickler
val cleaned = iter.map { case (k, v) =>
val key = if (keyFailed) k.toString else k
val value = if (valueFailed) v.toString else v
Array[Any](key, value)
}
- if (batchSize > 1) {
- cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
+ if (batchSize == 0) {
+ new AutoBatchedPickler(cleaned)
} else {
- cleaned.map(pickle.dumps(_))
+ val pickle = new Pickler
+ cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
}
}
}
@@ -146,36 +219,22 @@ private[spark] object SerDeUtil extends Logging {
/**
* Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)].
*/
- def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = {
+ def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = {
def isPair(obj: Any): Boolean = {
- Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) &&
+ Option(obj.getClass.getComponentType).exists(!_.isPrimitive) &&
obj.asInstanceOf[Array[_]].length == 2
}
- pyRDD.mapPartitions { iter =>
- initialize()
- val unpickle = new Unpickler
- val unpickled =
- if (batchSerialized) {
- iter.flatMap { batch =>
- unpickle.loads(batch) match {
- case objs: java.util.List[_] => collectionAsScalaIterable(objs)
- case other => throw new SparkException(
- s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD")
- }
- }
- } else {
- iter.map(unpickle.loads(_))
- }
- unpickled.map {
- case obj if isPair(obj) =>
- // we only accept (K, V)
- val arr = obj.asInstanceOf[Array[_]]
- (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
- case other => throw new SparkException(
- s"RDD element of type ${other.getClass.getName} cannot be used")
- }
+
+ val rdd = pythonToJava(pyRDD, batched).rdd
+ rdd.first match {
+ case obj if isPair(obj) =>
+ // we only accept (K, V)
+ case other => throw new SparkException(
+ s"RDD element of type ${other.getClass.getName} cannot be used")
+ }
+ rdd.map { obj =>
+ val arr = obj.asInstanceOf[Array[_]]
+ (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
}
}
-
}
-
diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
index e9ca9166eb4d6..c0cbd28a845be 100644
--- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
@@ -176,11 +176,11 @@ object WriteInputFormatTestDataGenerator {
// Create test data for arbitrary custom writable TestWritable
val testClass = Seq(
- ("1", TestWritable("test1", 123, 54.0)),
- ("2", TestWritable("test2", 456, 8762.3)),
- ("1", TestWritable("test3", 123, 423.1)),
- ("3", TestWritable("test56", 456, 423.5)),
- ("2", TestWritable("test2", 123, 5435.2))
+ ("1", TestWritable("test1", 1, 1.0)),
+ ("2", TestWritable("test2", 2, 2.3)),
+ ("3", TestWritable("test3", 3, 3.1)),
+ ("5", TestWritable("test56", 5, 5.5)),
+ ("4", TestWritable("test4", 4, 4.2))
)
val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) }
rdd.saveAsNewAPIHadoopFile(classPath,
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index e28eaad8a5180..60ee115e393ce 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,6 +17,7 @@
package org.apache.spark.deploy
+import java.lang.reflect.Method
import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
@@ -133,14 +134,9 @@ class SparkHadoopUtil extends Logging {
*/
private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration)
: Option[() => Long] = {
- val qualifiedPath = path.getFileSystem(conf).makeQualified(path)
- val scheme = qualifiedPath.toUri().getScheme()
- val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme))
try {
- val threadStats = stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics"))
- val statisticsDataClass =
- Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
- val getBytesReadMethod = statisticsDataClass.getDeclaredMethod("getBytesRead")
+ val threadStats = getFileSystemThreadStatistics(path, conf)
+ val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead")
val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum
val baselineBytesRead = f()
Some(() => f() - baselineBytesRead)
@@ -151,6 +147,42 @@ class SparkHadoopUtil extends Logging {
}
}
}
+
+ /**
+ * Returns a function that can be called to find Hadoop FileSystem bytes written. If
+ * getFSBytesWrittenOnThreadCallback is called from thread r at time t, the returned callback will
+ * return the bytes written on r since t. Reflection is required because thread-level FileSystem
+ * statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
+ * Returns None if the required method can't be found.
+ */
+ private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration)
+ : Option[() => Long] = {
+ try {
+ val threadStats = getFileSystemThreadStatistics(path, conf)
+ val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten")
+ val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum
+ val baselineBytesWritten = f()
+ Some(() => f() - baselineBytesWritten)
+ } catch {
+ case e: NoSuchMethodException => {
+ logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e)
+ None
+ }
+ }
+ }
+
+ private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = {
+ val qualifiedPath = path.getFileSystem(conf).makeQualified(path)
+ val scheme = qualifiedPath.toUri().getScheme()
+ val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme))
+ stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics"))
+ }
+
+ private def getFileSystemThreadStatisticsMethod(methodName: String): Method = {
+ val statisticsDataClass =
+ Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
+ statisticsDataClass.getDeclaredMethod(methodName)
+ }
}
object SparkHadoopUtil {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index f97bf67fa5a3b..b43e68e40f791 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -158,8 +158,9 @@ object SparkSubmit {
args.files = mergeFileLists(args.files, args.primaryResource)
}
args.files = mergeFileLists(args.files, args.pyFiles)
- // Format python file paths properly before adding them to the PYTHONPATH
- sysProps("spark.submit.pyFiles") = PythonRunner.formatPaths(args.pyFiles).mkString(",")
+ if (args.pyFiles != null) {
+ sysProps("spark.submit.pyFiles") = args.pyFiles
+ }
}
// Special flag to avoid deprecation warnings at the client
@@ -273,15 +274,32 @@ object SparkSubmit {
}
}
- // Properties given with --conf are superceded by other options, but take precedence over
- // properties in the defaults file.
+ // Load any properties specified through --conf and the default properties file
for ((k, v) <- args.sparkProperties) {
sysProps.getOrElseUpdate(k, v)
}
- // Read from default spark properties, if any
- for ((k, v) <- args.defaultSparkProperties) {
- sysProps.getOrElseUpdate(k, v)
+ // Resolve paths in certain spark properties
+ val pathConfigs = Seq(
+ "spark.jars",
+ "spark.files",
+ "spark.yarn.jar",
+ "spark.yarn.dist.files",
+ "spark.yarn.dist.archives")
+ pathConfigs.foreach { config =>
+ // Replace old URIs with resolved URIs, if they exist
+ sysProps.get(config).foreach { oldValue =>
+ sysProps(config) = Utils.resolveURIs(oldValue)
+ }
+ }
+
+ // Resolve and format python file paths properly before adding them to the PYTHONPATH.
+ // The resolving part is redundant in the case of --py-files, but necessary if the user
+ // explicitly sets `spark.submit.pyFiles` in his/her default properties file.
+ sysProps.get("spark.submit.pyFiles").foreach { pyFiles =>
+ val resolvedPyFiles = Utils.resolveURIs(pyFiles)
+ val formattedPyFiles = PythonRunner.formatPaths(resolvedPyFiles).mkString(",")
+ sysProps("spark.submit.pyFiles") = formattedPyFiles
}
(childArgs, childClasspath, sysProps, childMainClass)
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 72a452e0aefb5..f0e9ee67f6a67 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -19,7 +19,6 @@ package org.apache.spark.deploy
import java.util.jar.JarFile
-import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import org.apache.spark.util.Utils
@@ -72,39 +71,54 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
defaultProperties
}
- // Respect SPARK_*_MEMORY for cluster mode
- driverMemory = sys.env.get("SPARK_DRIVER_MEMORY").orNull
- executorMemory = sys.env.get("SPARK_EXECUTOR_MEMORY").orNull
-
+ // Set parameters from command line arguments
parseOpts(args.toList)
- mergeSparkProperties()
+ // Populate `sparkProperties` map from properties file
+ mergeDefaultSparkProperties()
+ // Use `sparkProperties` map along with env vars to fill in any missing parameters
+ loadEnvironmentArguments()
+
checkRequiredArguments()
/**
- * Fill in any undefined values based on the default properties file or options passed in through
- * the '--conf' flag.
+ * Merge values from the default properties file with those specified through --conf.
+ * When this is called, `sparkProperties` is already filled with configs from the latter.
*/
- private def mergeSparkProperties(): Unit = {
+ private def mergeDefaultSparkProperties(): Unit = {
// Use common defaults file, if not specified by user
propertiesFile = Option(propertiesFile).getOrElse(Utils.getDefaultPropertiesFile(env))
+ // Honor --conf before the defaults file
+ defaultSparkProperties.foreach { case (k, v) =>
+ if (!sparkProperties.contains(k)) {
+ sparkProperties(k) = v
+ }
+ }
+ }
- val properties = HashMap[String, String]()
- properties.putAll(defaultSparkProperties)
- properties.putAll(sparkProperties)
-
- // Use properties file as fallback for values which have a direct analog to
- // arguments in this script.
- master = Option(master).orElse(properties.get("spark.master")).orNull
- executorMemory = Option(executorMemory).orElse(properties.get("spark.executor.memory")).orNull
- executorCores = Option(executorCores).orElse(properties.get("spark.executor.cores")).orNull
+ /**
+ * Load arguments from environment variables, Spark properties etc.
+ */
+ private def loadEnvironmentArguments(): Unit = {
+ master = Option(master)
+ .orElse(sparkProperties.get("spark.master"))
+ .orElse(env.get("MASTER"))
+ .orNull
+ driverMemory = Option(driverMemory)
+ .orElse(sparkProperties.get("spark.driver.memory"))
+ .orElse(env.get("SPARK_DRIVER_MEMORY"))
+ .orNull
+ executorMemory = Option(executorMemory)
+ .orElse(sparkProperties.get("spark.executor.memory"))
+ .orElse(env.get("SPARK_EXECUTOR_MEMORY"))
+ .orNull
+ executorCores = Option(executorCores)
+ .orElse(sparkProperties.get("spark.executor.cores"))
+ .orNull
totalExecutorCores = Option(totalExecutorCores)
- .orElse(properties.get("spark.cores.max"))
+ .orElse(sparkProperties.get("spark.cores.max"))
.orNull
- name = Option(name).orElse(properties.get("spark.app.name")).orNull
- jars = Option(jars).orElse(properties.get("spark.jars")).orNull
-
- // This supports env vars in older versions of Spark
- master = Option(master).orElse(env.get("MASTER")).orNull
+ name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull
+ jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull
deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull
// Try to set main class from JAR if no --class argument is given
@@ -131,7 +145,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
/** Ensure that required fields exists. Call this only once all defaults are loaded. */
- private def checkRequiredArguments() = {
+ private def checkRequiredArguments(): Unit = {
if (args.length == 0) {
printUsageAndExit(-1)
}
@@ -166,7 +180,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
}
- override def toString = {
+ override def toString = {
s"""Parsed arguments:
| master $master
| deployMode $deployMode
@@ -174,7 +188,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| executorCores $executorCores
| totalExecutorCores $totalExecutorCores
| propertiesFile $propertiesFile
- | extraSparkProperties $sparkProperties
| driverMemory $driverMemory
| driverCores $driverCores
| driverExtraClassPath $driverExtraClassPath
@@ -193,8 +206,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| jars $jars
| verbose $verbose
|
- |Default properties from $propertiesFile:
- |${defaultSparkProperties.mkString(" ", "\n ", "\n")}
+ |Spark properties used, including those specified through
+ | --conf and those from the properties file $propertiesFile:
+ |${sparkProperties.mkString(" ", "\n ", "\n")}
""".stripMargin
}
@@ -327,7 +341,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
}
- private def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = {
val outStream = SparkSubmit.printStream
if (unknownParam != null) {
outStream.println("Unknown/unsupported param " + unknownParam)
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
index d25c29113d6da..0e249e51a77d8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -84,11 +84,11 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index 6ba395be1cc2c..ad7d81747c377 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import akka.actor.ActorRef
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.ApplicationDescription
import org.apache.spark.util.Utils
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
index 2ac21186881fa..9d3d7938c6ccb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
@@ -19,6 +19,7 @@ package org.apache.spark.deploy.master
import java.util.Date
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.DriverDescription
import org.apache.spark.util.Utils
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index 08a99bbe68578..6ff2aa5244847 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -18,10 +18,12 @@
package org.apache.spark.deploy.master
import java.io._
-
-import akka.serialization.Serialization
+import java.nio.ByteBuffer
import org.apache.spark.Logging
+import org.apache.spark.serializer.Serializer
+
+import scala.reflect.ClassTag
/**
* Stores data in a single on-disk directory with one file per application and worker.
@@ -32,65 +34,39 @@ import org.apache.spark.Logging
*/
private[spark] class FileSystemPersistenceEngine(
val dir: String,
- val serialization: Serialization)
+ val serialization: Serializer)
extends PersistenceEngine with Logging {
+ val serializer = serialization.newInstance()
new File(dir).mkdir()
- override def addApplication(app: ApplicationInfo) {
- val appFile = new File(dir + File.separator + "app_" + app.id)
- serializeIntoFile(appFile, app)
- }
-
- override def removeApplication(app: ApplicationInfo) {
- new File(dir + File.separator + "app_" + app.id).delete()
- }
-
- override def addDriver(driver: DriverInfo) {
- val driverFile = new File(dir + File.separator + "driver_" + driver.id)
- serializeIntoFile(driverFile, driver)
+ override def persist(name: String, obj: Object): Unit = {
+ serializeIntoFile(new File(dir + File.separator + name), obj)
}
- override def removeDriver(driver: DriverInfo) {
- new File(dir + File.separator + "driver_" + driver.id).delete()
+ override def unpersist(name: String): Unit = {
+ new File(dir + File.separator + name).delete()
}
- override def addWorker(worker: WorkerInfo) {
- val workerFile = new File(dir + File.separator + "worker_" + worker.id)
- serializeIntoFile(workerFile, worker)
- }
-
- override def removeWorker(worker: WorkerInfo) {
- new File(dir + File.separator + "worker_" + worker.id).delete()
- }
-
- override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
- val sortedFiles = new File(dir).listFiles().sortBy(_.getName)
- val appFiles = sortedFiles.filter(_.getName.startsWith("app_"))
- val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
- val driverFiles = sortedFiles.filter(_.getName.startsWith("driver_"))
- val drivers = driverFiles.map(deserializeFromFile[DriverInfo])
- val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_"))
- val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
- (apps, drivers, workers)
+ override def read[T: ClassTag](prefix: String) = {
+ val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix))
+ files.map(deserializeFromFile[T])
}
private def serializeIntoFile(file: File, value: AnyRef) {
val created = file.createNewFile()
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
- val serializer = serialization.findSerializerFor(value)
- val serialized = serializer.toBinary(value)
-
- val out = new FileOutputStream(file)
+ val out = serializer.serializeStream(new FileOutputStream(file))
try {
- out.write(serialized)
+ out.writeObject(value)
} finally {
out.close()
}
+
}
- def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
+ def deserializeFromFile[T](file: File): T = {
val fileData = new Array[Byte](file.length().asInstanceOf[Int])
val dis = new DataInputStream(new FileInputStream(file))
try {
@@ -99,8 +75,6 @@ private[spark] class FileSystemPersistenceEngine(
dis.close()
}
- val clazz = m.runtimeClass.asInstanceOf[Class[T]]
- val serializer = serialization.serializerFor(clazz)
- serializer.fromBinary(fileData).asInstanceOf[T]
+ serializer.deserializeStream(dis).readObject()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
index 4433a2ec29be6..cf77c86d760cf 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -17,30 +17,27 @@
package org.apache.spark.deploy.master
-import akka.actor.{Actor, ActorRef}
-
-import org.apache.spark.deploy.master.MasterMessages.ElectedLeader
+import org.apache.spark.annotation.DeveloperApi
/**
- * A LeaderElectionAgent keeps track of whether the current Master is the leader, meaning it
- * is the only Master serving requests.
- * In addition to the API provided, the LeaderElectionAgent will use of the following messages
- * to inform the Master of leader changes:
- * [[org.apache.spark.deploy.master.MasterMessages.ElectedLeader ElectedLeader]]
- * [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]]
+ * :: DeveloperApi ::
+ *
+ * A LeaderElectionAgent tracks current master and is a common interface for all election Agents.
*/
-private[spark] trait LeaderElectionAgent extends Actor {
- // TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring.
- val masterActor: ActorRef
+@DeveloperApi
+trait LeaderElectionAgent {
+ val masterActor: LeaderElectable
+ def stop() {} // to avoid noops in implementations.
}
-/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
-private[spark] class MonarchyLeaderAgent(val masterActor: ActorRef) extends LeaderElectionAgent {
- override def preStart() {
- masterActor ! ElectedLeader
- }
+@DeveloperApi
+trait LeaderElectable {
+ def electedLeader()
+ def revokedLeadership()
+}
- override def receive = {
- case _ =>
- }
+/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
+private[spark] class MonarchyLeaderAgent(val masterActor: LeaderElectable)
+ extends LeaderElectionAgent {
+ masterActor.electedLeader()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 2f81d472d7b78..021454e25804c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -50,7 +50,7 @@ private[spark] class Master(
port: Int,
webUiPort: Int,
val securityMgr: SecurityManager)
- extends Actor with ActorLogReceive with Logging {
+ extends Actor with ActorLogReceive with Logging with LeaderElectable {
import context.dispatcher // to use Akka's scheduler.schedule()
@@ -61,7 +61,6 @@ private[spark] class Master(
val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200)
val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200)
val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15)
- val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE")
val workers = new HashSet[WorkerInfo]
@@ -103,7 +102,7 @@ private[spark] class Master(
var persistenceEngine: PersistenceEngine = _
- var leaderElectionAgent: ActorRef = _
+ var leaderElectionAgent: LeaderElectionAgent = _
private var recoveryCompletionTask: Cancellable = _
@@ -130,23 +129,24 @@ private[spark] class Master(
masterMetricsSystem.start()
applicationMetricsSystem.start()
- persistenceEngine = RECOVERY_MODE match {
+ val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match {
case "ZOOKEEPER" =>
logInfo("Persisting recovery state to ZooKeeper")
- new ZooKeeperPersistenceEngine(SerializationExtension(context.system), conf)
+ val zkFactory = new ZooKeeperRecoveryModeFactory(conf)
+ (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this))
case "FILESYSTEM" =>
- logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
- new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system))
+ val fsFactory = new FileSystemRecoveryModeFactory(conf)
+ (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
+ case "CUSTOM" =>
+ val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory"))
+ val factory = clazz.getConstructor(conf.getClass)
+ .newInstance(conf).asInstanceOf[StandaloneRecoveryModeFactory]
+ (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
case _ =>
- new BlackHolePersistenceEngine()
+ (new BlackHolePersistenceEngine(), new MonarchyLeaderAgent(this))
}
-
- leaderElectionAgent = RECOVERY_MODE match {
- case "ZOOKEEPER" =>
- context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl, conf))
- case _ =>
- context.actorOf(Props(classOf[MonarchyLeaderAgent], self))
- }
+ persistenceEngine = persistenceEngine_
+ leaderElectionAgent = leaderElectionAgent_
}
override def preRestart(reason: Throwable, message: Option[Any]) {
@@ -165,7 +165,15 @@ private[spark] class Master(
masterMetricsSystem.stop()
applicationMetricsSystem.stop()
persistenceEngine.close()
- context.stop(leaderElectionAgent)
+ leaderElectionAgent.stop()
+ }
+
+ override def electedLeader() {
+ self ! ElectedLeader
+ }
+
+ override def revokedLeadership() {
+ self ! RevokedLeadership
}
override def receiveWithLogging = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
index e3640ea4f7e64..2e0e1e7036ac8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -17,6 +17,10 @@
package org.apache.spark.deploy.master
+import org.apache.spark.annotation.DeveloperApi
+
+import scala.reflect.ClassTag
+
/**
* Allows Master to persist any state that is necessary in order to recover from a failure.
* The following semantics are required:
@@ -25,36 +29,70 @@ package org.apache.spark.deploy.master
* Given these two requirements, we will have all apps and workers persisted, but
* we might not have yet deleted apps or workers that finished (so their liveness must be verified
* during recovery).
+ *
+ * The implementation of this trait defines how name-object pairs are stored or retrieved.
*/
-private[spark] trait PersistenceEngine {
- def addApplication(app: ApplicationInfo)
+@DeveloperApi
+trait PersistenceEngine {
- def removeApplication(app: ApplicationInfo)
+ /**
+ * Defines how the object is serialized and persisted. Implementation will
+ * depend on the store used.
+ */
+ def persist(name: String, obj: Object)
- def addWorker(worker: WorkerInfo)
+ /**
+ * Defines how the object referred by its name is removed from the store.
+ */
+ def unpersist(name: String)
- def removeWorker(worker: WorkerInfo)
+ /**
+ * Gives all objects, matching a prefix. This defines how objects are
+ * read/deserialized back.
+ */
+ def read[T: ClassTag](prefix: String): Seq[T]
- def addDriver(driver: DriverInfo)
+ final def addApplication(app: ApplicationInfo): Unit = {
+ persist("app_" + app.id, app)
+ }
- def removeDriver(driver: DriverInfo)
+ final def removeApplication(app: ApplicationInfo): Unit = {
+ unpersist("app_" + app.id)
+ }
+
+ final def addWorker(worker: WorkerInfo): Unit = {
+ persist("worker_" + worker.id, worker)
+ }
+
+ final def removeWorker(worker: WorkerInfo): Unit = {
+ unpersist("worker_" + worker.id)
+ }
+
+ final def addDriver(driver: DriverInfo): Unit = {
+ persist("driver_" + driver.id, driver)
+ }
+
+ final def removeDriver(driver: DriverInfo): Unit = {
+ unpersist("driver_" + driver.id)
+ }
/**
* Returns the persisted data sorted by their respective ids (which implies that they're
* sorted by time of creation).
*/
- def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo])
+ final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
+ (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
+ }
def close() {}
}
private[spark] class BlackHolePersistenceEngine extends PersistenceEngine {
- override def addApplication(app: ApplicationInfo) {}
- override def removeApplication(app: ApplicationInfo) {}
- override def addWorker(worker: WorkerInfo) {}
- override def removeWorker(worker: WorkerInfo) {}
- override def addDriver(driver: DriverInfo) {}
- override def removeDriver(driver: DriverInfo) {}
-
- override def readPersistedData() = (Nil, Nil, Nil)
+
+ override def persist(name: String, obj: Object): Unit = {}
+
+ override def unpersist(name: String): Unit = {}
+
+ override def read[T: ClassTag](name: String): Seq[T] = Nil
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
new file mode 100644
index 0000000000000..d9d36c1ed5f9f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.serializer.JavaSerializer
+
+/**
+ * ::DeveloperApi::
+ *
+ * Implementation of this class can be plugged in as recovery mode alternative for Spark's
+ * Standalone mode.
+ *
+ */
+@DeveloperApi
+abstract class StandaloneRecoveryModeFactory(conf: SparkConf) {
+
+ /**
+ * PersistenceEngine defines how the persistent data(Information about worker, driver etc..)
+ * is handled for recovery.
+ *
+ */
+ def createPersistenceEngine(): PersistenceEngine
+
+ /**
+ * Create an instance of LeaderAgent that decides who gets elected as master.
+ */
+ def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent
+}
+
+/**
+ * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual
+ * recovery is made by restoring from filesystem.
+ */
+private[spark] class FileSystemRecoveryModeFactory(conf: SparkConf)
+ extends StandaloneRecoveryModeFactory(conf) with Logging {
+ val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
+
+ def createPersistenceEngine() = {
+ logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
+ new FileSystemPersistenceEngine(RECOVERY_DIR, new JavaSerializer(conf))
+ }
+
+ def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master)
+}
+
+private[spark] class ZooKeeperRecoveryModeFactory(conf: SparkConf)
+ extends StandaloneRecoveryModeFactory(conf) {
+ def createPersistenceEngine() = new ZooKeeperPersistenceEngine(new JavaSerializer(conf), conf)
+
+ def createLeaderElectionAgent(master: LeaderElectable) =
+ new ZooKeeperLeaderElectionAgent(master, conf)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index d221b0f6cc86b..473ddc23ff0f3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
import akka.actor.ActorRef
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
private[spark] class WorkerInfo(
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
index 285f9b014e291..8eaa0ad948519 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -24,9 +24,8 @@ import org.apache.spark.deploy.master.MasterMessages._
import org.apache.curator.framework.CuratorFramework
import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch}
-private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
- masterUrl: String, conf: SparkConf)
- extends LeaderElectionAgent with LeaderLatchListener with Logging {
+private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable,
+ conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging {
val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
@@ -34,30 +33,21 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
private var leaderLatch: LeaderLatch = _
private var status = LeadershipStatus.NOT_LEADER
- override def preStart() {
+ start()
+ def start() {
logInfo("Starting ZooKeeper LeaderElection agent")
zk = SparkCuratorUtil.newClient(conf)
leaderLatch = new LeaderLatch(zk, WORKING_DIR)
leaderLatch.addListener(this)
-
leaderLatch.start()
}
- override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) {
- logError("LeaderElectionAgent failed...", reason)
- super.preRestart(reason, message)
- }
-
- override def postStop() {
+ override def stop() {
leaderLatch.close()
zk.close()
}
- override def receive = {
- case _ =>
- }
-
override def isLeader() {
synchronized {
// could have lost leadership by now.
@@ -85,10 +75,10 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
def updateLeadershipStatus(isLeader: Boolean) {
if (isLeader && status == LeadershipStatus.NOT_LEADER) {
status = LeadershipStatus.LEADER
- masterActor ! ElectedLeader
+ masterActor.electedLeader()
} else if (!isLeader && status == LeadershipStatus.LEADER) {
status = LeadershipStatus.NOT_LEADER
- masterActor ! RevokedLeadership
+ masterActor.revokedLeadership()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 834dfedee52ce..96c2139eb02f0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -19,72 +19,54 @@ package org.apache.spark.deploy.master
import scala.collection.JavaConversions._
-import akka.serialization.Serialization
import org.apache.curator.framework.CuratorFramework
import org.apache.zookeeper.CreateMode
import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.serializer.Serializer
+import java.nio.ByteBuffer
-class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
+import scala.reflect.ClassTag
+
+
+private[spark] class ZooKeeperPersistenceEngine(val serialization: Serializer, conf: SparkConf)
extends PersistenceEngine
with Logging
{
val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
val zk: CuratorFramework = SparkCuratorUtil.newClient(conf)
- SparkCuratorUtil.mkdir(zk, WORKING_DIR)
-
- override def addApplication(app: ApplicationInfo) {
- serializeIntoFile(WORKING_DIR + "/app_" + app.id, app)
- }
+ val serializer = serialization.newInstance()
- override def removeApplication(app: ApplicationInfo) {
- zk.delete().forPath(WORKING_DIR + "/app_" + app.id)
- }
+ SparkCuratorUtil.mkdir(zk, WORKING_DIR)
- override def addDriver(driver: DriverInfo) {
- serializeIntoFile(WORKING_DIR + "/driver_" + driver.id, driver)
- }
- override def removeDriver(driver: DriverInfo) {
- zk.delete().forPath(WORKING_DIR + "/driver_" + driver.id)
+ override def persist(name: String, obj: Object): Unit = {
+ serializeIntoFile(WORKING_DIR + "/" + name, obj)
}
- override def addWorker(worker: WorkerInfo) {
- serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker)
+ override def unpersist(name: String): Unit = {
+ zk.delete().forPath(WORKING_DIR + "/" + name)
}
- override def removeWorker(worker: WorkerInfo) {
- zk.delete().forPath(WORKING_DIR + "/worker_" + worker.id)
+ override def read[T: ClassTag](prefix: String) = {
+ val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix))
+ file.map(deserializeFromFile[T]).flatten
}
override def close() {
zk.close()
}
- override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
- val sortedFiles = zk.getChildren().forPath(WORKING_DIR).toList.sorted
- val appFiles = sortedFiles.filter(_.startsWith("app_"))
- val apps = appFiles.map(deserializeFromFile[ApplicationInfo]).flatten
- val driverFiles = sortedFiles.filter(_.startsWith("driver_"))
- val drivers = driverFiles.map(deserializeFromFile[DriverInfo]).flatten
- val workerFiles = sortedFiles.filter(_.startsWith("worker_"))
- val workers = workerFiles.map(deserializeFromFile[WorkerInfo]).flatten
- (apps, drivers, workers)
- }
-
private def serializeIntoFile(path: String, value: AnyRef) {
- val serializer = serialization.findSerializerFor(value)
- val serialized = serializer.toBinary(value)
- zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized)
+ val serialized = serializer.serialize(value)
+ zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized.array())
}
- def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): Option[T] = {
+ def deserializeFromFile[T](filename: String): Option[T] = {
val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename)
- val clazz = m.runtimeClass.asInstanceOf[Class[T]]
- val serializer = serialization.serializerFor(clazz)
try {
- Some(serializer.fromBinary(fileData).asInstanceOf[T])
+ Some(serializer.deserialize(ByteBuffer.wrap(fileData)))
} catch {
case e: Exception => {
logWarning("Exception while reading persisted file, deleting", e)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
index aba2e20118d7a..28e9662db5da9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
@@ -37,12 +37,12 @@ object CommandUtils extends Logging {
* The `env` argument is exposed for testing.
*/
def buildProcessBuilder(
- command: Command,
- memory: Int,
- sparkHome: String,
- substituteArguments: String => String,
- classPaths: Seq[String] = Seq[String](),
- env: Map[String, String] = sys.env): ProcessBuilder = {
+ command: Command,
+ memory: Int,
+ sparkHome: String,
+ substituteArguments: String => String,
+ classPaths: Seq[String] = Seq[String](),
+ env: Map[String, String] = sys.env): ProcessBuilder = {
val localCommand = buildLocalCommand(command, substituteArguments, classPaths, env)
val commandSeq = buildCommandSeq(localCommand, memory, sparkHome)
val builder = new ProcessBuilder(commandSeq: _*)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala
new file mode 100644
index 0000000000000..d044e1d01d429
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker
+
+import org.apache.spark.{Logging, SparkConf, SecurityManager}
+import org.apache.spark.network.TransportContext
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.sasl.SaslRpcHandler
+import org.apache.spark.network.server.TransportServer
+import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
+
+/**
+ * Provides a server from which Executors can read shuffle files (rather than reading directly from
+ * each other), to provide uninterrupted access to the files in the face of executors being turned
+ * off or killed.
+ *
+ * Optionally requires SASL authentication in order to read. See [[SecurityManager]].
+ */
+private[worker]
+class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager)
+ extends Logging {
+
+ private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false)
+ private val port = sparkConf.getInt("spark.shuffle.service.port", 7337)
+ private val useSasl: Boolean = securityManager.isAuthenticationEnabled()
+
+ private val transportConf = SparkTransportConf.fromSparkConf(sparkConf)
+ private val blockHandler = new ExternalShuffleBlockHandler(transportConf)
+ private val transportContext: TransportContext = {
+ val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler
+ new TransportContext(transportConf, handler)
+ }
+
+ private var server: TransportServer = _
+
+ /** Starts the external shuffle service if the user has configured us to. */
+ def startIfEnabled() {
+ if (enabled) {
+ require(server == null, "Shuffle server already started")
+ logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
+ server = transportContext.createServer(port)
+ }
+ }
+
+ def stop() {
+ if (enabled && server != null) {
+ server.close()
+ server = null
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index c4a8ec2e5e7b0..ca262de832e25 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -111,6 +111,9 @@ private[spark] class Worker(
val drivers = new HashMap[String, DriverRunner]
val finishedDrivers = new HashMap[String, DriverRunner]
+ // The shuffle service is not actually started unless configured.
+ val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr)
+
val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
if (envVar != null) envVar else host
@@ -154,6 +157,7 @@ private[spark] class Worker(
logInfo("Spark home: " + sparkHome)
createWorkDir()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ shuffleService.startIfEnabled()
webUi = new WorkerWebUI(this, workDir, webUiPort)
webUi.bind()
registerWithMaster()
@@ -186,11 +190,11 @@ private[spark] class Worker(
private def retryConnectToMaster() {
Utils.tryOrExit {
connectionAttemptCount += 1
- logInfo(s"Attempting to connect to master (attempt # $connectionAttemptCount")
if (registered) {
registrationRetryTimer.foreach(_.cancel())
registrationRetryTimer = None
} else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) {
+ logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)")
tryRegisterAllMasters()
if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) {
registrationRetryTimer.foreach(_.cancel())
@@ -419,6 +423,7 @@ private[spark] class Worker(
registrationRetryTimer.foreach(_.cancel())
executors.values.foreach(_.kill())
drivers.values.foreach(_.kill())
+ shuffleService.stop()
webUi.stop()
metricsSystem.stop()
}
@@ -441,7 +446,8 @@ private[spark] object Worker extends Logging {
cores: Int,
memory: Int,
masterUrls: Array[String],
- workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ workDir: String,
+ workerNumber: Option[Int] = None): (ActorSystem, Int) = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val conf = new SparkConf
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 697154d762d41..3711824a40cfc 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -131,7 +131,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
// Create a new ActorSystem using driver's Spark properties to run the backend.
val driverConf = new SparkConf().setAll(props)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf))
+ SparkEnv.executorActorSystemName,
+ hostname, port, driverConf, new SecurityManager(driverConf))
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 2889f59e33e84..caf4d76713d49 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -26,7 +26,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
-import akka.actor.ActorSystem
+import akka.actor.{Props, ActorSystem}
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
@@ -78,7 +78,7 @@ private[spark] class Executor(
val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above)
- conf.set("spark.executor.id", "executor." + executorId)
+ conf.set("spark.executor.id", executorId)
private val env = {
if (!isLocal) {
val port = conf.getInt("spark.executor.port", 0)
@@ -86,12 +86,17 @@ private[spark] class Executor(
conf, executorId, slaveHostname, port, isLocal, actorSystem)
SparkEnv.set(_env)
_env.metricsSystem.registerSource(executorSource)
+ _env.blockManager.initialize(conf.getAppId)
_env
} else {
SparkEnv.get
}
}
+ // Create an actor for receiving RPCs from the driver
+ private val executorActor = env.actorSystem.actorOf(
+ Props(new ExecutorActor(executorId)), "ExecutorActor")
+
// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
@@ -104,6 +109,9 @@ private[spark] class Executor(
// to send the result back.
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
+ // Limit of bytes for total size of results (default is 1GB)
+ private val maxResultSize = Utils.getMaxResultSize(conf)
+
// Start worker thread pool
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
@@ -128,6 +136,7 @@ private[spark] class Executor(
def stop() {
env.metricsSystem.report()
+ env.actorSystem.stop(executorActor)
isStopped = true
threadPool.shutdown()
if (!isLocal) {
@@ -152,7 +161,7 @@ private[spark] class Executor(
}
override def run() {
- val startTime = System.currentTimeMillis()
+ val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
@@ -197,7 +206,7 @@ private[spark] class Executor(
val afterSerialization = System.currentTimeMillis()
for (m <- task.metrics) {
- m.executorDeserializeTime = taskStart - startTime
+ m.executorDeserializeTime = taskStart - deserializeStartTime
m.executorRunTime = taskFinish - taskStart
m.jvmGCTime = gcTime - startGCTime
m.resultSerializationTime = afterSerialization - beforeSerialization
@@ -210,25 +219,27 @@ private[spark] class Executor(
val resultSize = serializedDirectResult.limit
// directSend = sending directly back to the driver
- val (serializedResult, directSend) = {
- if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
+ val serializedResult = {
+ if (resultSize > maxResultSize) {
+ logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
+ s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
+ s"dropping it.")
+ ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
+ } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
- (ser.serialize(new IndirectTaskResult[Any](blockId)), false)
+ logInfo(
+ s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
+ ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
- (serializedDirectResult, true)
+ logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
+ serializedDirectResult
}
}
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
- if (directSend) {
- logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
- } else {
- logInfo(
- s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
- }
} catch {
case ffe: FetchFailedException => {
val reason = ffe.toTaskEndReason
@@ -252,7 +263,7 @@ private[spark] class Executor(
m.executorRunTime = serviceTime
m.jvmGCTime = gcTime - startGCTime
}
- val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics)
+ val reason = new ExceptionFailure(t, metrics)
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// Don't forcibly exit unless the exception was inherently fatal, to avoid
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
new file mode 100644
index 0000000000000..41925f7e97e84
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import akka.actor.Actor
+import org.apache.spark.Logging
+
+import org.apache.spark.util.{Utils, ActorLogReceive}
+
+/**
+ * Driver -> Executor message to trigger a thread dump.
+ */
+private[spark] case object TriggerThreadDump
+
+/**
+ * Actor that runs inside of executors to enable driver -> executor RPC.
+ */
+private[spark]
+class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging {
+
+ override def receiveWithLogging = {
+ case TriggerThreadDump =>
+ sender ! Utils.getThreadDump()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 57bc2b40cec44..51b5328cb4c8f 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -82,6 +82,12 @@ class TaskMetrics extends Serializable {
*/
var inputMetrics: Option[InputMetrics] = None
+ /**
+ * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much
+ * data was written are stored here.
+ */
+ var outputMetrics: Option[OutputMetrics] = None
+
/**
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here.
* This includes read metrics aggregated over all the task's shuffle dependencies.
@@ -157,6 +163,16 @@ object DataReadMethod extends Enumeration with Serializable {
val Memory, Disk, Hadoop, Network = Value
}
+/**
+ * :: DeveloperApi ::
+ * Method by which output data was written.
+ */
+@DeveloperApi
+object DataWriteMethod extends Enumeration with Serializable {
+ type DataWriteMethod = Value
+ val Hadoop = Value
+}
+
/**
* :: DeveloperApi ::
* Metrics about reading input data.
@@ -169,6 +185,18 @@ case class InputMetrics(readMethod: DataReadMethod.Value) {
var bytesRead: Long = 0L
}
+/**
+ * :: DeveloperApi ::
+ * Metrics about writing output data.
+ */
+@DeveloperApi
+case class OutputMetrics(writeMethod: DataWriteMethod.Value) {
+ /**
+ * Total bytes written
+ */
+ var bytesWritten: Long = 0L
+}
+
/**
* :: DeveloperApi ::
* Metrics pertaining to shuffle data read in a given task.
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
new file mode 100644
index 0000000000000..89b29af2000c8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.{BytesWritable, LongWritable}
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
+import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+
+/**
+ * Custom Input Format for reading and splitting flat binary files that contain records,
+ * each of which are a fixed size in bytes. The fixed record size is specified through
+ * a parameter recordLength in the Hadoop configuration.
+ */
+private[spark] object FixedLengthBinaryInputFormat {
+ /** Property name to set in Hadoop JobConfs for record length */
+ val RECORD_LENGTH_PROPERTY = "org.apache.spark.input.FixedLengthBinaryInputFormat.recordLength"
+
+ /** Retrieves the record length property from a Hadoop configuration */
+ def getRecordLength(context: JobContext): Int = {
+ context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt
+ }
+}
+
+private[spark] class FixedLengthBinaryInputFormat
+ extends FileInputFormat[LongWritable, BytesWritable] {
+
+ private var recordLength = -1
+
+ /**
+ * Override of isSplitable to ensure initial computation of the record length
+ */
+ override def isSplitable(context: JobContext, filename: Path): Boolean = {
+ if (recordLength == -1) {
+ recordLength = FixedLengthBinaryInputFormat.getRecordLength(context)
+ }
+ if (recordLength <= 0) {
+ println("record length is less than 0, file cannot be split")
+ false
+ } else {
+ true
+ }
+ }
+
+ /**
+ * This input format overrides computeSplitSize() to make sure that each split
+ * only contains full records. Each InputSplit passed to FixedLengthBinaryRecordReader
+ * will start at the first byte of a record, and the last byte will the last byte of a record.
+ */
+ override def computeSplitSize(blockSize: Long, minSize: Long, maxSize: Long): Long = {
+ val defaultSize = super.computeSplitSize(blockSize, minSize, maxSize)
+ // If the default size is less than the length of a record, make it equal to it
+ // Otherwise, make sure the split size is as close to possible as the default size,
+ // but still contains a complete set of records, with the first record
+ // starting at the first byte in the split and the last record ending with the last byte
+ if (defaultSize < recordLength) {
+ recordLength.toLong
+ } else {
+ (Math.floor(defaultSize / recordLength) * recordLength).toLong
+ }
+ }
+
+ /**
+ * Create a FixedLengthBinaryRecordReader
+ */
+ override def createRecordReader(split: InputSplit, context: TaskAttemptContext)
+ : RecordReader[LongWritable, BytesWritable] = {
+ new FixedLengthBinaryRecordReader
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
new file mode 100644
index 0000000000000..36a1e5d475f46
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import java.io.IOException
+
+import org.apache.hadoop.fs.FSDataInputStream
+import org.apache.hadoop.io.compress.CompressionCodecFactory
+import org.apache.hadoop.io.{BytesWritable, LongWritable}
+import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+
+/**
+ * FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat.
+ * It uses the record length set in FixedLengthBinaryInputFormat to
+ * read one record at a time from the given InputSplit.
+ *
+ * Each call to nextKeyValue() updates the LongWritable key and BytesWritable value.
+ *
+ * key = record index (Long)
+ * value = the record itself (BytesWritable)
+ */
+private[spark] class FixedLengthBinaryRecordReader
+ extends RecordReader[LongWritable, BytesWritable] {
+
+ private var splitStart: Long = 0L
+ private var splitEnd: Long = 0L
+ private var currentPosition: Long = 0L
+ private var recordLength: Int = 0
+ private var fileInputStream: FSDataInputStream = null
+ private var recordKey: LongWritable = null
+ private var recordValue: BytesWritable = null
+
+ override def close() {
+ if (fileInputStream != null) {
+ fileInputStream.close()
+ }
+ }
+
+ override def getCurrentKey: LongWritable = {
+ recordKey
+ }
+
+ override def getCurrentValue: BytesWritable = {
+ recordValue
+ }
+
+ override def getProgress: Float = {
+ splitStart match {
+ case x if x == splitEnd => 0.0.toFloat
+ case _ => Math.min(
+ ((currentPosition - splitStart) / (splitEnd - splitStart)).toFloat, 1.0
+ ).toFloat
+ }
+ }
+
+ override def initialize(inputSplit: InputSplit, context: TaskAttemptContext) {
+ // the file input
+ val fileSplit = inputSplit.asInstanceOf[FileSplit]
+
+ // the byte position this fileSplit starts at
+ splitStart = fileSplit.getStart
+
+ // splitEnd byte marker that the fileSplit ends at
+ splitEnd = splitStart + fileSplit.getLength
+
+ // the actual file we will be reading from
+ val file = fileSplit.getPath
+ // job configuration
+ val job = context.getConfiguration
+ // check compression
+ val codec = new CompressionCodecFactory(job).getCodec(file)
+ if (codec != null) {
+ throw new IOException("FixedLengthRecordReader does not support reading compressed files")
+ }
+ // get the record length
+ recordLength = FixedLengthBinaryInputFormat.getRecordLength(context)
+ // get the filesystem
+ val fs = file.getFileSystem(job)
+ // open the File
+ fileInputStream = fs.open(file)
+ // seek to the splitStart position
+ fileInputStream.seek(splitStart)
+ // set our current position
+ currentPosition = splitStart
+ }
+
+ override def nextKeyValue(): Boolean = {
+ if (recordKey == null) {
+ recordKey = new LongWritable()
+ }
+ // the key is a linear index of the record, given by the
+ // position the record starts divided by the record length
+ recordKey.set(currentPosition / recordLength)
+ // the recordValue to place the bytes into
+ if (recordValue == null) {
+ recordValue = new BytesWritable(new Array[Byte](recordLength))
+ }
+ // read a record if the currentPosition is less than the split end
+ if (currentPosition < splitEnd) {
+ // setup a buffer to store the record
+ val buffer = recordValue.getBytes
+ fileInputStream.readFully(buffer)
+ // update our current position
+ currentPosition = currentPosition + recordLength
+ // return true
+ return true
+ }
+ false
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
new file mode 100644
index 0000000000000..457472547fcbb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
+import scala.collection.JavaConversions._
+
+import com.google.common.io.ByteStreams
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit}
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * A general format for reading whole files in as streams, byte arrays,
+ * or other functions to be added
+ */
+private[spark] abstract class StreamFileInputFormat[T]
+ extends CombineFileInputFormat[String, T]
+{
+ override protected def isSplitable(context: JobContext, file: Path): Boolean = false
+
+ /**
+ * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API
+ * which is set through setMaxSplitSize
+ */
+ def setMinPartitions(context: JobContext, minPartitions: Int) {
+ val files = listStatus(context)
+ val totalLen = files.map { file =>
+ if (file.isDir) 0L else file.getLen
+ }.sum
+
+ val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong
+ super.setMaxSplitSize(maxSplitSize)
+ }
+
+ def createRecordReader(split: InputSplit, taContext: TaskAttemptContext): RecordReader[String, T]
+
+}
+
+/**
+ * An abstract class of [[org.apache.hadoop.mapreduce.RecordReader RecordReader]]
+ * to reading files out as streams
+ */
+private[spark] abstract class StreamBasedRecordReader[T](
+ split: CombineFileSplit,
+ context: TaskAttemptContext,
+ index: Integer)
+ extends RecordReader[String, T] {
+
+ // True means the current file has been processed, then skip it.
+ private var processed = false
+
+ private var key = ""
+ private var value: T = null.asInstanceOf[T]
+
+ override def initialize(split: InputSplit, context: TaskAttemptContext) = {}
+ override def close() = {}
+
+ override def getProgress = if (processed) 1.0f else 0.0f
+
+ override def getCurrentKey = key
+
+ override def getCurrentValue = value
+
+ override def nextKeyValue = {
+ if (!processed) {
+ val fileIn = new PortableDataStream(split, context, index)
+ value = parseStream(fileIn)
+ fileIn.close() // if it has not been open yet, close does nothing
+ key = fileIn.getPath
+ processed = true
+ true
+ } else {
+ false
+ }
+ }
+
+ /**
+ * Parse the stream (and close it afterwards) and return the value as in type T
+ * @param inStream the stream to be read in
+ * @return the data formatted as
+ */
+ def parseStream(inStream: PortableDataStream): T
+}
+
+/**
+ * Reads the record in directly as a stream for other objects to manipulate and handle
+ */
+private[spark] class StreamRecordReader(
+ split: CombineFileSplit,
+ context: TaskAttemptContext,
+ index: Integer)
+ extends StreamBasedRecordReader[PortableDataStream](split, context, index) {
+
+ def parseStream(inStream: PortableDataStream): PortableDataStream = inStream
+}
+
+/**
+ * The format for the PortableDataStream files
+ */
+private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] {
+ override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = {
+ new CombineFileRecordReader[String, PortableDataStream](
+ split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader])
+ }
+}
+
+/**
+ * A class that allows DataStreams to be serialized and moved around by not creating them
+ * until they need to be read
+ * @note TaskAttemptContext is not serializable resulting in the confBytes construct
+ * @note CombineFileSplit is not serializable resulting in the splitBytes construct
+ */
+@Experimental
+class PortableDataStream(
+ @transient isplit: CombineFileSplit,
+ @transient context: TaskAttemptContext,
+ index: Integer)
+ extends Serializable {
+
+ // transient forces file to be reopened after being serialization
+ // it is also used for non-serializable classes
+
+ @transient private var fileIn: DataInputStream = null
+ @transient private var isOpen = false
+
+ private val confBytes = {
+ val baos = new ByteArrayOutputStream()
+ context.getConfiguration.write(new DataOutputStream(baos))
+ baos.toByteArray
+ }
+
+ private val splitBytes = {
+ val baos = new ByteArrayOutputStream()
+ isplit.write(new DataOutputStream(baos))
+ baos.toByteArray
+ }
+
+ @transient private lazy val split = {
+ val bais = new ByteArrayInputStream(splitBytes)
+ val nsplit = new CombineFileSplit()
+ nsplit.readFields(new DataInputStream(bais))
+ nsplit
+ }
+
+ @transient private lazy val conf = {
+ val bais = new ByteArrayInputStream(confBytes)
+ val nconf = new Configuration()
+ nconf.readFields(new DataInputStream(bais))
+ nconf
+ }
+ /**
+ * Calculate the path name independently of opening the file
+ */
+ @transient private lazy val path = {
+ val pathp = split.getPath(index)
+ pathp.toString
+ }
+
+ /**
+ * Create a new DataInputStream from the split and context
+ */
+ def open(): DataInputStream = {
+ if (!isOpen) {
+ val pathp = split.getPath(index)
+ val fs = pathp.getFileSystem(conf)
+ fileIn = fs.open(pathp)
+ isOpen = true
+ }
+ fileIn
+ }
+
+ /**
+ * Read the file as a byte array
+ */
+ def toArray(): Array[Byte] = {
+ open()
+ val innerBuffer = ByteStreams.toByteArray(fileIn)
+ close()
+ innerBuffer
+ }
+
+ /**
+ * Close the file (if it is currently open)
+ */
+ def close() = {
+ if (isOpen) {
+ try {
+ fileIn.close()
+ isOpen = false
+ } catch {
+ case ioe: java.io.IOException => // do nothing
+ }
+ }
+ }
+
+ def getPath(): String = path
+}
+
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
index 4cb450577796a..183bce3d8d8d3 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
@@ -48,9 +48,10 @@ private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[Str
}
/**
- * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API.
+ * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API,
+ * which is set through setMaxSplitSize
*/
- def setMaxSplitSize(context: JobContext, minPartitions: Int) {
+ def setMinPartitions(context: JobContext, minPartitions: Int) {
val files = listStatus(context)
val totalLen = files.map { file =>
if (file.isDir) 0L else file.getLen
diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
similarity index 79%
rename from core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
rename to core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
index 0c47afae54c8b..21b782edd2a9e 100644
--- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
@@ -15,15 +15,24 @@
* limitations under the License.
*/
-package org.apache.hadoop.mapred
+package org.apache.spark.mapred
-private[apache]
+import java.lang.reflect.Modifier
+
+import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext}
+
+private[spark]
trait SparkHadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
"org.apache.hadoop.mapred.JobContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf],
classOf[org.apache.hadoop.mapreduce.JobID])
+ // In Hadoop 1.0.x, JobContext is an interface, and JobContextImpl is package private.
+ // Make it accessible if it's not in order to access it.
+ if (!Modifier.isPublic(ctor.getModifiers)) {
+ ctor.setAccessible(true)
+ }
ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
}
@@ -31,6 +40,10 @@ trait SparkHadoopMapRedUtil {
val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
"org.apache.hadoop.mapred.TaskAttemptContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
+ // See above
+ if (!Modifier.isPublic(ctor.getModifiers)) {
+ ctor.setAccessible(true)
+ }
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
similarity index 96%
rename from core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
rename to core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
index 1fca5729c6092..3340673f91156 100644
--- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
@@ -15,13 +15,14 @@
* limitations under the License.
*/
-package org.apache.hadoop.mapreduce
+package org.apache.spark.mapreduce
import java.lang.{Boolean => JBoolean, Integer => JInteger}
import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID}
-private[apache]
+private[spark]
trait SparkHadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
val klass = firstAvailableClass(
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index b083f465334fe..dcbda5a8515dd 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -20,16 +20,16 @@ package org.apache.spark.network
import java.io.Closeable
import java.nio.ByteBuffer
-import scala.concurrent.{Await, Future}
+import scala.concurrent.{Promise, Await, Future}
import scala.concurrent.duration.Duration
import org.apache.spark.Logging
import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
-import org.apache.spark.storage.{BlockId, StorageLevel}
-import org.apache.spark.util.Utils
+import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener}
+import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel}
private[spark]
-abstract class BlockTransferService extends Closeable with Logging {
+abstract class BlockTransferService extends ShuffleClient with Closeable with Logging {
/**
* Initialize the transfer service by giving it the BlockDataManager that can be used to fetch
@@ -60,10 +60,11 @@ abstract class BlockTransferService extends Closeable with Logging {
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched.
*/
- def fetchBlocks(
- hostName: String,
+ override def fetchBlocks(
+ host: String,
port: Int,
- blockIds: Seq[String],
+ execId: String,
+ blockIds: Array[String],
listener: BlockFetchingListener): Unit
/**
@@ -72,6 +73,7 @@ abstract class BlockTransferService extends Closeable with Logging {
def uploadBlock(
hostname: String,
port: Int,
+ execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit]
@@ -81,43 +83,23 @@ abstract class BlockTransferService extends Closeable with Logging {
*
* It is also only available after [[init]] is invoked.
*/
- def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = {
+ def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = {
// A monitor for the thread to wait on.
- val lock = new Object
- @volatile var result: Either[ManagedBuffer, Throwable] = null
- fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener {
- override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
- lock.synchronized {
- result = Right(exception)
- lock.notify()
+ val result = Promise[ManagedBuffer]()
+ fetchBlocks(host, port, execId, Array(blockId),
+ new BlockFetchingListener {
+ override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
+ result.failure(exception)
}
- }
- override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
- lock.synchronized {
+ override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
val ret = ByteBuffer.allocate(data.size.toInt)
ret.put(data.nioByteBuffer())
ret.flip()
- result = Left(new NioManagedBuffer(ret))
- lock.notify()
+ result.success(new NioManagedBuffer(ret))
}
- }
- })
+ })
- // Sleep until result is no longer null
- lock.synchronized {
- while (result == null) {
- try {
- lock.wait()
- } catch {
- case e: InterruptedException =>
- }
- }
- }
-
- result match {
- case Left(data) => data
- case Right(e) => throw e
- }
+ Await.result(result.future, Duration.Inf)
}
/**
@@ -129,9 +111,10 @@ abstract class BlockTransferService extends Closeable with Logging {
def uploadBlockSync(
hostname: String,
port: Int,
+ execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Unit = {
- Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf)
+ Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala
deleted file mode 100644
index 8c5ffd8da6bbb..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.nio.ByteBuffer
-import java.util
-
-import org.apache.spark.{SparkConf, Logging}
-import org.apache.spark.network.BlockFetchingListener
-import org.apache.spark.network.netty.NettyMessages._
-import org.apache.spark.serializer.{JavaSerializer, Serializer}
-import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, TransportClient}
-import org.apache.spark.storage.BlockId
-import org.apache.spark.util.Utils
-
-/**
- * Responsible for holding the state for a request for a single set of blocks. This assumes that
- * the chunks will be returned in the same order as requested, and that there will be exactly
- * one chunk per block.
- *
- * Upon receipt of any block, the listener will be called back. Upon failure part way through,
- * the listener will receive a failure callback for each outstanding block.
- */
-class NettyBlockFetcher(
- serializer: Serializer,
- client: TransportClient,
- blockIds: Seq[String],
- listener: BlockFetchingListener)
- extends Logging {
-
- require(blockIds.nonEmpty)
-
- private val ser = serializer.newInstance()
-
- private var streamHandle: ShuffleStreamHandle = _
-
- private val chunkCallback = new ChunkReceivedCallback {
- // On receipt of a chunk, pass it upwards as a block.
- def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions {
- listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer)
- }
-
- // On receipt of a failure, fail every block from chunkIndex onwards.
- def onFailure(chunkIndex: Int, e: Throwable): Unit = {
- blockIds.drop(chunkIndex).foreach { blockId =>
- listener.onBlockFetchFailure(blockId, e);
- }
- }
- }
-
- /** Begins the fetching process, calling the listener with every block fetched. */
- def start(): Unit = {
- // Send the RPC to open the given set of blocks. This will return a ShuffleStreamHandle.
- client.sendRpc(ser.serialize(OpenBlocks(blockIds.map(BlockId.apply))).array(),
- new RpcResponseCallback {
- override def onSuccess(response: Array[Byte]): Unit = {
- try {
- streamHandle = ser.deserialize[ShuffleStreamHandle](ByteBuffer.wrap(response))
- logTrace(s"Successfully opened block set: $streamHandle! Preparing to fetch chunks.")
-
- // Immediately request all chunks -- we expect that the total size of the request is
- // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
- for (i <- 0 until streamHandle.numChunks) {
- client.fetchChunk(streamHandle.streamId, i, chunkCallback)
- }
- } catch {
- case e: Exception =>
- logError("Failed while starting block fetches", e)
- blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e)))
- }
- }
-
- override def onFailure(e: Throwable): Unit = {
- logError("Failed while starting block fetches", e)
- blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e)))
- }
- })
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 02c657e1d61b5..b089da8596e2b 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -19,58 +19,55 @@ package org.apache.spark.network.netty
import java.nio.ByteBuffer
+import scala.collection.JavaConversions._
+
import org.apache.spark.Logging
import org.apache.spark.network.BlockDataManager
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
+import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
+import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock}
import org.apache.spark.serializer.Serializer
-import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.client.{TransportClient, RpcResponseCallback}
-import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler}
-import org.apache.spark.storage.{StorageLevel, BlockId}
-
-import scala.collection.JavaConversions._
-
-object NettyMessages {
-
- /** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */
- case class OpenBlocks(blockIds: Seq[BlockId])
-
- /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
- case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel)
-
- /** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */
- case class ShuffleStreamHandle(streamId: Long, numChunks: Int)
-}
+import org.apache.spark.storage.{BlockId, StorageLevel}
/**
* Serves requests to open blocks by simply registering one chunk per block requested.
+ * Handles opening and uploading arbitrary BlockManager blocks.
+ *
+ * Opened blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk
+ * is equivalent to one Spark-level shuffle block.
*/
class NettyBlockRpcServer(
serializer: Serializer,
- streamManager: DefaultStreamManager,
blockManager: BlockDataManager)
extends RpcHandler with Logging {
- import NettyMessages._
+ private val streamManager = new OneForOneStreamManager()
override def receive(
client: TransportClient,
messageBytes: Array[Byte],
responseContext: RpcResponseCallback): Unit = {
- val ser = serializer.newInstance()
- val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes))
+ val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes)
logTrace(s"Received request: $message")
message match {
- case OpenBlocks(blockIds) =>
- val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData)
+ case openBlocks: OpenBlocks =>
+ val blocks: Seq[ManagedBuffer] =
+ openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
val streamId = streamManager.registerStream(blocks.iterator)
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
- responseContext.onSuccess(
- ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array())
+ responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
- case UploadBlock(blockId, blockData, level) =>
- blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level)
+ case uploadBlock: UploadBlock =>
+ // StorageLevel is serialized as bytes using our JavaSerializer.
+ val level: StorageLevel =
+ serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
+ val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
+ blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level)
responseContext.onSuccess(new Array[Byte](0))
}
}
+
+ override def getStreamManager(): StreamManager = streamManager
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 38a3e945155e8..f8a7f640689a2 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -17,15 +17,17 @@
package org.apache.spark.network.netty
-import scala.concurrent.{Promise, Future}
+import scala.collection.JavaConversions._
+import scala.concurrent.{Future, Promise}
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.network.client.{RpcResponseCallback, TransportClient, TransportClientFactory}
-import org.apache.spark.network.netty.NettyMessages.UploadBlock
+import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
+import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
import org.apache.spark.network.server._
-import org.apache.spark.network.util.{ConfigProvider, TransportConf}
+import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
+import org.apache.spark.network.shuffle.protocol.UploadBlock
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
@@ -33,34 +35,59 @@ import org.apache.spark.util.Utils
/**
* A BlockTransferService that uses Netty to fetch a set of blocks at at time.
*/
-class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
- // TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
- val serializer = new JavaSerializer(conf)
+class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager)
+ extends BlockTransferService {
- // Create a TransportConfig using SparkConf.
- private[this] val transportConf = new TransportConf(
- new ConfigProvider { override def get(name: String) = conf.get(name) })
+ // TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
+ private val serializer = new JavaSerializer(conf)
+ private val authEnabled = securityManager.isAuthenticationEnabled()
+ private val transportConf = SparkTransportConf.fromSparkConf(conf)
private[this] var transportContext: TransportContext = _
private[this] var server: TransportServer = _
private[this] var clientFactory: TransportClientFactory = _
+ private[this] var appId: String = _
override def init(blockDataManager: BlockDataManager): Unit = {
- val streamManager = new DefaultStreamManager
- val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager)
- transportContext = new TransportContext(transportConf, streamManager, rpcHandler)
- clientFactory = transportContext.createClientFactory()
+ val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
+ val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
+ if (!authEnabled) {
+ (nettyRpcHandler, None)
+ } else {
+ (new SaslRpcHandler(nettyRpcHandler, securityManager),
+ Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager)))
+ }
+ }
+ transportContext = new TransportContext(transportConf, rpcHandler)
+ clientFactory = transportContext.createClientFactory(bootstrap.toList)
server = transportContext.createServer()
+ appId = conf.getAppId
+ logInfo("Server created on " + server.getPort)
}
override def fetchBlocks(
- hostname: String,
+ host: String,
port: Int,
- blockIds: Seq[String],
+ execId: String,
+ blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
+ logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
- val client = clientFactory.createClient(hostname, port)
- new NettyBlockFetcher(serializer, client, blockIds, listener).start()
+ val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
+ override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
+ val client = clientFactory.createClient(host, port)
+ new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
+ }
+ }
+
+ val maxRetries = transportConf.maxIORetries()
+ if (maxRetries > 0) {
+ // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
+ // a bug in this code. We should remove the if statement once we're sure of the stability.
+ new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
+ } else {
+ blockFetchStarter.createAndStart(blockIds, listener)
+ }
} catch {
case e: Exception =>
logError("Exception while beginning fetchBlocks", e)
@@ -75,12 +102,17 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
override def uploadBlock(
hostname: String,
port: Int,
+ execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit] = {
val result = Promise[Unit]()
val client = clientFactory.createClient(hostname, port)
+ // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
+ // using our binary protocol.
+ val levelBytes = serializer.newInstance().serialize(level).array()
+
// Convert or copy nio buffer into array in order to serialize it.
val nioBuffer = blockData.nioByteBuffer()
val array = if (nioBuffer.hasArray) {
@@ -91,8 +123,7 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
data
}
- val ser = serializer.newInstance()
- client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(),
+ client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray,
new RpcResponseCallback {
override def onSuccess(response: Array[Byte]): Unit = {
logTrace(s"Successfully uploaded block $blockId")
@@ -107,5 +138,8 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
result.future
}
- override def close(): Unit = server.close()
+ override def close(): Unit = {
+ server.close()
+ clientFactory.close()
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
new file mode 100644
index 0000000000000..9fa4fa77b8817
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network.util.{TransportConf, ConfigProvider}
+
+/**
+ * Utility for creating a [[TransportConf]] from a [[SparkConf]].
+ */
+object SparkTransportConf {
+ def fromSparkConf(conf: SparkConf): TransportConf = {
+ new TransportConf(new ConfigProvider {
+ override def get(name: String): String = conf.get(name)
+ })
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
index 4f6f5e235811d..c2d9578be7ebb 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -23,12 +23,13 @@ import java.nio.channels._
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.LinkedList
-import org.apache.spark._
-
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
+import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
+
private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 8408b75bb4d65..f198aa8564a54 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -34,6 +34,7 @@ import scala.language.postfixOps
import com.google.common.base.Charsets.UTF_8
import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
import org.apache.spark.util.Utils
import scala.util.Try
@@ -600,7 +601,7 @@ private[nio] class ConnectionManager(
} else {
var replyToken : Array[Byte] = null
try {
- replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken)
+ replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken)
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
@@ -634,7 +635,7 @@ private[nio] class ConnectionManager(
connection.synchronized {
if (connection.sparkSaslServer == null) {
logDebug("Creating sasl Server")
- connection.sparkSaslServer = new SparkSaslServer(securityManager)
+ connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager)
}
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
@@ -778,7 +779,7 @@ private[nio] class ConnectionManager(
if (!conn.isSaslComplete()) {
conn.synchronized {
if (conn.sparkSaslClient == null) {
- conn.sparkSaslClient = new SparkSaslClient(securityManager)
+ conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager)
var firstResponse: Array[Byte] = null
try {
firstResponse = conn.sparkSaslClient.firstToken()
diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
index 11793ea92adb1..b2aec160635c7 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -21,6 +21,7 @@ import java.nio.ByteBuffer
import org.apache.spark.network._
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
@@ -79,13 +80,14 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
}
override def fetchBlocks(
- hostName: String,
+ host: String,
port: Int,
- blockIds: Seq[String],
+ execId: String,
+ blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
checkInit()
- val cmId = new ConnectionManagerId(hostName, port)
+ val cmId = new ConnectionManagerId(host, port)
val blockMessageArray = new BlockMessageArray(blockIds.map { blockId =>
BlockMessage.fromGetBlock(GetBlock(BlockId(blockId)))
})
@@ -135,6 +137,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
override def uploadBlock(
hostname: String,
port: Int,
+ execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel)
diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
new file mode 100644
index 0000000000000..6e66ddbdef788
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.hadoop.conf.{ Configurable, Configuration }
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapreduce._
+import org.apache.spark.input.StreamFileInputFormat
+import org.apache.spark.{ Partition, SparkContext }
+
+private[spark] class BinaryFileRDD[T](
+ sc: SparkContext,
+ inputFormatClass: Class[_ <: StreamFileInputFormat[T]],
+ keyClass: Class[String],
+ valueClass: Class[T],
+ @transient conf: Configuration,
+ minPartitions: Int)
+ extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) {
+
+ override def getPartitions: Array[Partition] = {
+ val inputFormat = inputFormatClass.newInstance
+ inputFormat match {
+ case configurable: Configurable =>
+ configurable.setConf(conf)
+ case _ =>
+ }
+ val jobContext = newJobContext(conf, jobId)
+ inputFormat.setMinPartitions(jobContext, minPartitions)
+ val rawSplits = inputFormat.getSplits(jobContext).toArray
+ val result = new Array[Partition](rawSplits.size)
+ for (i <- 0 until rawSplits.size) {
+ result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
+ }
+ result
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index 2673ec22509e9..fffa1911f5bc2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -84,5 +84,9 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds
"Attempted to use %s after its blocks have been removed!".format(toString))
}
}
+
+ protected def getBlockIdLocations(): Map[BlockId, Seq[String]] = {
+ locations_
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 946fb5616d3ec..a157e36e2286e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -211,20 +211,11 @@ class HadoopRDD[K, V](
val split = theSplit.asInstanceOf[HadoopPartition]
logInfo("Input split: " + split.inputSplit)
- var reader: RecordReader[K, V] = null
val jobConf = getJobConf()
- val inputFormat = getInputFormat(jobConf)
- HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
- context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
- reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
-
- // Register an on-task-completion callback to close the input stream.
- context.addTaskCompletionListener{ context => closeIfNeeded() }
- val key: K = reader.createKey()
- val value: V = reader.createValue()
val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- // Find a function that will return the FileSystem bytes read by this thread.
+ // Find a function that will return the FileSystem bytes read by this thread. Do this before
+ // creating RecordReader, because RecordReader's constructor might read some bytes
val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) {
SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf)
@@ -234,6 +225,18 @@ class HadoopRDD[K, V](
if (bytesReadCallback.isDefined) {
context.taskMetrics.inputMetrics = Some(inputMetrics)
}
+
+ var reader: RecordReader[K, V] = null
+ val inputFormat = getInputFormat(jobConf)
+ HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
+ context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
+ reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addTaskCompletionListener{ context => closeIfNeeded() }
+ val key: K = reader.createKey()
+ val value: V = reader.createValue()
+
var recordsSinceMetricsUpdate = 0
override def getNext() = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 324563248793c..e55d03d391e03 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -35,6 +35,7 @@ import org.apache.spark.Partition
import org.apache.spark.SerializableWritable
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.executor.{DataReadMethod, InputMetrics}
+import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.Utils
import org.apache.spark.deploy.SparkHadoopUtil
@@ -107,20 +108,10 @@ class NewHadoopRDD[K, V](
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
- val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
- val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
- val format = inputFormatClass.newInstance
- format match {
- case configurable: Configurable =>
- configurable.setConf(conf)
- case _ =>
- }
- val reader = format.createRecordReader(
- split.serializableHadoopSplit.value, hadoopAttemptContext)
- reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- // Find a function that will return the FileSystem bytes read by this thread.
+ // Find a function that will return the FileSystem bytes read by this thread. Do this before
+ // creating RecordReader, because RecordReader's constructor might read some bytes
val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf)
@@ -131,6 +122,18 @@ class NewHadoopRDD[K, V](
context.taskMetrics.inputMetrics = Some(inputMetrics)
}
+ val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
+ val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
+ val format = inputFormatClass.newInstance
+ format match {
+ case configurable: Configurable =>
+ configurable.setConf(conf)
+ case _ =>
+ }
+ val reader = format.createRecordReader(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)
+ reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener(context => close())
var havePair = false
@@ -263,7 +266,7 @@ private[spark] class WholeTextFileRDD(
case _ =>
}
val jobContext = newJobContext(conf, jobId)
- inputFormat.setMaxSplitSize(jobContext, minPartitions)
+ inputFormat.setMinPartitions(jobContext, minPartitions)
val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Partition](rawSplits.size)
for (i <- 0 until rawSplits.size) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index da89f634abaea..8c2c959e73bb6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -28,18 +28,20 @@ import scala.reflect.ClassTag
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
import org.apache.hadoop.conf.{Configurable, Configuration}
-import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat,
-RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil}
+RecordWriter => NewRecordWriter}
import org.apache.spark._
import org.apache.spark.Partitioner.defaultPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.executor.{DataWriteMethod, OutputMetrics}
+import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
@@ -961,30 +963,40 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => {
+ val config = wrappedConf.value
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" */
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
attemptNumber)
- val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
+ val hadoopContext = newTaskAttemptContext(config, attemptId)
val format = outfmt.newInstance
format match {
- case c: Configurable => c.setConf(wrappedConf.value)
+ case c: Configurable => c.setConf(config)
case _ => ()
}
val committer = format.getOutputCommitter(hadoopContext)
committer.setupTask(hadoopContext)
+
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
+
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
try {
+ var recordsWritten = 0L
while (iter.hasNext) {
val pair = iter.next()
writer.write(pair._1, pair._2)
+
+ // Update bytes written metric every few records
+ maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten)
+ recordsWritten += 1
}
} finally {
writer.close(hadoopContext)
}
committer.commitTask(hadoopContext)
+ bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
1
} : Int
@@ -1005,6 +1017,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def saveAsHadoopDataset(conf: JobConf) {
// Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038).
val hadoopConf = conf
+ val wrappedConf = new SerializableWritable(hadoopConf)
val outputFormatInstance = hadoopConf.getOutputFormat
val keyClass = hadoopConf.getOutputKeyClass
val valueClass = hadoopConf.getOutputValueClass
@@ -1032,27 +1045,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.preSetup()
val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => {
+ val config = wrappedConf.value
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
+
writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.open()
try {
+ var recordsWritten = 0L
while (iter.hasNext) {
val record = iter.next()
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
+
+ // Update bytes written metric every few records
+ maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten)
+ recordsWritten += 1
}
} finally {
writer.close()
}
writer.commit()
+ bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
}
self.context.runJob(self, writeToFile)
writer.commitJob()
}
+ private def initHadoopOutputMetrics(context: TaskContext, config: Configuration)
+ : (OutputMetrics, Option[() => Long]) = {
+ val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir"))
+ .map(new Path(_))
+ .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config))
+ val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
+ if (bytesWrittenCallback.isDefined) {
+ context.taskMetrics.outputMetrics = Some(outputMetrics)
+ }
+ (outputMetrics, bytesWrittenCallback)
+ }
+
+ private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long],
+ outputMetrics: OutputMetrics, recordsWritten: Long): Unit = {
+ if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0
+ && bytesWrittenCallback.isDefined) {
+ bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
+ }
+ }
+
/**
* Return an RDD with the keys of each tuple.
*/
@@ -1069,3 +1111,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord)
}
+
+private[spark] object PairRDDFunctions {
+ val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index b7f125d01dfaf..716f2dd17733b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -43,7 +43,8 @@ import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite}
import org.apache.spark.util.collection.OpenHashMap
-import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
+import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler,
+ SamplingUtils}
/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
@@ -375,7 +376,8 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed)
+ new PartitionwiseSampledRDD[T, T](
+ this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
}.toArray
}
@@ -1094,7 +1096,7 @@ abstract class RDD[T: ClassTag](
}
/**
- * Returns the top K (largest) elements from this RDD as defined by the specified
+ * Returns the top k (largest) elements from this RDD as defined by the specified
* implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example:
* {{{
* sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1)
@@ -1104,14 +1106,14 @@ abstract class RDD[T: ClassTag](
* // returns Array(6, 5)
* }}}
*
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
*/
def top(num: Int)(implicit ord: Ordering[T]): Array[T] = takeOrdered(num)(ord.reverse)
/**
- * Returns the first K (smallest) elements from this RDD as defined by the specified
+ * Returns the first k (smallest) elements from this RDD as defined by the specified
* implicit Ordering[T] and maintains the ordering. This does the opposite of [[top]].
* For example:
* {{{
@@ -1122,7 +1124,7 @@ abstract class RDD[T: ClassTag](
* // returns Array(2, 3)
* }}}
*
- * @param num the number of top elements to return
+ * @param num k, the number of elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
*/
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index f81fa6d8089fc..22449517d100f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -124,6 +124,9 @@ class DAGScheduler(
/** If enabled, we may run certain actions like take() and first() locally. */
private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)
+ /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
+ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)
+
private def initializeEventProcessActor() {
// blocking the thread until supervisor is started, which ensures eventProcessActor is
// not null before any job is submitted
@@ -1050,7 +1053,7 @@ class DAGScheduler(
logInfo("Resubmitted " + task + ", so marking it as still running")
stage.pendingTasks += task
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)
@@ -1060,11 +1063,13 @@ class DAGScheduler(
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
- markStageAsFinished(failedStage, Some("Fetch failure"))
+ markStageAsFinished(failedStage, Some(failureMessage))
runningStages -= failedStage
}
- if (failedStages.isEmpty && eventProcessActor != null) {
+ if (disallowStageRetryForTest) {
+ abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
+ } else if (failedStages.isEmpty && eventProcessActor != null) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled. eventProcessActor may be
// null during unit tests.
@@ -1086,10 +1091,10 @@ class DAGScheduler(
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
- handleExecutorLost(bmAddress.executorId, Some(task.epoch))
+ handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
- case ExceptionFailure(className, description, stackTrace, metrics) =>
+ case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) =>
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
case TaskResultLost =>
@@ -1106,25 +1111,35 @@ class DAGScheduler(
* Responds to an executor being lost. This is called inside the event loop, so it assumes it can
* modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
*
+ * We will also assume that we've lost all shuffle blocks associated with the executor if the
+ * executor serves its own blocks (i.e., we're not using external shuffle) OR a FetchFailed
+ * occurred, in which case we presume all shuffle data related to this executor to be lost.
+ *
* Optionally the epoch during which the failure was caught can be passed to avoid allowing
* stray fetch failures from possibly retriggering the detection of a node as lost.
*/
- private[scheduler] def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) {
+ private[scheduler] def handleExecutorLost(
+ execId: String,
+ fetchFailed: Boolean,
+ maybeEpoch: Option[Long] = None) {
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) {
failedEpoch(execId) = currentEpoch
logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch))
blockManagerMaster.removeExecutor(execId)
- // TODO: This will be really slow if we keep accumulating shuffle map stages
- for ((shuffleId, stage) <- shuffleToMapStage) {
- stage.removeOutputsOnExecutor(execId)
- val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
- mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true)
- }
- if (shuffleToMapStage.isEmpty) {
- mapOutputTracker.incrementEpoch()
+
+ if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) {
+ // TODO: This will be really slow if we keep accumulating shuffle map stages
+ for ((shuffleId, stage) <- shuffleToMapStage) {
+ stage.removeOutputsOnExecutor(execId)
+ val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
+ mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true)
+ }
+ if (shuffleToMapStage.isEmpty) {
+ mapOutputTracker.incrementEpoch()
+ }
+ clearCacheLocs()
}
- clearCacheLocs()
} else {
logDebug("Additional executor lost message for " + execId +
"(epoch " + currentEpoch + ")")
@@ -1382,7 +1397,7 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule
dagScheduler.handleExecutorAdded(execId, host)
case ExecutorLost(execId) =>
- dagScheduler.handleExecutorLost(execId)
+ dagScheduler.handleExecutorLost(execId, fetchFailed = false)
case BeginEvent(task, taskInfo) =>
dagScheduler.handleBeginEvent(task, taskInfo)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 100c9ba9b7809..597dbc884913c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -142,7 +142,7 @@ private[spark] object EventLoggingListener extends Logging {
val SPARK_VERSION_PREFIX = "SPARK_VERSION_"
val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_"
val APPLICATION_COMPLETE = "APPLICATION_COMPLETE"
- val LOG_FILE_PERMISSIONS = FsPermission.createImmutable(Integer.parseInt("770", 8).toShort)
+ val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort)
// A cache for compression codecs to avoid creating the same codec many times
private val codecMap = new mutable.HashMap[String, CompressionCodec]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 54904bffdf10b..3bb54855bae44 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -158,6 +158,11 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
" INPUT_BYTES=" + metrics.bytesRead
case None => ""
}
+ val outputMetrics = taskMetrics.outputMetrics match {
+ case Some(metrics) =>
+ " OUTPUT_BYTES=" + metrics.bytesWritten
+ case None => ""
+ }
val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match {
case Some(metrics) =>
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
@@ -173,7 +178,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
" SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime
case None => ""
}
- stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics +
+ stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + outputMetrics +
shuffleReadMetrics + writeMetrics)
}
@@ -215,7 +220,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
" STAGE_ID=" + taskEnd.stageId
stageLogInfo(taskEnd.stageId, taskStatus)
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) =>
taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
mapId + " REDUCE_ID=" + reduceId
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 071568cdfb429..cc13f57a49b89 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -102,6 +102,11 @@ private[spark] class Stage(
}
}
+ /**
+ * Removes all shuffle outputs associated with this executor. Note that this will also remove
+ * outputs which are served by an external shuffle server (if one exists), as they are still
+ * registered with this execId.
+ */
def removeOutputsOnExecutor(execId: String) {
var becameUnavailable = false
for (partition <- 0 until numPartitions) {
@@ -131,4 +136,9 @@ private[spark] class Stage(
override def toString = "Stage " + id
override def hashCode(): Int = id
+
+ override def equals(other: Any): Boolean = other match {
+ case stage: Stage => stage != null && stage.id == id
+ case _ => false
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 11c19eeb6e42c..1f114a0207f7b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -31,8 +31,8 @@ import org.apache.spark.util.Utils
private[spark] sealed trait TaskResult[T]
/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
-private[spark]
-case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable
+private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
+ extends TaskResult[T] with Serializable
/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 3f345ceeaaf7a..819b51e12ad8c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -47,9 +47,18 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
try {
- val result = serializer.get().deserialize[TaskResult[_]](serializedData) match {
- case directResult: DirectTaskResult[_] => directResult
- case IndirectTaskResult(blockId) =>
+ val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
+ case directResult: DirectTaskResult[_] =>
+ if (!taskSetManager.canFetchMoreResults(serializedData.limit())) {
+ return
+ }
+ (directResult, serializedData.limit())
+ case IndirectTaskResult(blockId, size) =>
+ if (!taskSetManager.canFetchMoreResults(size)) {
+ // dropped by executor if size is larger than maxResultSize
+ sparkEnv.blockManager.master.removeBlock(blockId)
+ return
+ }
logDebug("Fetching indirect task result for TID %s".format(tid))
scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
@@ -64,9 +73,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
serializedTaskResult.get)
sparkEnv.blockManager.master.removeBlock(blockId)
- deserializedResult
+ (deserializedResult, size)
}
- result.metrics.resultSize = serializedData.limit()
+
+ result.metrics.resultSize = size
scheduler.handleSuccessfulTask(taskSetManager, tid, result)
} catch {
case cnf: ClassNotFoundException =>
@@ -93,7 +103,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
}
} catch {
case cnd: ClassNotFoundException =>
- // Log an error but keep going here -- the task failed, so not catastropic if we can't
+ // Log an error but keep going here -- the task failed, so not catastrophic if we can't
// deserialize the reason.
val loader = Utils.getContextOrSparkClassLoader
logError(
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index a129a434c9a1a..f095915352b17 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -23,7 +23,7 @@ import org.apache.spark.storage.BlockManagerId
/**
* Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl.
- * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks
+ * This interface allows plugging in different task schedulers. Each TaskScheduler schedules tasks
* for a single SparkContext. These schedulers get sets of tasks submitted to them from the
* DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running
* them, retrying if there are failures, and mitigating stragglers. They return events to the
@@ -41,7 +41,7 @@ private[spark] trait TaskScheduler {
// Invoked after system has successfully initialized (typically in spark context).
// Yarn uses this to bootstrap allocation of resources based on preferred locations,
- // wait for slave registerations, etc.
+ // wait for slave registrations, etc.
def postStartHook() { }
// Disconnect from the cluster.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index a6c23fc85a1b0..d8fb640350343 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -23,13 +23,12 @@ import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
+import scala.math.{min, max}
import org.apache.spark._
-import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.util.{Clock, SystemClock}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.util.{Clock, SystemClock, Utils}
/**
* Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of
@@ -68,6 +67,9 @@ private[spark] class TaskSetManager(
val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75)
val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5)
+ // Limit of bytes for total size of results (default is 1GB)
+ val maxResultSize = Utils.getMaxResultSize(conf)
+
// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()
@@ -89,6 +91,8 @@ private[spark] class TaskSetManager(
var stageId = taskSet.stageId
var name = "TaskSet_" + taskSet.stageId.toString
var parent: Pool = null
+ var totalResultSize = 0L
+ var calculatedTasks = 0
val runningTasksSet = new HashSet[Long]
override def runningTasks = runningTasksSet.size
@@ -515,12 +519,33 @@ private[spark] class TaskSetManager(
index
}
+ /**
+ * Marks the task as getting result and notifies the DAG Scheduler
+ */
def handleTaskGettingResult(tid: Long) = {
val info = taskInfos(tid)
info.markGettingResult()
sched.dagScheduler.taskGettingResult(info)
}
+ /**
+ * Check whether has enough quota to fetch the result with `size` bytes
+ */
+ def canFetchMoreResults(size: Long): Boolean = synchronized {
+ totalResultSize += size
+ calculatedTasks += 1
+ if (maxResultSize > 0 && totalResultSize > maxResultSize) {
+ val msg = s"Total size of serialized results of ${calculatedTasks} tasks " +
+ s"(${Utils.bytesToString(totalResultSize)}) is bigger than maxResultSize " +
+ s"(${Utils.bytesToString(maxResultSize)})"
+ logError(msg)
+ abort(msg)
+ false
+ } else {
+ true
+ }
+ }
+
/**
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
@@ -687,10 +712,11 @@ private[spark] class TaskSetManager(
addPendingTask(index, readding=true)
}
- // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage.
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage,
+ // and we are not using an external shuffle server which could serve the shuffle outputs.
// The reason is the next stage wouldn't be able to fetch the data from this dead executor
// so we would need to rerun these tasks on other executors.
- if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (successful(index)) {
@@ -706,7 +732,7 @@ private[spark] class TaskSetManager(
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
- handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure)
+ handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId))
}
// recalculate valid locality levels and waits when executor is lost
recomputeLocality()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index d8c0e2f66df01..5289661eb896b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -93,7 +93,7 @@ private[spark] class CoarseMesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = CoarseMesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try { {
val ret = driver.run()
@@ -242,8 +242,7 @@ private[spark] class CoarseMesosSchedulerBackend(
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
- // If we reached here, no resource with the required name was present
- throw new IllegalArgumentException("No resource called " + name + " in " + res)
+ 0
}
/** Build a Mesos resource protobuf object */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 8e2faff90f9b2..c5f3493477bc5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -72,7 +72,7 @@ private[spark] class MesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = MesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try {
val ret = driver.run()
@@ -278,8 +278,7 @@ private[spark] class MesosSchedulerBackend(
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
- // If we reached here, no resource with the required name was present
- throw new IllegalArgumentException("No resource called " + name + " in " + res)
+ 0
}
/** Turn a Spark TaskDescription into a Mesos task */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index 58b78f041cd85..c0264836de738 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -21,7 +21,7 @@ import java.nio.ByteBuffer
import akka.actor.{Actor, ActorRef, Props}
-import org.apache.spark.{Logging, SparkEnv, TaskState}
+import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer}
@@ -47,7 +47,7 @@ private[spark] class LocalActor(
private var freeCores = totalCores
- private val localExecutorId = "localhost"
+ private val localExecutorId = SparkContext.DRIVER_IDENTIFIER
private val localExecutorHostname = "localhost"
val executor = new Executor(
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index 71c08e9d5a8c3..be184464e0ae9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.{FetchFailed, TaskEndReason}
+import org.apache.spark.util.Utils
/**
* Failed to fetch a shuffle block. The executor catches this exception and propagates it
@@ -30,13 +31,22 @@ private[spark] class FetchFailedException(
bmAddress: BlockManagerId,
shuffleId: Int,
mapId: Int,
- reduceId: Int)
- extends Exception {
+ reduceId: Int,
+ message: String,
+ cause: Throwable = null)
+ extends Exception(message, cause) {
- override def getMessage: String =
- "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
+ def this(
+ bmAddress: BlockManagerId,
+ shuffleId: Int,
+ mapId: Int,
+ reduceId: Int,
+ cause: Throwable) {
+ this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
+ }
- def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId)
+ def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
+ Utils.exceptionString(this))
}
/**
@@ -46,7 +56,4 @@ private[spark] class MetadataFetchFailedException(
shuffleId: Int,
reduceId: Int,
message: String)
- extends FetchFailedException(null, shuffleId, -1, reduceId) {
-
- override def getMessage: String = message
-}
+ extends FetchFailedException(null, shuffleId, -1, reduceId, message)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index 1fb5b2c4546bd..7de2f9cbb2866 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -27,6 +27,7 @@ import scala.collection.JavaConversions._
import org.apache.spark.{Logging, SparkConf, SparkEnv}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.storage._
@@ -62,11 +63,14 @@ private[spark] trait ShuffleWriterGroup {
* each block stored in each file. In order to find the location of a shuffle block, we search the
* files within a ShuffleFileGroups associated with the block's reducer.
*/
-
+// Note: Changes to the format in this file should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData().
private[spark]
class FileShuffleBlockManager(conf: SparkConf)
extends ShuffleBlockManager with Logging {
+ private val transportConf = SparkTransportConf.fromSparkConf(conf)
+
private lazy val blockManager = SparkEnv.get.blockManager
// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
@@ -181,13 +185,14 @@ class FileShuffleBlockManager(conf: SparkConf)
val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId)
if (segmentOpt.isDefined) {
val segment = segmentOpt.get
- return new FileSegmentManagedBuffer(segment.file, segment.offset, segment.length)
+ return new FileSegmentManagedBuffer(
+ transportConf, segment.file, segment.offset, segment.length)
}
}
throw new IllegalStateException("Failed to find shuffle block: " + blockId)
} else {
val file = blockManager.diskBlockManager.getFile(blockId)
- new FileSegmentManagedBuffer(file, 0, file.length)
+ new FileSegmentManagedBuffer(transportConf, file, 0, file.length)
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
index e9805c9c134b5..b292587d37028 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
@@ -22,8 +22,9 @@ import java.nio.ByteBuffer
import com.google.common.io.ByteStreams
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.storage._
/**
@@ -35,11 +36,15 @@ import org.apache.spark.storage._
* as the filename postfix for data file, and ".index" as the filename postfix for index file.
*
*/
+// Note: Changes to the format in this file should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData().
private[spark]
-class IndexShuffleBlockManager extends ShuffleBlockManager {
+class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager {
private lazy val blockManager = SparkEnv.get.blockManager
+ private val transportConf = SparkTransportConf.fromSparkConf(conf)
+
/**
* Mapping to a single shuffleBlockId with reduce ID 0.
* */
@@ -107,6 +112,7 @@ class IndexShuffleBlockManager extends ShuffleBlockManager {
val offset = in.readLong()
val nextOffset = in.readLong()
new FileSegmentManagedBuffer(
+ transportConf,
getDataFile(blockId.shuffleId, blockId.mapId),
offset,
nextOffset - offset)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 6cf9305977a3c..e3e7434df45b0 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.hash
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import scala.util.{Failure, Success, Try}
import org.apache.spark._
import org.apache.spark.serializer.Serializer
@@ -52,21 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
- def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
+ def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
- case Some(block) => {
+ case Success(block) => {
block.asInstanceOf[Iterator[T]]
}
- case None => {
+ case Failure(e) => {
blockId match {
case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
- throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId)
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
case _ =>
throw new SparkException(
- "Failed to get block " + blockId + ", which is not a shuffle block")
+ "Failed to get block " + blockId + ", which is not a shuffle block", e)
}
}
}
@@ -74,7 +75,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
- SparkEnv.get.blockTransferService,
+ SparkEnv.get.blockManager.shuffleClient,
blockManager,
blocksByAddress,
serializer,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 746ed33b54c00..183a30373b28c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -107,7 +107,7 @@ private[spark] class HashShuffleWriter[K, V](
writer.commitAndClose()
writer.fileSegment().length
}
- MapStatus(blockManager.blockManagerId, sizes)
+ MapStatus(blockManager.shuffleServerId, sizes)
}
private def revertWrites(): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index b727438ae7e47..bda30a56d808e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -25,7 +25,7 @@ import org.apache.spark.shuffle.hash.HashShuffleReader
private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager {
- private val indexShuffleBlockManager = new IndexShuffleBlockManager()
+ private val indexShuffleBlockManager = new IndexShuffleBlockManager(conf)
private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]()
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 927481b72cf4f..d75f9d7311fad 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -70,7 +70,7 @@ private[spark] class SortShuffleWriter[K, V, C](
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
- mapStatus = MapStatus(blockManager.blockManagerId, partitionLengths)
+ mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}
/** Close this writer, passing along whether the map completed */
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 8df5ec6bde184..1f012941c85ab 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -53,6 +53,8 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
def name = "rdd_" + rddId + "_" + splitIndex
}
+// Format of the shuffle block ids (including data and index) should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData().
@DeveloperApi
case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 58510d7232436..39434f473a9d8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -21,9 +21,9 @@ import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream,
import java.nio.{ByteBuffer, MappedByteBuffer}
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.concurrent.{Await, Future}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
-import scala.concurrent.{Await, Future}
import scala.util.Random
import akka.actor.{ActorSystem, Props}
@@ -34,8 +34,13 @@ import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService}
+import org.apache.spark.network.shuffle.ExternalShuffleClient
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
+import org.apache.spark.network.util.{ConfigProvider, TransportConf}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
+import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.util._
private[spark] sealed trait BlockValues
@@ -52,6 +57,12 @@ private[spark] class BlockResult(
inputMetrics.bytesRead = bytes
}
+/**
+ * Manager running on every node (driver and executors) which provides interfaces for putting and
+ * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap).
+ *
+ * Note that #initialize() must be called before the BlockManager is usable.
+ */
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
@@ -61,11 +72,10 @@ private[spark] class BlockManager(
val conf: SparkConf,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
- blockTransferService: BlockTransferService)
+ blockTransferService: BlockTransferService,
+ securityManager: SecurityManager)
extends BlockDataManager with Logging {
- blockTransferService.init(this)
-
val diskBlockManager = new DiskBlockManager(this, conf)
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -85,8 +95,37 @@ private[spark] class BlockManager(
new TachyonStore(this, tachyonBlockManager)
}
- val blockManagerId = BlockManagerId(
- executorId, blockTransferService.hostName, blockTransferService.port)
+ private[spark]
+ val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
+
+ // Port used by the external shuffle service. In Yarn mode, this may be already be
+ // set through the Hadoop configuration as the server is launched in the Yarn NM.
+ private val externalShuffleServicePort =
+ Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt
+
+ // Check that we're not using external shuffle service with consolidated shuffle files.
+ if (externalShuffleServiceEnabled
+ && conf.getBoolean("spark.shuffle.consolidateFiles", false)
+ && shuffleManager.isInstanceOf[HashShuffleManager]) {
+ throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated"
+ + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or "
+ + " switch to sort-based shuffle.")
+ }
+
+ var blockManagerId: BlockManagerId = _
+
+ // Address of the server that serves this executor's shuffle files. This is either an external
+ // service, or just our own Executor's BlockManager.
+ private[spark] var shuffleServerId: BlockManagerId = _
+
+ // Client to read other executors' shuffle files. This is either an external service, or just the
+ // standard BlockTranserService to directly connect to other Executors.
+ private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
+ new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), securityManager,
+ securityManager.isAuthenticationEnabled())
+ } else {
+ blockTransferService
+ }
// Whether to compress broadcast variables that are stored
private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
@@ -116,8 +155,6 @@ private[spark] class BlockManager(
private val peerFetchLock = new Object
private var lastPeerFetchTime = 0L
- initialize()
-
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -136,17 +173,65 @@ private[spark] class BlockManager(
conf: SparkConf,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
- blockTransferService: BlockTransferService) = {
+ blockTransferService: BlockTransferService,
+ securityManager: SecurityManager) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
- conf, mapOutputTracker, shuffleManager, blockTransferService)
+ conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager)
}
/**
- * Initialize the BlockManager. Register to the BlockManagerMaster, and start the
- * BlockManagerWorker actor.
+ * Initializes the BlockManager with the given appId. This is not performed in the constructor as
+ * the appId may not be known at BlockManager instantiation time (in particular for the driver,
+ * where it is only learned after registration with the TaskScheduler).
+ *
+ * This method initializes the BlockTransferService and ShuffleClient, registers with the
+ * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle
+ * service if configured.
*/
- private def initialize(): Unit = {
+ def initialize(appId: String): Unit = {
+ blockTransferService.init(this)
+ shuffleClient.init(appId)
+
+ blockManagerId = BlockManagerId(
+ executorId, blockTransferService.hostName, blockTransferService.port)
+
+ shuffleServerId = if (externalShuffleServiceEnabled) {
+ BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
+ } else {
+ blockManagerId
+ }
+
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
+
+ // Register Executors' configuration with the local shuffle service, if one should exist.
+ if (externalShuffleServiceEnabled && !blockManagerId.isDriver) {
+ registerWithExternalShuffleServer()
+ }
+ }
+
+ private def registerWithExternalShuffleServer() {
+ logInfo("Registering executor with local external shuffle service.")
+ val shuffleConfig = new ExecutorShuffleInfo(
+ diskBlockManager.localDirs.map(_.toString),
+ diskBlockManager.subDirsPerLocalDir,
+ shuffleManager.getClass.getName)
+
+ val MAX_ATTEMPTS = 3
+ val SLEEP_TIME_SECS = 5
+
+ for (i <- 1 to MAX_ATTEMPTS) {
+ try {
+ // Synchronous and will throw an exception if we cannot connect.
+ shuffleClient.asInstanceOf[ExternalShuffleClient].registerWithShuffleServer(
+ shuffleServerId.host, shuffleServerId.port, shuffleServerId.executorId, shuffleConfig)
+ return
+ } catch {
+ case e: Exception if i < MAX_ATTEMPTS =>
+ logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}"
+ + s" more times after waiting $SLEEP_TIME_SECS seconds...", e)
+ Thread.sleep(SLEEP_TIME_SECS * 1000)
+ }
+ }
}
/**
@@ -506,7 +591,7 @@ private[spark] class BlockManager(
for (loc <- locations) {
logDebug(s"Getting remote block $blockId from $loc")
val data = blockTransferService.fetchBlockSync(
- loc.host, loc.port, blockId.toString).nioByteBuffer()
+ loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer()
if (data != null) {
if (asBlockResult) {
@@ -855,7 +940,7 @@ private[spark] class BlockManager(
data.rewind()
logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
blockTransferService.uploadBlockSync(
- peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel)
+ peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel)
logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms"
.format(System.currentTimeMillis - onePeerStartTime))
peersReplicatedTo += peer
@@ -1113,6 +1198,10 @@ private[spark] class BlockManager(
def stop(): Unit = {
blockTransferService.close()
+ if (shuffleClient ne blockTransferService) {
+ // Closing should be idempotent, but maybe not for the NioBlockTransferService.
+ shuffleClient.close()
+ }
diskBlockManager.stop()
actorSystem.stop(slaveActor)
blockInfo.clear()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index 259f423c73e6b..b177a59c721df 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
@@ -59,7 +60,7 @@ class BlockManagerId private (
def port: Int = port_
- def isDriver: Boolean = (executorId == "")
+ def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER }
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeUTF(executorId_)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index d08e1419e3e41..b63c7f191155c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -88,6 +88,10 @@ class BlockManagerMaster(
askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
}
+ def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId))
+ }
+
/**
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 5e375a2553979..685b2e11440fb 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -86,6 +86,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case GetPeers(blockManagerId) =>
sender ! getPeers(blockManagerId)
+ case GetActorSystemHostPortForExecutor(executorId) =>
+ sender ! getActorSystemHostPortForExecutor(executorId)
+
case GetMemoryStatus =>
sender ! memoryStatus
@@ -412,6 +415,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
Seq.empty
}
}
+
+ /**
+ * Returns the hostname and port of an executor's actor system, based on the Akka address of its
+ * BlockManagerSlaveActor.
+ */
+ private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ for (
+ blockManagerId <- blockManagerIdByExecutor.get(executorId);
+ info <- blockManagerInfo.get(blockManagerId);
+ host <- info.slaveActor.path.address.host;
+ port <- info.slaveActor.path.address.port
+ ) yield {
+ (host, port)
+ }
+ }
}
@DeveloperApi
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 291ddfcc113ac..3f32099d08cc9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -92,6 +92,8 @@ private[spark] object BlockManagerMessages {
case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
+ case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster
+
case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
case object StopBlockManagerMaster extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 99e925328a4b9..58fba54710510 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -38,12 +38,13 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
extends Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
+ private[spark]
+ val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
/* Create one local directory for each path mentioned in spark.local.dir; then, inside this
* directory, create multiple subdirectories that we will hash files into, in order to avoid
* having really large inodes at the top level. */
- val localDirs: Array[File] = createLocalDirs(conf)
+ private[spark] val localDirs: Array[File] = createLocalDirs(conf)
if (localDirs.isEmpty) {
logError("Failed to create any local dir.")
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
@@ -52,6 +53,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
addShutdownHook()
+ /** Looks up a file by hashing it into one of our local subdirectories. */
+ // This method should be kept in sync with
+ // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getFile().
def getFile(filename: String): File = {
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = Utils.nonNegativeHash(filename)
@@ -159,13 +163,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
/** Cleanup local dirs and stop shuffle sender. */
private[spark] def stop() {
- localDirs.foreach { localDir =>
- if (localDir.isDirectory() && localDir.exists()) {
- try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
- } catch {
- case e: Exception =>
- logError(s"Exception while deleting local spark dir: $localDir", e)
+ // Only perform cleanup if an external service is not serving our shuffle files.
+ if (!blockManager.externalShuffleServiceEnabled) {
+ localDirs.foreach { localDir =>
+ if (localDir.isDirectory() && localDir.exists()) {
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case e: Exception =>
+ logError(s"Exception while deleting local spark dir: $localDir", e)
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 0d6f3bf003a9d..6b1f57a069431 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -20,9 +20,11 @@ package org.apache.spark.storage
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
+import scala.util.{Failure, Success, Try}
import org.apache.spark.{Logging, TaskContext}
-import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
+import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.{CompletionIterator, Utils}
@@ -38,8 +40,8 @@ import org.apache.spark.util.{CompletionIterator, Utils}
* using too much memory.
*
* @param context [[TaskContext]], used for metrics update
- * @param blockTransferService [[BlockTransferService]] for fetching remote blocks
- * @param blockManager [[BlockManager]] for reading local blocks
+ * @param shuffleClient [[ShuffleClient]] for fetching remote blocks
+ * @param blockManager [[BlockManager]] for reading local blocks
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
@@ -49,12 +51,12 @@ import org.apache.spark.util.{CompletionIterator, Utils}
private[spark]
final class ShuffleBlockFetcherIterator(
context: TaskContext,
- blockTransferService: BlockTransferService,
+ shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer,
maxBytesInFlight: Long)
- extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
+ extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging {
import ShuffleBlockFetcherIterator._
@@ -90,7 +92,7 @@ final class ShuffleBlockFetcherIterator(
* Current [[FetchResult]] being processed. We track this so we can release the current buffer
* in case of a runtime exception when processing the current buffer.
*/
- private[this] var currentResult: FetchResult = null
+ @volatile private[this] var currentResult: FetchResult = null
/**
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
@@ -117,16 +119,18 @@ final class ShuffleBlockFetcherIterator(
private[this] def cleanup() {
isZombie = true
// Release the current buffer if necessary
- if (currentResult != null && !currentResult.failed) {
- currentResult.buf.release()
+ currentResult match {
+ case SuccessFetchResult(_, _, buf) => buf.release()
+ case _ =>
}
// Release buffers in the results queue
val iter = results.iterator()
while (iter.hasNext) {
val result = iter.next()
- if (!result.failed) {
- result.buf.release()
+ result match {
+ case SuccessFetchResult(_, _, buf) => buf.release()
+ case _ =>
}
}
}
@@ -140,7 +144,8 @@ final class ShuffleBlockFetcherIterator(
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val blockIds = req.blocks.map(_._1.toString)
- blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds,
+ val address = req.address
+ shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
@@ -149,7 +154,7 @@ final class ShuffleBlockFetcherIterator(
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
- results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), buf))
+ results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))
shuffleMetrics.remoteBytesRead += buf.size
shuffleMetrics.remoteBlocksFetched += 1
}
@@ -158,7 +163,7 @@ final class ShuffleBlockFetcherIterator(
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
- results.put(new FetchResult(BlockId(blockId), -1, null))
+ results.put(new FailureFetchResult(BlockId(blockId), e))
}
}
)
@@ -179,7 +184,7 @@ final class ShuffleBlockFetcherIterator(
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
- if (address == blockManager.blockManagerId) {
+ if (address.executorId == blockManager.blockManagerId.executorId) {
// Filter out zero-sized blocks
localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
numBlocksToFetch += localBlocks.size
@@ -229,12 +234,12 @@ final class ShuffleBlockFetcherIterator(
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.localBlocksFetched += 1
buf.retain()
- results.put(new FetchResult(blockId, 0, buf))
+ results.put(new SuccessFetchResult(blockId, 0, buf))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
- results.put(new FetchResult(blockId, -1, null))
+ results.put(new FailureFetchResult(blockId, e))
return
}
}
@@ -265,15 +270,17 @@ final class ShuffleBlockFetcherIterator(
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
- override def next(): (BlockId, Option[Iterator[Any]]) = {
+ override def next(): (BlockId, Try[Iterator[Any]]) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
val result = currentResult
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
- if (!result.failed) {
- bytesInFlight -= result.size
+
+ result match {
+ case SuccessFetchResult(_, size, _) => bytesInFlight -= size
+ case _ =>
}
// Send fetch requests up to maxBytesInFlight
while (fetchRequests.nonEmpty &&
@@ -281,20 +288,21 @@ final class ShuffleBlockFetcherIterator(
sendRequest(fetchRequests.dequeue())
}
- val iteratorOpt: Option[Iterator[Any]] = if (result.failed) {
- None
- } else {
- val is = blockManager.wrapForCompression(result.blockId, result.buf.createInputStream())
- val iter = serializer.newInstance().deserializeStream(is).asIterator
- Some(CompletionIterator[Any, Iterator[Any]](iter, {
- // Once the iterator is exhausted, release the buffer and set currentResult to null
- // so we don't release it again in cleanup.
- currentResult = null
- result.buf.release()
- }))
+ val iteratorTry: Try[Iterator[Any]] = result match {
+ case FailureFetchResult(_, e) => Failure(e)
+ case SuccessFetchResult(blockId, _, buf) => {
+ val is = blockManager.wrapForCompression(blockId, buf.createInputStream())
+ val iter = serializer.newInstance().deserializeStream(is).asIterator
+ Success(CompletionIterator[Any, Iterator[Any]](iter, {
+ // Once the iterator is exhausted, release the buffer and set currentResult to null
+ // so we don't release it again in cleanup.
+ currentResult = null
+ buf.release()
+ }))
+ }
}
- (result.blockId, iteratorOpt)
+ (result.blockId, iteratorTry)
}
}
@@ -313,14 +321,30 @@ object ShuffleBlockFetcherIterator {
}
/**
- * Result of a fetch from a remote block. A failure is represented as size == -1.
+ * Result of a fetch from a remote block.
+ */
+ private[storage] sealed trait FetchResult {
+ val blockId: BlockId
+ }
+
+ /**
+ * Result of a fetch from a remote block successfully.
* @param blockId block id
* @param size estimated size of the block, used to calculate bytesInFlight.
- * Note that this is NOT the exact bytes. -1 if failure is present.
- * @param buf [[ManagedBuffer]] for the content. null is error.
+ * Note that this is NOT the exact bytes.
+ * @param buf [[ManagedBuffer]] for the content.
*/
- case class FetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) {
- def failed: Boolean = size == -1
- if (failed) assert(buf == null) else assert(buf != null)
+ private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer)
+ extends FetchResult {
+ require(buf != null)
+ require(size >= 0)
}
+
+ /**
+ * Result of a fetch from a remote block unsuccessfully.
+ * @param blockId block id
+ * @param e the failure exception
+ */
+ private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable)
+ extends FetchResult
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
index d9066f766476e..def49e80a3605 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
@@ -19,6 +19,7 @@ package org.apache.spark.storage
import scala.collection.mutable
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.scheduler._
@@ -59,10 +60,9 @@ class StorageStatusListener extends SparkListener {
val info = taskEnd.taskInfo
val metrics = taskEnd.taskMetrics
if (info != null && metrics != null) {
- val execId = formatExecutorId(info.executorId)
val updatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
if (updatedBlocks.length > 0) {
- updateStorageStatus(execId, updatedBlocks)
+ updateStorageStatus(info.executorId, updatedBlocks)
}
}
}
@@ -88,13 +88,4 @@ class StorageStatusListener extends SparkListener {
}
}
- /**
- * In the local mode, there is a discrepancy between the executor ID according to the
- * task ("localhost") and that according to SparkEnv (""). In the UI, this
- * results in duplicate rows for the same executor. Thus, in this mode, we aggregate
- * these two rows and use the executor ID of "" to be consistent.
- */
- def formatExecutorId(execId: String): String = {
- if (execId == "localhost") "" else execId
- }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index 9ced9b8107ebf..6f446c5a95a0a 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -24,11 +24,28 @@ private[spark] object ToolTips {
scheduler delay is large, consider decreasing the size of tasks or decreasing the size
of task results."""
+ val TASK_DESERIALIZATION_TIME =
+ """Time spent deserializating the task closure on the executor."""
+
val INPUT = "Bytes read from Hadoop or from Spark storage."
+ val OUTPUT = "Bytes written to Hadoop."
+
val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage."
val SHUFFLE_READ =
"""Bytes read from remote executors. Typically less than shuffle write bytes
because this does not include shuffle data read locally."""
+
+ val GETTING_RESULT_TIME =
+ """Time that the driver spends fetching task results from workers. If this is large, consider
+ decreasing the amount of data returned from each task."""
+
+ val RESULT_SERIALIZATION_TIME =
+ """Time spent serializing the task result on the executor before sending it back to the
+ driver."""
+
+ val GC_TIME =
+ """Time that the executor spent paused for Java garbage collection while the task was
+ running."""
}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 76714b1e6964f..3312671b6f885 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -20,13 +20,13 @@ package org.apache.spark.ui
import java.text.SimpleDateFormat
import java.util.{Locale, Date}
-import scala.xml.{Text, Node}
+import scala.xml.{Node, Text}
import org.apache.spark.Logging
/** Utility functions for generating XML pages with spark content. */
private[spark] object UIUtils extends Logging {
- val TABLE_CLASS = "table table-bordered table-striped table-condensed sortable"
+ val TABLE_CLASS = "table table-bordered table-striped-custom table-condensed sortable"
// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
@@ -160,6 +160,8 @@ private[spark] object UIUtils extends Logging {
+
+
}
/** Returns a spark page with correctly formatted headers */
@@ -240,7 +242,8 @@ private[spark] object UIUtils extends Logging {
generateDataRow: T => Seq[Node],
data: Iterable[T],
fixedWidth: Boolean = false,
- id: Option[String] = None): Seq[Node] = {
+ id: Option[String] = None,
+ headerClasses: Seq[String] = Seq.empty): Seq[Node] = {
var listingTableClass = TABLE_CLASS
if (fixedWidth) {
@@ -248,20 +251,29 @@ private[spark] object UIUtils extends Logging {
}
val colWidth = 100.toDouble / headers.size
val colWidthAttr = if (fixedWidth) colWidth + "%" else ""
- val headerRow: Seq[Node] = {
- // if none of the headers have "\n" in them
- if (headers.forall(!_.contains("\n"))) {
- // represent header as simple text
- headers.map(h =>
{h}
)
+
+ def getClass(index: Int): String = {
+ if (index < headerClasses.size) {
+ headerClasses(index)
} else {
- // represent header text as list while respecting "\n"
- headers.map { case h =>
-
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
new file mode 100644
index 0000000000000..e9c755e36f716
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.exec
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.util.Try
+import scala.xml.{Text, Node}
+
+import org.apache.spark.ui.{UIUtils, WebUIPage}
+
+private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") {
+
+ private val sc = parent.sc
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val executorId = Option(request.getParameter("executorId")).getOrElse {
+ return Text(s"Missing executorId parameter")
+ }
+ val time = System.currentTimeMillis()
+ val maybeThreadDump = sc.get.getExecutorThreadDump(executorId)
+
+ val content = maybeThreadDump.map { threadDump =>
+ val dumpRows = threadDump.map { thread =>
+
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index b5207360510dd..8bbde51e1801c 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -59,6 +59,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
val failedStages = ListBuffer[StageInfo]()
val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData]
val stageIdToInfo = new HashMap[StageId, StageInfo]
+
+ // Number of completed and failed stages, may not actually equal to completedStages.size and
+ // failedStages.size respectively due to completedStage and failedStages only maintain the latest
+ // part of the stages, the earlier ones will be removed when there are too many stages for
+ // memory sake.
+ var numCompletedStages = 0
+ var numFailedStages = 0
// Map from pool name to a hash map (map from stage id to StageInfo).
val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]()
@@ -110,9 +117,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
activeStages.remove(stage.stageId)
if (stage.failureReason.isEmpty) {
completedStages += stage
+ numCompletedStages += 1
trimIfNecessary(completedStages)
} else {
failedStages += stage
+ numFailedStages += 1
trimIfNecessary(failedStages)
}
}
@@ -250,6 +259,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
stageData.inputBytes += inputBytesDelta
execSummary.inputBytes += inputBytesDelta
+ val outputBytesDelta =
+ (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L)
+ - oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L))
+ stageData.outputBytes += outputBytesDelta
+ execSummary.outputBytes += outputBytesDelta
+
val diskSpillDelta =
taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L)
stageData.diskBytesSpilled += diskSpillDelta
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
index 6e718eecdd52a..83a7898071c9b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
@@ -34,7 +34,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
listener.synchronized {
val activeStages = listener.activeStages.values.toSeq
val completedStages = listener.completedStages.reverse.toSeq
+ val numCompletedStages = listener.numCompletedStages
val failedStages = listener.failedStages.reverse.toSeq
+ val numFailedStages = listener.numFailedStages
val now = System.currentTimeMillis
val activeStagesTable =
@@ -69,11 +71,11 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
++
failedStagesTable.toNodeSeq
UIUtils.headerSparkPage("Spark Stages", content, parent)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 2414e4c65237e..16bc3f6c18d09 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -22,10 +22,13 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.{Node, Unparsed}
+import org.apache.commons.lang3.StringEscapeUtils
+
+import org.apache.spark.executor.TaskMetrics
import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils}
import org.apache.spark.ui.jobs.UIData._
import org.apache.spark.util.{Utils, Distribution}
-import org.apache.spark.scheduler.AccumulableInfo
+import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
/** Page showing statistics and task list for a given stage */
private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
@@ -52,12 +55,13 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
val numCompleted = tasks.count(_.taskInfo.finished)
val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables
+ val hasAccumulators = accumulables.size > 0
val hasInput = stageData.inputBytes > 0
+ val hasOutput = stageData.outputBytes > 0
val hasShuffleRead = stageData.shuffleReadBytes > 0
val hasShuffleWrite = stageData.shuffleWriteBytes > 0
val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0
- // scalastyle:off
val summary =
@@ -65,55 +69,121 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
Total task time across all tasks:
{UIUtils.formatDuration(stageData.executorRunTime)}
- {if (hasInput)
+ {if (hasInput) {
+:
+ getFormattedTimeQuantiles(gettingResultTimes)
// The scheduler delay includes the network delay to send the task to the worker
// machine and to send back the result (but not the time to fetch the task result,
// if it needed to be fetched from the block manager on the worker).
val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) =>
- val totalExecutionTime = {
- if (info.gettingResultTime > 0) {
- (info.gettingResultTime - info.launchTime).toDouble
- } else {
- (info.finishTime - info.launchTime).toDouble
- }
- }
- totalExecutionTime - metrics.get.executorRunTime
+ getSchedulerDelay(info, metrics.get).toDouble
}
val schedulerDelayTitle =
-
# Building With Hive and JDBC Support
To enable Hive integration for Spark SQL along with its JDBC server and CLI,
add the `-Phive` profile to your existing build options. By default Spark
will build with Hive 0.13.1 bindings. You can also build for Hive 0.12.0 using
-the `-Phive-0.12.0` profile. NOTE: currently the JDBC server is only
-supported for Hive 0.12.0.
+the `-Phive-0.12.0` profile.
{% highlight bash %}
# Apache Hadoop 2.4.X with Hive 13 support
mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package
@@ -121,8 +118,8 @@ Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.o
Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence:
- mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-0.12.0 clean package
- mvn -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 test
+ mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive clean package
+ mvn -Pyarn -Phadoop-2.3 -Phive test
The ScalaTest plugin also supports running only a specific test suite as follows:
@@ -185,16 +182,16 @@ can be set to control the SBT build. For example:
Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence:
- sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 assembly
- sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 test
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive assembly
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive test
To run only a specific test suite as follows:
- sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 "test-only org.apache.spark.repl.ReplSuite"
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive "test-only org.apache.spark.repl.ReplSuite"
To run test suites of a specific sub project as follows:
- sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 core/test
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive core/test
# Speeding up Compilation with Zinc
diff --git a/docs/configuration.md b/docs/configuration.md
index 3007706a2586e..f0b396e21f198 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -21,16 +21,22 @@ application. These properties can be set directly on a
[SparkConf](api/scala/index.html#org.apache.spark.SparkConf) passed to your
`SparkContext`. `SparkConf` allows you to configure some of the common properties
(e.g. master URL and application name), as well as arbitrary key-value pairs through the
-`set()` method. For example, we could initialize an application as follows:
+`set()` method. For example, we could initialize an application with two threads as follows:
+
+Note that we run with local[2], meaning two threads - which represents "minimal" parallelism,
+which can help detect bugs that only exist when we run in a distributed context.
{% highlight scala %}
val conf = new SparkConf()
- .setMaster("local")
+ .setMaster("local[2]")
.setAppName("CountingSheep")
.set("spark.executor.memory", "1g")
val sc = new SparkContext(conf)
{% endhighlight %}
+Note that we can have more than 1 thread in local mode, and in cases like spark streaming, we may actually
+require one to prevent any sort of starvation issues.
+
## Dynamically Loading Spark Properties
In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For
instance, if you'd like to run the same application with different masters or different
@@ -111,6 +117,18 @@ of the most common options to set are:
(e.g. 512m, 2g).
+
+
spark.driver.maxResultSize
+
1g
+
+ Limit of total size of serialized results of all partitions for each Spark action (e.g. collect).
+ Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size
+ is above this limit.
+ Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory
+ and memory overhead of objects in JVM). Setting a proper limit can protect the driver from
+ out-of-memory errors.
+
+
spark.serializer
org.apache.spark.serializer. JavaSerializer
@@ -359,6 +377,16 @@ Apart from these, the following properties are also available, and may be useful
map-side aggregation and there are at most this many reduce partitions.
+
+
spark.shuffle.blockTransferService
+
netty
+
+ Implementation to use for transferring shuffle and cached blocks between executors. There
+ are two implementations available: netty and nio. Netty-based
+ block transfer is intended to be simpler but equally efficient and is the default option
+ starting in 1.2.
+
+
#### Spark UI
@@ -534,6 +562,9 @@ Apart from these, the following properties are also available, and may be useful
spark.default.parallelism
+ For distributed shuffle operations like reduceByKey and join, the
+ largest number of partitions in a parent RDD. For operations like parallelize
+ with no parent RDDs, it depends on the cluster manager:
Local mode: number of cores on the local machine
Mesos fine grained mode: 8
@@ -541,8 +572,8 @@ Apart from these, the following properties are also available, and may be useful
- Default number of tasks to use across the cluster for distributed shuffle operations
- (groupByKey, reduceByKey, etc) when not set by user.
+ Default number of partitions in RDDs returned by transformations like join,
+ reduceByKey, and parallelize when not set by user.
diff --git a/docs/index.md b/docs/index.md
index edd622ec90f64..171d6ddad62f3 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -112,6 +112,7 @@ options for deployment:
**External Resources:**
* [Spark Homepage](http://spark.apache.org)
+* [Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK)
* [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here
* [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and
exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/),
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index 7978e934fb36b..c696ae9c8e8c8 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -34,7 +34,7 @@ a given dataset, the algorithm returns the best clustering result).
* *initializationSteps* determines the number of steps in the k-means\|\| algorithm.
* *epsilon* determines the distance threshold within which we consider k-means to have converged.
-## Examples
+### Examples
@@ -153,3 +153,97 @@ provided in the [Self-Contained Applications](quick-start.html#self-contained-ap
section of the Spark
Quick Start guide. Be sure to also include *spark-mllib* to your build file as
a dependency.
+
+## Streaming clustering
+
+When data arrive in a stream, we may want to estimate clusters dynamically,
+updating them as new data arrive. MLlib provides support for streaming k-means clustering,
+with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm
+uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign
+all points to their nearest cluster, compute new cluster centers, then update each cluster using:
+
+`\begin{equation}
+ c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t}
+\end{equation}`
+`\begin{equation}
+ n_{t+1} = n_t + m_t
+\end{equation}`
+
+Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned
+to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$`
+is the number of points added to the cluster in the current batch. The decay factor `$\alpha$`
+can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning;
+with `$\alpha$=0` only the most recent data will be used. This is analogous to an
+exponentially-weighted moving average.
+
+The decay can be specified using a `halfLife` parameter, which determines the
+correct decay factor `a` such that, for data acquired
+at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5.
+The unit of time can be specified either as `batches` or `points` and the update rule
+will be adjusted accordingly.
+
+### Examples
+
+This example shows how to estimate clusters on streaming data.
+
+
+
+
+
+First we import the neccessary classes.
+
+{% highlight scala %}
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.clustering.StreamingKMeans
+
+{% endhighlight %}
+
+Then we make an input stream of vectors for training, as well as a stream of labeled data
+points for testing. We assume a StreamingContext `ssc` has been created, see
+[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info.
+
+{% highlight scala %}
+
+val trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse)
+val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse)
+
+{% endhighlight %}
+
+We create a model with random clusters and specify the number of clusters to find
+
+{% highlight scala %}
+
+val numDimensions = 3
+val numClusters = 2
+val model = new StreamingKMeans()
+ .setK(numClusters)
+ .setDecayFactor(1.0)
+ .setRandomCenters(numDimensions, 0.0)
+
+{% endhighlight %}
+
+Now register the streams for training and testing and start the job, printing
+the predicted cluster assignments on new data points as they arrive.
+
+{% highlight scala %}
+
+model.trainOn(trainingData)
+model.predictOnValues(testData).print()
+
+ssc.start()
+ssc.awaitTermination()
+
+{% endhighlight %}
+
+As you add new text files with data the cluster centers will update. Each training
+point should be formatted as `[x1, x2, x3]`, and each test data point
+should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier
+(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir`
+the model will update. Anytime a text file is placed in `/testing/data/dir`
+you will see predictions. With new data, the cluster centers will change!
+
+
+[`Statistics`](api/python/index.html#pyspark.mllib.stat.Statistics$) provides methods to
+run Pearson's chi-squared tests. The following example demonstrates how to run and interpret
+hypothesis tests.
+
+{% highlight python %}
+from pyspark import SparkContext
+from pyspark.mllib.linalg import Vectors, Matrices
+from pyspark.mllib.regresssion import LabeledPoint
+from pyspark.mllib.stat import Statistics
+
+sc = SparkContext()
+
+vec = Vectors.dense(...) # a vector composed of the frequencies of events
+
+# compute the goodness of fit. If a second vector to test against is not supplied as a parameter,
+# the test runs against a uniform distribution.
+goodnessOfFitTestResult = Statistics.chiSqTest(vec)
+print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom,
+ # test statistic, the method used, and the null hypothesis.
+
+mat = Matrices.dense(...) # a contingency matrix
+
+# conduct Pearson's independence test on the input contingency matrix
+independenceTestResult = Statistics.chiSqTest(mat)
+print independenceTestResult # summary of the test including the p-value, degrees of freedom...
+
+obs = sc.parallelize(...) # LabeledPoint(feature, label) .
+
+# The contingency table is constructed from an RDD of LabeledPoint and used to conduct
+# the independence test. Returns an array containing the ChiSquaredTestResult for every feature
+# against the label.
+featureTestResults = Statistics.chiSqTest(obs)
+
+for i, result in enumerate(featureTestResults):
+ print "Column $d:" % (i + 1)
+ print result
+{% endhighlight %}
+
+
## Random data generation
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 695813a2ba881..2f7e4981e5bb9 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -4,7 +4,7 @@ title: Running Spark on YARN
---
Support for running on [YARN (Hadoop
-NextGen)](http://hadoop.apache.org/docs/r2.0.2-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html)
+NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html)
was added to Spark in version 0.6.0, and improved in subsequent releases.
# Preparations
diff --git a/docs/security.md b/docs/security.md
index ec0523184d665..1e206a139fb72 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -7,7 +7,6 @@ Spark currently supports authentication via a shared secret. Authentication can
* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret.
* For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications.
-* **IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.*
## Web UI
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index d4ade939c3a6e..ffcce2c588879 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -582,19 +582,27 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or
spark.sql.parquet.cacheMetadata
-
false
+
true
Turns on caching of Parquet schema metadata. Can speed up querying of static data.
spark.sql.parquet.compression.codec
-
snappy
+
gzip
Sets the compression codec use when writing Parquet files. Acceptable values include:
uncompressed, snappy, gzip, lzo.
+
+
spark.sql.hive.convertMetastoreParquet
+
true
+
+ When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of the built in
+ support.
+
+
## JSON Datasets
@@ -815,7 +823,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL
Property Name
Default
Meaning
spark.sql.inMemoryColumnarStorage.compressed
-
false
+
true
When set to true Spark SQL will automatically select a compression codec for each column based
on statistics of the data.
@@ -823,7 +831,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL
spark.sql.inMemoryColumnarStorage.batchSize
-
1000
+
10000
Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization
and compression, but risk OOMs when caching data.
@@ -841,7 +849,7 @@ that these options will be deprecated in future release as more optimizations ar
Property Name
Default
Meaning
spark.sql.autoBroadcastJoinThreshold
-
10000
+
10485760 (10 MB)
Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when
performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently
@@ -1051,7 +1059,6 @@ in Hive deployments.
**Major Hive Features**
-* Spark SQL does not currently support inserting to tables using dynamic partitioning.
* Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL
doesn't support buckets yet.
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 8bbba88b31978..44a1f3ad7560b 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -68,7 +68,9 @@ import org.apache.spark._
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
-// Create a local StreamingContext with two working thread and batch interval of 1 second
+// Create a local StreamingContext with two working thread and batch interval of 1 second.
+// The master requires 2 cores to prevent from a starvation scenario.
+
val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount")
val ssc = new StreamingContext(conf, Seconds(1))
{% endhighlight %}
@@ -586,11 +588,13 @@ Every input DStream (except file stream) is associated with a single [Receiver](
A receiver is run within a Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the Spark Streaming application. Hence, it is important to remember that Spark Streaming application needs to be allocated enough cores to process the received data, as well as, to run the receiver(s). Therefore, few important points to remember are:
-##### Points to remember:
+##### Points to remember
{:.no_toc}
-- If the number of cores allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them.
-- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs with even one input DStream (file streams are okay) as the receiver will occupy that core and there will be no core left to process the data.
-
+- If the number of threads allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them.
+- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs using a DStream as the receiver (file streams are okay). So, a "local" master URL in a streaming app is generally going to cause starvation for the processor.
+Thus in any streaming app, you generally will want to allocate more than one thread (i.e. set your master to "local[2]") when testing locally.
+See [Spark Properties] (configuration.html#spark-properties.html).
+
### Basic Sources
{:.no_toc}
diff --git a/ec2/spark-ec2 b/ec2/spark-ec2
index 31f9771223e51..4aa908242eeaa 100755
--- a/ec2/spark-ec2
+++ b/ec2/spark-ec2
@@ -18,5 +18,9 @@
# limitations under the License.
#
-cd "`dirname $0`"
-PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py "$@"
+# Preserve the user's CWD so that relative paths are passed correctly to
+#+ the underlying Python script.
+SPARK_EC2_DIR="$(dirname $0)"
+
+PYTHONPATH="${SPARK_EC2_DIR}/third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" \
+ python "${SPARK_EC2_DIR}/spark_ec2.py" "$@"
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 0d6b82b4944f3..a5396c2375915 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -40,9 +40,11 @@
from boto import ec2
DEFAULT_SPARK_VERSION = "1.1.0"
+SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__))
+MESOS_SPARK_EC2_BRANCH = "v4"
# A URL prefix from which to fetch AMI information
-AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list"
+AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH)
class UsageError(Exception):
@@ -583,10 +585,23 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
# NOTE: We should clone the repository before running deploy_files to
# prevent ec2-variables.sh from being overwritten
- ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v4")
+ ssh(
+ host=master,
+ opts=opts,
+ command="rm -rf spark-ec2"
+ + " && "
+ + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH)
+ )
print "Deploying files to master..."
- deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules)
+ deploy_files(
+ conn=conn,
+ root_dir=SPARK_EC2_DIR + "/" + "deploy.generic",
+ opts=opts,
+ master_nodes=master_nodes,
+ slave_nodes=slave_nodes,
+ modules=modules
+ )
print "Running setup on master..."
setup_spark_cluster(master, opts)
@@ -723,6 +738,8 @@ def get_num_disks(instance_type):
# cluster (e.g. lists of masters and slaves). Files are only deployed to
# the first master instance in the cluster, and we expect the setup
# script to be run on that instance to copy them to other nodes.
+#
+# root_dir should be an absolute path to the directory with the files we want to deploy.
def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
active_master = master_nodes[0].public_dns_name
diff --git a/examples/pom.xml b/examples/pom.xml
index bc3291803c324..910eb55308b9d 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -50,6 +50,30 @@
+
+ hbase-hadoop2
+
+
+ hbase.profile
+ hadoop2
+
+
+
+ 0.98.7-hadoop2
+
+
+
+ hbase-hadoop1
+
+
+ !hbase.profile
+
+
+
+ 0.98.7-hadoop1
+
+
+
@@ -120,37 +144,122 @@
spark-streaming-mqtt_${scala.binary.version}${project.version}
-
- org.apache.hbase
- hbase
- ${hbase.version}
-
-
- asm
- asm
-
-
- org.jboss.netty
- netty
-
-
- io.netty
- netty
-
-
- commons-logging
- commons-logging
-
-
- org.jruby
- jruby-complete
-
-
- org.eclipse.jettyjetty-server
+
+ org.apache.hbase
+ hbase-testing-util
+ ${hbase.version}
+
+
+ org.jruby
+ jruby-complete
+
+
+
+
+ org.apache.hbase
+ hbase-protocol
+ ${hbase.version}
+
+
+ org.apache.hbase
+ hbase-common
+ ${hbase.version}
+
+
+ org.apache.hbase
+ hbase-client
+ ${hbase.version}
+
+
+ io.netty
+ netty
+
+
+
+
+ org.apache.hbase
+ hbase-server
+ ${hbase.version}
+
+
+ org.apache.hadoop
+ hadoop-core
+
+
+ org.apache.hadoop
+ hadoop-client
+
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-jobclient
+
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-core
+
+
+ org.apache.hadoop
+ hadoop-auth
+
+
+ org.apache.hadoop
+ hadoop-annotations
+
+
+ org.apache.hadoop
+ hadoop-hdfs
+
+
+ org.apache.hbase
+ hbase-hadoop1-compat
+
+
+ org.apache.commons
+ commons-math
+
+
+ com.sun.jersey
+ jersey-core
+
+
+ org.slf4j
+ slf4j-api
+
+
+ com.sun.jersey
+ jersey-server
+
+
+ com.sun.jersey
+ jersey-core
+
+
+ com.sun.jersey
+ jersey-json
+
+
+
+ commons-io
+ commons-io
+
+
+
+
+ org.apache.hbase
+ hbase-hadoop-compat
+ ${hbase.version}
+
+
+ org.apache.hbase
+ hbase-hadoop-compat
+ ${hbase.version}
+ test-jar
+ test
+ com.twitteralgebird-core_${scala.binary.version}
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java
index 6c177de359b60..31a79ddd3fff1 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java
@@ -30,12 +30,25 @@
/**
* Logistic regression based classification.
+ *
+ * This is an example implementation for learning how to use Spark. For more conventional use,
+ * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
*/
public final class JavaHdfsLR {
private static final int D = 10; // Number of dimensions
private static final Random rand = new Random(42);
+ static void showWarning() {
+ String warning = "WARN: This is a naive implementation of Logistic Regression " +
+ "and is given as an example!\n" +
+ "Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD " +
+ "or org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS " +
+ "for more conventional use.";
+ System.err.println(warning);
+ }
+
static class DataPoint implements Serializable {
DataPoint(double[] x, double y) {
this.x = x;
@@ -109,6 +122,8 @@ public static void main(String[] args) {
System.exit(1);
}
+ showWarning();
+
SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD lines = sc.textFile(args[0]);
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
index c22506491fbff..a5db8accdf138 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
@@ -45,10 +45,21 @@
* URL neighbor URL
* ...
* where URL and their neighbors are separated by space(s).
+ *
+ * This is an example implementation for learning how to use Spark. For more conventional use,
+ * please refer to org.apache.spark.graphx.lib.PageRank
*/
public final class JavaPageRank {
private static final Pattern SPACES = Pattern.compile("\\s+");
+ static void showWarning() {
+ String warning = "WARN: This is a naive implementation of PageRank " +
+ "and is given as an example! \n" +
+ "Please use the PageRank implementation found in " +
+ "org.apache.spark.graphx.lib.PageRank for more conventional use.";
+ System.err.println(warning);
+ }
+
private static class Sum implements Function2 {
@Override
public Double call(Double a, Double b) {
@@ -62,6 +73,8 @@ public static void main(String[] args) throws Exception {
System.exit(1);
}
+ showWarning();
+
SparkConf sparkConf = new SparkConf().setAppName("JavaPageRank");
JavaSparkContext ctx = new JavaSparkContext(sparkConf);
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
new file mode 100644
index 0000000000000..1af2067b2b929
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib;
+
+import scala.Tuple2;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.GradientBoosting;
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
+import org.apache.spark.mllib.util.MLUtils;
+
+/**
+ * Classification and regression using gradient-boosted decision trees.
+ */
+public final class JavaGradientBoostedTrees {
+
+ private static void usage() {
+ System.err.println("Usage: JavaGradientBoostedTrees " +
+ " ");
+ System.exit(-1);
+ }
+
+ public static void main(String[] args) {
+ String datapath = "data/mllib/sample_libsvm_data.txt";
+ String algo = "Classification";
+ if (args.length >= 1) {
+ datapath = args[0];
+ }
+ if (args.length >= 2) {
+ algo = args[1];
+ }
+ if (args.length > 2) {
+ usage();
+ }
+ SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
+ JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+ JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
+
+ // Set parameters.
+ // Note: All features are treated as continuous.
+ BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
+ boostingStrategy.setNumIterations(10);
+ boostingStrategy.weakLearnerParams().setMaxDepth(5);
+
+ if (algo.equals("Classification")) {
+ // Compute the number of classes from the data.
+ Integer numClasses = data.map(new Function() {
+ @Override public Double call(LabeledPoint p) {
+ return p.label();
+ }
+ }).countByValue().size();
+ boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
+
+ // Train a GradientBoosting model for classification.
+ final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD predictionAndLabel =
+ data.mapToPair(new PairFunction() {
+ @Override public Tuple2 call(LabeledPoint p) {
+ return new Tuple2(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainErr =
+ 1.0 * predictionAndLabel.filter(new Function, Boolean>() {
+ @Override public Boolean call(Tuple2 pl) {
+ return !pl._1().equals(pl._2());
+ }
+ }).count() / data.count();
+ System.out.println("Training error: " + trainErr);
+ System.out.println("Learned classification tree model:\n" + model);
+ } else if (algo.equals("Regression")) {
+ // Train a GradientBoosting model for classification.
+ final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD predictionAndLabel =
+ data.mapToPair(new PairFunction() {
+ @Override public Tuple2 call(LabeledPoint p) {
+ return new Tuple2(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainMSE =
+ predictionAndLabel.map(new Function, Double>() {
+ @Override public Double call(Tuple2 pl) {
+ Double diff = pl._1() - pl._2();
+ return diff * diff;
+ }
+ }).reduce(new Function2() {
+ @Override public Double call(Double a, Double b) {
+ return a + b;
+ }
+ }) / data.count();
+ System.out.println("Training Mean Squared Error: " + trainMSE);
+ System.out.println("Learned regression tree model:\n" + model);
+ } else {
+ usage();
+ }
+
+ sc.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
index 981bc4f0613a9..99df259b4e8e6 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
@@ -70,7 +70,7 @@ public static void main(String[] args) {
// Create a input stream with the custom receiver on target ip:port and count the
// words in input stream of \n delimited text (eg. generated by 'nc')
JavaReceiverInputDStream lines = ssc.receiverStream(
- new JavaCustomReceiver(args[1], Integer.parseInt(args[2])));
+ new JavaCustomReceiver(args[0], Integer.parseInt(args[1])));
JavaDStream words = lines.flatMap(new FlatMapFunction() {
@Override
public Iterable call(String x) {
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java
index 45bcedebb4117..3e9f0f4b8f127 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java
@@ -25,7 +25,7 @@
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.StorageLevels;
-import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
@@ -35,8 +35,9 @@
/**
* Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ *
* Usage: JavaNetworkWordCount
- * and describe the TCP server that Spark Streaming would connect to receive data.
+ * and describe the TCP server that Spark Streaming would connect to receive data.
*
* To run this on your local machine, you need to first run a Netcat server
* `$ nc -lk 9999`
@@ -56,7 +57,7 @@ public static void main(String[] args) {
// Create the context with a 1 second batch size
SparkConf sparkConf = new SparkConf().setAppName("JavaNetworkWordCount");
- JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, new Duration(1000));
+ JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
// Create a JavaReceiverInputDStream on target ip:port and count the
// words in input stream of \n delimited text (eg. generated by 'nc')
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java
new file mode 100644
index 0000000000000..bceda97f058ea
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.streaming;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.util.Arrays;
+import java.util.regex.Pattern;
+
+import scala.Tuple2;
+import com.google.common.collect.Lists;
+import com.google.common.io.Files;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.streaming.Durations;
+import org.apache.spark.streaming.Time;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.apache.spark.streaming.api.java.JavaStreamingContextFactory;
+
+/**
+ * Counts words in text encoded with UTF8 received from the network every second.
+ *
+ * Usage: JavaRecoverableNetworkWordCount
+ * and describe the TCP server that Spark Streaming would connect to receive
+ * data. directory to HDFS-compatible file system which checkpoint data
+ * file to which the word counts will be appended
+ *
+ * and must be absolute paths
+ *
+ * To run this on your local machine, you need to first run a Netcat server
+ *
+ * `$ nc -lk 9999`
+ *
+ * and run the example as
+ *
+ * `$ ./bin/run-example org.apache.spark.examples.streaming.JavaRecoverableNetworkWordCount \
+ * localhost 9999 ~/checkpoint/ ~/out`
+ *
+ * If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create
+ * a new StreamingContext (will print "Creating new context" to the console). Otherwise, if
+ * checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from
+ * the checkpoint data.
+ *
+ * Refer to the online documentation for more details.
+ */
+public final class JavaRecoverableNetworkWordCount {
+ private static final Pattern SPACE = Pattern.compile(" ");
+
+ private static JavaStreamingContext createContext(String ip,
+ int port,
+ String checkpointDirectory,
+ String outputPath) {
+
+ // If you do not see this printed, that means the StreamingContext has been loaded
+ // from the new checkpoint
+ System.out.println("Creating new context");
+ final File outputFile = new File(outputPath);
+ if (outputFile.exists()) {
+ outputFile.delete();
+ }
+ SparkConf sparkConf = new SparkConf().setAppName("JavaRecoverableNetworkWordCount");
+ // Create the context with a 1 second batch size
+ JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
+ ssc.checkpoint(checkpointDirectory);
+
+ // Create a socket stream on target ip:port and count the
+ // words in input stream of \n delimited text (eg. generated by 'nc')
+ JavaReceiverInputDStream lines = ssc.socketTextStream(ip, port);
+ JavaDStream words = lines.flatMap(new FlatMapFunction() {
+ @Override
+ public Iterable call(String x) {
+ return Lists.newArrayList(SPACE.split(x));
+ }
+ });
+ JavaPairDStream wordCounts = words.mapToPair(
+ new PairFunction() {
+ @Override
+ public Tuple2 call(String s) {
+ return new Tuple2(s, 1);
+ }
+ }).reduceByKey(new Function2() {
+ @Override
+ public Integer call(Integer i1, Integer i2) {
+ return i1 + i2;
+ }
+ });
+
+ wordCounts.foreachRDD(new Function2, Time, Void>() {
+ @Override
+ public Void call(JavaPairRDD rdd, Time time) throws IOException {
+ String counts = "Counts at time " + time + " " + rdd.collect();
+ System.out.println(counts);
+ System.out.println("Appending to " + outputFile.getAbsolutePath());
+ Files.append(counts + "\n", outputFile, Charset.defaultCharset());
+ return null;
+ }
+ });
+
+ return ssc;
+ }
+
+ public static void main(String[] args) {
+ if (args.length != 4) {
+ System.err.println("You arguments were " + Arrays.asList(args));
+ System.err.println(
+ "Usage: JavaRecoverableNetworkWordCount \n" +
+ " . and describe the TCP server that Spark\n" +
+ " Streaming would connect to receive data. directory to\n" +
+ " HDFS-compatible file system which checkpoint data file to which\n" +
+ " the word counts will be appended\n" +
+ "\n" +
+ "In local mode, should be 'local[n]' with n > 1\n" +
+ "Both and must be absolute paths");
+ System.exit(1);
+ }
+
+ final String ip = args[0];
+ final int port = Integer.parseInt(args[1]);
+ final String checkpointDirectory = args[2];
+ final String outputPath = args[3];
+ JavaStreamingContextFactory factory = new JavaStreamingContextFactory() {
+ @Override
+ public JavaStreamingContext create() {
+ return createContext(ip, port, checkpointDirectory, outputPath);
+ }
+ };
+ JavaStreamingContext ssc = JavaStreamingContext.getOrCreate(checkpointDirectory, factory);
+ ssc.start();
+ ssc.awaitTermination();
+ }
+}
diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py
new file mode 100644
index 0000000000000..540dae785f6ea
--- /dev/null
+++ b/examples/src/main/python/mllib/dataset_example.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+An example of how to use SchemaRDD as a dataset for ML. Run with::
+ bin/spark-submit examples/src/main/python/mllib/dataset_example.py
+"""
+
+import os
+import sys
+import tempfile
+import shutil
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+from pyspark.mllib.util import MLUtils
+from pyspark.mllib.stat import Statistics
+
+
+def summarize(dataset):
+ print "schema: %s" % dataset.schema().json()
+ labels = dataset.map(lambda r: r.label)
+ print "label average: %f" % labels.mean()
+ features = dataset.map(lambda r: r.features)
+ summary = Statistics.colStats(features)
+ print "features average: %r" % summary.mean()
+
+if __name__ == "__main__":
+ if len(sys.argv) > 2:
+ print >> sys.stderr, "Usage: dataset_example.py "
+ exit(-1)
+ sc = SparkContext(appName="DatasetExample")
+ sqlCtx = SQLContext(sc)
+ if len(sys.argv) == 2:
+ input = sys.argv[1]
+ else:
+ input = "data/mllib/sample_libsvm_data.txt"
+ points = MLUtils.loadLibSVMFile(sc, input)
+ dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache()
+ summarize(dataset0)
+ tempdir = tempfile.NamedTemporaryFile(delete=False).name
+ os.unlink(tempdir)
+ print "Save dataset as a Parquet file to %s." % tempdir
+ dataset0.saveAsParquetFile(tempdir)
+ print "Load it back and summarize it again."
+ dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache()
+ summarize(dataset1)
+ shutil.rmtree(tempdir)
diff --git a/examples/src/main/python/mllib/word2vec.py b/examples/src/main/python/mllib/word2vec.py
new file mode 100644
index 0000000000000..99fef4276a369
--- /dev/null
+++ b/examples/src/main/python/mllib/word2vec.py
@@ -0,0 +1,50 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# This example uses text8 file from http://mattmahoney.net/dc/text8.zip
+# The file was downloadded, unziped and split into multiple lines using
+#
+# wget http://mattmahoney.net/dc/text8.zip
+# unzip text8.zip
+# grep -o -E '\w+(\W+\w+){0,15}' text8 > text8_lines
+# This was done so that the example can be run in local mode
+
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.mllib.feature import Word2Vec
+
+USAGE = ("bin/spark-submit --driver-memory 4g "
+ "examples/src/main/python/mllib/word2vec.py text8_lines")
+
+if __name__ == "__main__":
+ if len(sys.argv) < 2:
+ print USAGE
+ sys.exit("Argument for file not provided")
+ file_path = sys.argv[1]
+ sc = SparkContext(appName='Word2Vec')
+ inp = sc.textFile(file_path).map(lambda row: row.split(" "))
+
+ word2vec = Word2Vec()
+ model = word2vec.fit(inp)
+
+ synonyms = model.findSynonyms('china', 40)
+
+ for word, cosine_distance in synonyms:
+ print "{}: {}".format(word, cosine_distance)
+ sc.stop()
diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py
index b539c4128cdcc..a5f25d78c1146 100755
--- a/examples/src/main/python/pagerank.py
+++ b/examples/src/main/python/pagerank.py
@@ -15,6 +15,11 @@
# limitations under the License.
#
+"""
+This is an example implementation of PageRank. For more conventional use,
+Please refer to PageRank implementation provided by graphx
+"""
+
import re
import sys
from operator import add
@@ -40,6 +45,9 @@ def parseNeighbors(urls):
print >> sys.stderr, "Usage: pagerank "
exit(-1)
+ print >> sys.stderr, """WARN: This is a naive implementation of PageRank and is
+ given as an example! Please refer to PageRank implementation provided by graphx"""
+
# Initialize the spark context.
sc = SparkContext(appName="PythonPageRank")
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala
index 931faac5463c4..ac2ea35bbd0e0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala
@@ -25,7 +25,8 @@ import breeze.linalg.{Vector, DenseVector}
* Logistic regression based classification.
*
* This is an example implementation for learning how to use Spark. For more conventional use,
- * please refer to org.apache.spark.mllib.classification.LogisticRegression
+ * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
*/
object LocalFileLR {
val D = 10 // Numer of dimensions
@@ -41,7 +42,8 @@ object LocalFileLR {
def showWarning() {
System.err.println(
"""WARN: This is a naive implementation of Logistic Regression and is given as an example!
- |Please use the LogisticRegression method found in org.apache.spark.mllib.classification
+ |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|for more conventional use.
""".stripMargin)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala
index 2d75b9d2590f8..92a683ad57ea1 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala
@@ -25,7 +25,8 @@ import breeze.linalg.{Vector, DenseVector}
* Logistic regression based classification.
*
* This is an example implementation for learning how to use Spark. For more conventional use,
- * please refer to org.apache.spark.mllib.classification.LogisticRegression
+ * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
*/
object LocalLR {
val N = 10000 // Number of data points
@@ -48,7 +49,8 @@ object LocalLR {
def showWarning() {
System.err.println(
"""WARN: This is a naive implementation of Logistic Regression and is given as an example!
- |Please use the LogisticRegression method found in org.apache.spark.mllib.classification
+ |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|for more conventional use.
""".stripMargin)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
index 3258510894372..9099c2fcc90b3 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
@@ -32,7 +32,8 @@ import org.apache.spark.scheduler.InputFormatInfo
* Logistic regression based classification.
*
* This is an example implementation for learning how to use Spark. For more conventional use,
- * please refer to org.apache.spark.mllib.classification.LogisticRegression
+ * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
*/
object SparkHdfsLR {
val D = 10 // Numer of dimensions
@@ -54,7 +55,8 @@ object SparkHdfsLR {
def showWarning() {
System.err.println(
"""WARN: This is a naive implementation of Logistic Regression and is given as an example!
- |Please use the LogisticRegression method found in org.apache.spark.mllib.classification
+ |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|for more conventional use.
""".stripMargin)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
index fc23308fc4adf..257a7d29f922a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
@@ -30,7 +30,8 @@ import org.apache.spark._
* Usage: SparkLR [slices]
*
* This is an example implementation for learning how to use Spark. For more conventional use,
- * please refer to org.apache.spark.mllib.classification.LogisticRegression
+ * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
*/
object SparkLR {
val N = 10000 // Number of data points
@@ -53,7 +54,8 @@ object SparkLR {
def showWarning() {
System.err.println(
"""WARN: This is a naive implementation of Logistic Regression and is given as an example!
- |Please use the LogisticRegression method found in org.apache.spark.mllib.classification
+ |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|for more conventional use.
""".stripMargin)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
index 4c7e006da0618..8d092b6506d33 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
@@ -28,13 +28,28 @@ import org.apache.spark.{SparkConf, SparkContext}
* URL neighbor URL
* ...
* where URL and their neighbors are separated by space(s).
+ *
+ * This is an example implementation for learning how to use Spark. For more conventional use,
+ * please refer to org.apache.spark.graphx.lib.PageRank
*/
object SparkPageRank {
+
+ def showWarning() {
+ System.err.println(
+ """WARN: This is a naive implementation of PageRank and is given as an example!
+ |Please use the PageRank implementation found in org.apache.spark.graphx.lib.PageRank
+ |for more conventional use.
+ """.stripMargin)
+ }
+
def main(args: Array[String]) {
if (args.length < 1) {
System.err.println("Usage: SparkPageRank ")
System.exit(1)
}
+
+ showWarning()
+
val sparkConf = new SparkConf().setAppName("PageRank")
val iters = if (args.length > 0) args(1).toInt else 10
val ctx = new SparkContext(sparkConf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala
index 96d13612e46dd..4393b99e636b6 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala
@@ -32,11 +32,24 @@ import org.apache.spark.storage.StorageLevel
/**
* Logistic regression based classification.
* This example uses Tachyon to persist rdds during computation.
+ *
+ * This is an example implementation for learning how to use Spark. For more conventional use,
+ * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
*/
object SparkTachyonHdfsLR {
val D = 10 // Numer of dimensions
val rand = new Random(42)
+ def showWarning() {
+ System.err.println(
+ """WARN: This is a naive implementation of Logistic Regression and is given as an example!
+ |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+ |for more conventional use.
+ """.stripMargin)
+ }
+
case class DataPoint(x: Vector[Double], y: Double)
def parsePoint(line: String): DataPoint = {
@@ -51,6 +64,9 @@ object SparkTachyonHdfsLR {
}
def main(args: Array[String]) {
+
+ showWarning()
+
val inputPath = args(0)
val sparkConf = new SparkConf().setAppName("SparkTachyonHdfsLR")
val conf = new Configuration()
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
index d70d93608a57c..828cffb01ca1e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
@@ -77,7 +77,7 @@ object Analytics extends Logging {
val sc = new SparkContext(conf.setAppName("PageRank(" + fname + ")"))
val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname,
- minEdgePartitions = numEPart,
+ numEdgePartitions = numEPart,
edgeStorageLevel = edgeStorageLevel,
vertexStorageLevel = vertexStorageLevel).cache()
val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_))
@@ -110,7 +110,7 @@ object Analytics extends Logging {
val sc = new SparkContext(conf.setAppName("ConnectedComponents(" + fname + ")"))
val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname,
- minEdgePartitions = numEPart,
+ numEdgePartitions = numEPart,
edgeStorageLevel = edgeStorageLevel,
vertexStorageLevel = vertexStorageLevel).cache()
val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_))
@@ -131,7 +131,7 @@ object Analytics extends Logging {
val sc = new SparkContext(conf.setAppName("TriangleCount(" + fname + ")"))
val graph = GraphLoader.edgeListFile(sc, fname,
canonicalOrientation = true,
- minEdgePartitions = numEPart,
+ numEdgePartitions = numEPart,
edgeStorageLevel = edgeStorageLevel,
vertexStorageLevel = vertexStorageLevel)
// TriangleCount requires the graph to be partitioned
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
new file mode 100644
index 0000000000000..f8d83f4ec7327
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import java.io.File
+
+import com.google.common.io.Files
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}
+
+/**
+ * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object DatasetExample {
+
+ case class Params(
+ input: String = "data/mllib/sample_libsvm_data.txt",
+ dataFormat: String = "libsvm") extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("DatasetExample") {
+ head("Dataset: an example app using SchemaRDD as a Dataset for ML.")
+ opt[String]("input")
+ .text(s"input path to dataset")
+ .action((x, c) => c.copy(input = x))
+ opt[String]("dataFormat")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ success
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"DatasetExample with $params")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._ // for implicit conversions
+
+ // Load input data
+ val origData: RDD[LabeledPoint] = params.dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
+ }
+ println(s"Loaded ${origData.count()} instances from file: ${params.input}")
+
+ // Convert input data to SchemaRDD explicitly.
+ val schemaRDD: SchemaRDD = origData
+ println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}")
+ println(s"Converted to SchemaRDD with ${schemaRDD.count()} records")
+
+ // Select columns, using implicit conversion to SchemaRDD.
+ val labelsSchemaRDD: SchemaRDD = origData.select('label)
+ val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v }
+ val numLabels = labels.count()
+ val meanLabel = labels.fold(0.0)(_ + _) / numLabels
+ println(s"Selected label column with average value $meanLabel")
+
+ val featuresSchemaRDD: SchemaRDD = origData.select('features)
+ val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v }
+ val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, feat) => summary.add(feat),
+ (sum1, sum2) => sum1.merge(sum2))
+ println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")
+
+ val tmpDir = Files.createTempDir()
+ tmpDir.deleteOnExit()
+ val outputDir = new File(tmpDir, "dataset").toString
+ println(s"Saving to $outputDir as Parquet file.")
+ schemaRDD.saveAsParquetFile(outputDir)
+
+ println(s"Loading Parquet file with UDT from $outputDir.")
+ val newDataset = sqlContext.parquetFile(outputDir)
+
+ println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
+ val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v }
+ val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, feat) => summary.add(feat),
+ (sum1, sum2) => sum1.merge(sum2))
+ println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}")
+
+ sc.stop()
+ }
+
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 0890e6263e165..63f02cf7b98b9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
@@ -62,7 +62,10 @@ object DecisionTreeRunner {
minInfoGain: Double = 0.0,
numTrees: Int = 1,
featureSubsetStrategy: String = "auto",
- fracTest: Double = 0.2) extends AbstractParams[Params]
+ fracTest: Double = 0.2,
+ useNodeIdCache: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
@@ -102,6 +105,21 @@ object DecisionTreeRunner {
.text(s"fraction of data to hold out for testing. If given option testInput, " +
s"this option is ignored. default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("useNodeIdCache")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.useNodeIdCache}")
+ .action((x, c) => c.copy(useNodeIdCache = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }}")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
opt[String]("testInput")
.text(s"input path to test dataset. If given, option fracTest is ignored." +
s" default: ${defaultParams.testInput}")
@@ -136,20 +154,30 @@ object DecisionTreeRunner {
}
}
- def run(params: Params) {
-
- val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
- val sc = new SparkContext(conf)
-
- println(s"DecisionTreeRunner with parameters:\n$params")
-
+ /**
+ * Load training and test data from files.
+ * @param input Path to input dataset.
+ * @param dataFormat "libsvm" or "dense"
+ * @param testInput Path to test dataset.
+ * @param algo Classification or Regression
+ * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given.
+ * @return (training dataset, test dataset, number of classes),
+ * where the number of classes is inferred from data (and set to 0 for Regression)
+ */
+ private[mllib] def loadDatasets(
+ sc: SparkContext,
+ input: String,
+ dataFormat: String,
+ testInput: String,
+ algo: Algo,
+ fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = {
// Load training data and cache it.
- val origExamples = params.dataFormat match {
- case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
- case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
+ val origExamples = dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, input).cache()
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache()
}
// For classification, re-index classes if needed.
- val (examples, classIndexMap, numClasses) = params.algo match {
+ val (examples, classIndexMap, numClasses) = algo match {
case Classification => {
// classCounts: class --> # examples in class
val classCounts = origExamples.map(_.label).countByValue()
@@ -187,14 +215,14 @@ object DecisionTreeRunner {
}
// Create training, test sets.
- val splits = if (params.testInput != "") {
+ val splits = if (testInput != "") {
// Load testInput.
val numFeatures = examples.take(1)(0).features.size
- val origTestExamples = params.dataFormat match {
- case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
- case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures)
+ val origTestExamples = dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, testInput)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures)
}
- params.algo match {
+ algo match {
case Classification => {
// classCounts: class --> # examples in class
val testExamples = {
@@ -211,17 +239,31 @@ object DecisionTreeRunner {
}
} else {
// Split input into training, test.
- examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
+ examples.randomSplit(Array(1.0 - fracTest, fracTest))
}
val training = splits(0).cache()
val test = splits(1).cache()
+
val numTraining = training.count()
val numTest = test.count()
-
println(s"numTraining = $numTraining, numTest = $numTest.")
examples.unpersist(blocking = false)
+ (training, test, numClasses)
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"DecisionTreeRunner with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat,
+ params.testInput, params.algo, params.fracTest)
+
val impurityCalculator = params.impurity match {
case Gini => impurity.Gini
case Entropy => impurity.Entropy
@@ -236,7 +278,10 @@ object DecisionTreeRunner {
maxBins = params.maxBins,
numClassesForClassification = numClasses,
minInstancesPerNode = params.minInstancesPerNode,
- minInfoGain = params.minInfoGain)
+ minInfoGain = params.minInfoGain,
+ useNodeIdCache = params.useNodeIdCache,
+ checkpointDir = params.checkpointDir,
+ checkpointInterval = params.checkpointInterval)
if (params.numTrees == 1) {
val startTime = System.nanoTime()
val model = DecisionTree.train(training, strategy)
@@ -317,7 +362,9 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
- private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = {
+ private[mllib] def meanSquaredError(
+ tree: WeightedEnsembleModel,
+ data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
err * err
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
new file mode 100644
index 0000000000000..9b6db01448be0
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.tree.GradientBoosting
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
+import org.apache.spark.util.Utils
+
+/**
+ * An example runner for Gradient Boosting using decision trees as weak learners. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ *
+ * Note: This script treats all features as real-valued (not categorical).
+ * To include categorical features, modify categoricalFeaturesInfo.
+ */
+object GradientBoostedTrees {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "Classification",
+ maxDepth: Int = 5,
+ numIterations: Int = 10,
+ fracTest: Double = 0.2) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("GradientBoostedTrees") {
+ head("GradientBoostedTrees: an example decision tree app.")
+ opt[String]("algo")
+ .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("numIterations")
+ .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}")
+ .action((x, c) => c.copy(numIterations = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest > 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"GradientBoostedTrees with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
+
+ val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
+ boostingStrategy.numClassesForClassification = numClasses
+ boostingStrategy.numIterations = params.numIterations
+ boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
+
+ val randomSeed = Utils.random.nextInt()
+ if (params.algo == "Classification") {
+ val startTime = System.nanoTime()
+ val model = GradientBoosting.trainClassifier(training, boostingStrategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainAccuracy =
+ new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+ .precision
+ println(s"Train accuracy = $trainAccuracy")
+ val testAccuracy =
+ new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
+ println(s"Test accuracy = $testAccuracy")
+ } else if (params.algo == "Regression") {
+ val startTime = System.nanoTime()
+ val model = GradientBoosting.trainRegressor(training, boostingStrategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainMSE = DecisionTreeRunner.meanSquaredError(model, training)
+ println(s"Train mean squared error = $trainMSE")
+ val testMSE = DecisionTreeRunner.meanSquaredError(model, test)
+ println(s"Test mean squared error = $testMSE")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
index 8796c28db8a66..91a0a860d6c71 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
@@ -106,9 +106,11 @@ object MovieLensALS {
Logger.getRootLogger.setLevel(Level.WARN)
+ val implicitPrefs = params.implicitPrefs
+
val ratings = sc.textFile(params.input).map { line =>
val fields = line.split("::")
- if (params.implicitPrefs) {
+ if (implicitPrefs) {
/*
* MovieLens ratings are on a scale of 1-5:
* 5: Must see
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala
new file mode 100644
index 0000000000000..33e5760aed997
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.clustering.StreamingKMeans
+import org.apache.spark.SparkConf
+import org.apache.spark.streaming.{Seconds, StreamingContext}
+
+/**
+ * Estimate clusters on one stream of data and make predictions
+ * on another stream, where the data streams arrive as text files
+ * into two different directories.
+ *
+ * The rows of the training text files must be vector data in the form
+ * `[x1,x2,x3,...,xn]`
+ * Where n is the number of dimensions.
+ *
+ * The rows of the test text files must be labeled data in the form
+ * `(y,[x1,x2,x3,...,xn])`
+ * Where y is some identifier. n must be the same for train and test.
+ *
+ * Usage: StreamingKmeans
+ *
+ * To run on your local machine using the two directories `trainingDir` and `testDir`,
+ * with updates every 5 seconds, 2 dimensions per data point, and 3 clusters, call:
+ * $ bin/run-example \
+ * org.apache.spark.examples.mllib.StreamingKMeans trainingDir testDir 5 3 2
+ *
+ * As you add text files to `trainingDir` the clusters will continuously update.
+ * Anytime you add text files to `testDir`, you'll see predicted labels using the current model.
+ *
+ */
+object StreamingKMeans {
+
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ System.err.println(
+ "Usage: StreamingKMeans " +
+ "")
+ System.exit(1)
+ }
+
+ val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression")
+ val ssc = new StreamingContext(conf, Seconds(args(2).toLong))
+
+ val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse)
+ val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)
+
+ val model = new StreamingKMeans()
+ .setK(args(3).toInt)
+ .setDecayFactor(1.0)
+ .setRandomCenters(args(4).toInt, 0.0)
+
+ model.trainOn(trainingData)
+ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
+
+ ssc.start()
+ ssc.awaitTermination()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
index 6af3a0f33efc2..19427e629f76d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
@@ -31,15 +31,13 @@ import org.apache.spark.util.IntParam
/**
* Counts words in text encoded with UTF8 received from the network every second.
*
- * Usage: NetworkWordCount
+ * Usage: RecoverableNetworkWordCount
* and describe the TCP server that Spark Streaming would connect to receive
* data. directory to HDFS-compatible file system which checkpoint data
* file to which the word counts will be appended
*
- * In local mode, should be 'local[n]' with n > 1
* and must be absolute paths
*
- *
* To run this on your local machine, you need to first run a Netcat server
*
* `$ nc -lk 9999`
@@ -54,22 +52,11 @@ import org.apache.spark.util.IntParam
* checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from
* the checkpoint data.
*
- * To run this example in a local standalone cluster with automatic driver recovery,
- *
- * `$ bin/spark-class org.apache.spark.deploy.Client -s launch \
- * \
- * org.apache.spark.examples.streaming.RecoverableNetworkWordCount \
- * localhost 9999 ~/checkpoint ~/out`
- *
- * would typically be
- * /examples/target/scala-XX/spark-examples....jar
- *
* Refer to the online documentation for more details.
*/
-
object RecoverableNetworkWordCount {
- def createContext(ip: String, port: Int, outputPath: String) = {
+ def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) = {
// If you do not see this printed, that means the StreamingContext has been loaded
// from the new checkpoint
@@ -79,6 +66,7 @@ object RecoverableNetworkWordCount {
val sparkConf = new SparkConf().setAppName("RecoverableNetworkWordCount")
// Create the context with a 1 second batch size
val ssc = new StreamingContext(sparkConf, Seconds(1))
+ ssc.checkpoint(checkpointDirectory)
// Create a socket stream on target ip:port and count the
// words in input stream of \n delimited text (eg. generated by 'nc')
@@ -114,7 +102,7 @@ object RecoverableNetworkWordCount {
val Array(ip, IntParam(port), checkpointDirectory, outputPath) = args
val ssc = StreamingContext.getOrCreate(checkpointDirectory,
() => {
- createContext(ip, port, outputPath)
+ createContext(ip, port, outputPath, checkpointDirectory)
})
ssc.start()
ssc.awaitTermination()
diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties
new file mode 100644
index 0000000000000..4411d6e20c52a
--- /dev/null
+++ b/external/flume-sink/src/test/resources/log4j.properties
@@ -0,0 +1,29 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file streaming/target/unit-tests.log
+log4j.rootCategory=INFO, file
+# log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=false
+log4j.appender.file.file=target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.eclipse.jetty=WARN
+
diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
index a2b2cc6149d95..650b2fbe1c142 100644
--- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
+++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
@@ -159,6 +159,7 @@ class SparkSinkSuite extends FunSuite {
channelContext.put("transactionCapacity", 1000.toString)
channelContext.put("keep-alive", 0.toString)
channelContext.putAll(overrides)
+ channel.setName(scala.util.Random.nextString(10))
channel.configure(channelContext)
val sink = new SparkSink()
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
index 32a19787a28e1..475026e8eb140 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
@@ -145,11 +145,16 @@ class FlumePollingStreamSuite extends TestSuiteBase {
outputStream.register()
ssc.start()
- writeAndVerify(Seq(channel, channel2), ssc, outputBuffer)
- assertChannelIsEmpty(channel)
- assertChannelIsEmpty(channel2)
- sink.stop()
- channel.stop()
+ try {
+ writeAndVerify(Seq(channel, channel2), ssc, outputBuffer)
+ assertChannelIsEmpty(channel)
+ assertChannelIsEmpty(channel2)
+ } finally {
+ sink.stop()
+ sink2.stop()
+ channel.stop()
+ channel2.stop()
+ }
}
def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext,
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala
index e20e2c8f26991..28ac5929df44a 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala
@@ -26,8 +26,6 @@ import java.util.concurrent.Executors
import kafka.consumer._
import kafka.serializer.Decoder
import kafka.utils.VerifiableProperties
-import kafka.utils.ZKStringSerializer
-import org.I0Itec.zkclient._
import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel
@@ -97,12 +95,6 @@ class KafkaReceiver[
consumerConnector = Consumer.create(consumerConfig)
logInfo("Connected to " + zkConnect)
- // When auto.offset.reset is defined, it is our responsibility to try and whack the
- // consumer group zk node.
- if (kafkaParams.contains("auto.offset.reset")) {
- tryZookeeperConsumerGroupCleanup(zkConnect, kafkaParams("group.id"))
- }
-
val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties])
.newInstance(consumerConfig.props)
.asInstanceOf[Decoder[K]]
@@ -139,26 +131,4 @@ class KafkaReceiver[
}
}
}
-
- // It is our responsibility to delete the consumer group when specifying auto.offset.reset. This
- // is because Kafka 0.7.2 only honors this param when the group is not in zookeeper.
- //
- // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied
- // from Kafka's ConsoleConsumer. See code related to 'auto.offset.reset' when it is set to
- // 'smallest'/'largest':
- // scalastyle:off
- // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala
- // scalastyle:on
- private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) {
- val dir = "/consumers/" + groupId
- logInfo("Cleaning up temporary Zookeeper data under " + dir + ".")
- val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer)
- try {
- zk.deleteRecursive(dir)
- } catch {
- case e: Throwable => logWarning("Error cleaning up temporary Zookeeper data", e)
- } finally {
- zk.close()
- }
- }
}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index 48668f763e41e..ec812e1ef3b04 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -17,19 +17,18 @@
package org.apache.spark.streaming.kafka
-import scala.reflect.ClassTag
-import scala.collection.JavaConversions._
-
import java.lang.{Integer => JInt}
import java.util.{Map => JMap}
+import scala.reflect.ClassTag
+import scala.collection.JavaConversions._
+
import kafka.serializer.{Decoder, StringDecoder}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
-import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext, JavaPairDStream}
-import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream}
-
+import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext}
+import org.apache.spark.streaming.dstream.ReceiverInputDStream
object KafkaUtils {
/**
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
index 5bcb96b136ed7..5267560b3e5ce 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
@@ -82,12 +82,17 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag](
this
}
- /** Persists the vertex partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */
+ /** Persists the edge partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */
override def cache(): this.type = {
partitionsRDD.persist(targetStorageLevel)
this
}
+ /** The number of edges in the RDD. */
+ override def count(): Long = {
+ partitionsRDD.map(_._2.size.toLong).reduce(_ + _)
+ }
+
private[graphx] def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag](
f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDD[ED2, VD2] = {
this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter =>
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala
index f4c79365b16da..4933aecba1286 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala
@@ -48,7 +48,8 @@ object GraphLoader extends Logging {
* @param path the path to the file (e.g., /home/data/file or hdfs://file)
* @param canonicalOrientation whether to orient edges in the positive
* direction
- * @param minEdgePartitions the number of partitions for the edge RDD
+ * @param numEdgePartitions the number of partitions for the edge RDD
+ * Setting this value to -1 will use the default parallelism.
* @param edgeStorageLevel the desired storage level for the edge partitions
* @param vertexStorageLevel the desired storage level for the vertex partitions
*/
@@ -56,7 +57,7 @@ object GraphLoader extends Logging {
sc: SparkContext,
path: String,
canonicalOrientation: Boolean = false,
- minEdgePartitions: Int = 1,
+ numEdgePartitions: Int = -1,
edgeStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY,
vertexStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY)
: Graph[Int, Int] =
@@ -64,7 +65,12 @@ object GraphLoader extends Logging {
val startTime = System.currentTimeMillis
// Parse the edge data table directly into edge partitions
- val lines = sc.textFile(path, minEdgePartitions).coalesce(minEdgePartitions)
+ val lines =
+ if (numEdgePartitions > 0) {
+ sc.textFile(path, numEdgePartitions).coalesce(numEdgePartitions)
+ } else {
+ sc.textFile(path)
+ }
val edges = lines.mapPartitionsWithIndex { (pid, iter) =>
val builder = new EdgePartitionBuilder[Int, Int]
iter.foreach { line =>
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
index 2c8b245955d12..12216d9d33d66 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
@@ -27,8 +27,6 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.graphx.impl.RoutingTablePartition
import org.apache.spark.graphx.impl.ShippableVertexPartition
import org.apache.spark.graphx.impl.VertexAttributeBlock
-import org.apache.spark.graphx.impl.RoutingTableMessageRDDFunctions._
-import org.apache.spark.graphx.impl.VertexRDDFunctions._
/**
* Extends `RDD[(VertexId, VD)]` by ensuring that there is only one entry for each vertex and by
@@ -233,7 +231,7 @@ class VertexRDD[@specialized VD: ClassTag](
case _ =>
this.withPartitionsRDD[VD3](
partitionsRDD.zipPartitions(
- other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) {
+ other.partitionBy(this.partitioner.get), preservesPartitioning = true) {
(partIter, msgs) => partIter.map(_.leftJoin(msgs)(f))
}
)
@@ -277,7 +275,7 @@ class VertexRDD[@specialized VD: ClassTag](
case _ =>
this.withPartitionsRDD(
partitionsRDD.zipPartitions(
- other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) {
+ other.partitionBy(this.partitioner.get), preservesPartitioning = true) {
(partIter, msgs) => partIter.map(_.innerJoin(msgs)(f))
}
)
@@ -297,7 +295,7 @@ class VertexRDD[@specialized VD: ClassTag](
*/
def aggregateUsingIndex[VD2: ClassTag](
messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = {
- val shuffled = messages.copartitionWithVertices(this.partitioner.get)
+ val shuffled = messages.partitionBy(this.partitioner.get)
val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) =>
thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc))
}
@@ -371,7 +369,7 @@ object VertexRDD {
def apply[VD: ClassTag](vertices: RDD[(VertexId, VD)]): VertexRDD[VD] = {
val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match {
case Some(p) => vertices
- case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size))
+ case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size))
}
val vertexPartitions = vPartitioned.mapPartitions(
iter => Iterator(ShippableVertexPartition(iter)),
@@ -412,7 +410,7 @@ object VertexRDD {
): VertexRDD[VD] = {
val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match {
case Some(p) => vertices
- case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size))
+ case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size))
}
val routingTables = createRoutingTables(edges, vPartitioned.partitioner.get)
val vertexPartitions = vPartitioned.zipPartitions(routingTables, preservesPartitioning = true) {
@@ -454,7 +452,7 @@ object VertexRDD {
.setName("VertexRDD.createRoutingTables - vid2pid (aggregation)")
val numEdgePartitions = edges.partitions.size
- vid2pid.copartitionWithVertices(vertexPartitioner).mapPartitions(
+ vid2pid.partitionBy(vertexPartitioner).mapPartitions(
iter => Iterator(RoutingTablePartition.fromMsgs(numEdgePartitions, iter)),
preservesPartitioning = true)
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
index 4520beb991515..2b6137be25547 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
@@ -45,8 +45,8 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla
// Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and
// adding them to the index
if (edgeArray.length > 0) {
- index.update(srcIds(0), 0)
- var currSrcId: VertexId = srcIds(0)
+ index.update(edgeArray(0).srcId, 0)
+ var currSrcId: VertexId = edgeArray(0).srcId
var i = 0
while (i < edgeArray.size) {
srcIds(i) = edgeArray(i).srcId
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
deleted file mode 100644
index 714f3b81c9dad..0000000000000
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.graphx.impl
-
-import scala.language.implicitConversions
-import scala.reflect.{classTag, ClassTag}
-
-import org.apache.spark.Partitioner
-import org.apache.spark.graphx.{PartitionID, VertexId}
-import org.apache.spark.rdd.{ShuffledRDD, RDD}
-
-
-private[graphx]
-class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) {
- def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = {
- val rdd = new ShuffledRDD[VertexId, VD, VD](self, partitioner)
-
- // Set a custom serializer if the data is of int or double type.
- if (classTag[VD] == ClassTag.Int) {
- rdd.setSerializer(new IntAggMsgSerializer)
- } else if (classTag[VD] == ClassTag.Long) {
- rdd.setSerializer(new LongAggMsgSerializer)
- } else if (classTag[VD] == ClassTag.Double) {
- rdd.setSerializer(new DoubleAggMsgSerializer)
- }
- rdd
- }
-}
-
-private[graphx]
-object VertexRDDFunctions {
- implicit def rdd2VertexRDDFunctions[VD: ClassTag](rdd: RDD[(VertexId, VD)]) = {
- new VertexRDDFunctions(rdd)
- }
-}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index b27485953f719..7a7fa91aadfe1 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -29,24 +29,6 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
-private[graphx]
-class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
- /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
- def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
- new ShuffledRDD[VertexId, Int, Int](
- self, partitioner).setSerializer(new RoutingTableMessageSerializer)
- }
-}
-
-private[graphx]
-object RoutingTableMessageRDDFunctions {
- import scala.language.implicitConversions
-
- implicit def rdd2RoutingTableMessageRDDFunctions(rdd: RDD[RoutingTableMessage]) = {
- new RoutingTableMessageRDDFunctions(rdd)
- }
-}
-
private[graphx]
object RoutingTablePartition {
/**
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
deleted file mode 100644
index 3909efcdfc993..0000000000000
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
+++ /dev/null
@@ -1,369 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.graphx.impl
-
-import scala.language.existentials
-
-import java.io.{EOFException, InputStream, OutputStream}
-import java.nio.ByteBuffer
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.serializer._
-
-import org.apache.spark.graphx._
-import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
-
-private[graphx]
-class RoutingTableMessageSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream): SerializationStream =
- new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T): SerializationStream = {
- val msg = t.asInstanceOf[RoutingTableMessage]
- writeVarLong(msg._1, optimizePositive = false)
- writeInt(msg._2)
- this
- }
- }
-
- override def deserializeStream(s: InputStream): DeserializationStream =
- new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readInt()
- (a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-private[graphx]
-class VertexIdMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[(VertexId, _)]
- writeVarLong(msg._1, optimizePositive = false)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- (readVarLong(optimizePositive = false), null).asInstanceOf[T]
- }
- }
- }
-}
-
-/** A special shuffle serializer for AggregationMessage[Int]. */
-private[graphx]
-class IntAggMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[(VertexId, Int)]
- writeVarLong(msg._1, optimizePositive = false)
- writeUnsignedVarInt(msg._2)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readUnsignedVarInt()
- (a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-/** A special shuffle serializer for AggregationMessage[Long]. */
-private[graphx]
-class LongAggMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[(VertexId, Long)]
- writeVarLong(msg._1, optimizePositive = false)
- writeVarLong(msg._2, optimizePositive = true)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readVarLong(optimizePositive = true)
- (a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-/** A special shuffle serializer for AggregationMessage[Double]. */
-private[graphx]
-class DoubleAggMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[(VertexId, Double)]
- writeVarLong(msg._1, optimizePositive = false)
- writeDouble(msg._2)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readDouble()
- (a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-////////////////////////////////////////////////////////////////////////////////
-// Helper classes to shorten the implementation of those special serializers.
-////////////////////////////////////////////////////////////////////////////////
-
-private[graphx]
-abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
- // The implementation should override this one.
- def writeObject[T: ClassTag](t: T): SerializationStream
-
- def writeInt(v: Int) {
- s.write(v >> 24)
- s.write(v >> 16)
- s.write(v >> 8)
- s.write(v)
- }
-
- def writeUnsignedVarInt(value: Int) {
- if ((value >>> 7) == 0) {
- s.write(value.toInt)
- } else if ((value >>> 14) == 0) {
- s.write((value & 0x7F) | 0x80)
- s.write(value >>> 7)
- } else if ((value >>> 21) == 0) {
- s.write((value & 0x7F) | 0x80)
- s.write(value >>> 7 | 0x80)
- s.write(value >>> 14)
- } else if ((value >>> 28) == 0) {
- s.write((value & 0x7F) | 0x80)
- s.write(value >>> 7 | 0x80)
- s.write(value >>> 14 | 0x80)
- s.write(value >>> 21)
- } else {
- s.write((value & 0x7F) | 0x80)
- s.write(value >>> 7 | 0x80)
- s.write(value >>> 14 | 0x80)
- s.write(value >>> 21 | 0x80)
- s.write(value >>> 28)
- }
- }
-
- def writeVarLong(value: Long, optimizePositive: Boolean) {
- val v = if (!optimizePositive) (value << 1) ^ (value >> 63) else value
- if ((v >>> 7) == 0) {
- s.write(v.toInt)
- } else if ((v >>> 14) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7).toInt)
- } else if ((v >>> 21) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14).toInt)
- } else if ((v >>> 28) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21).toInt)
- } else if ((v >>> 35) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28).toInt)
- } else if ((v >>> 42) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28 | 0x80).toInt)
- s.write((v >>> 35).toInt)
- } else if ((v >>> 49) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28 | 0x80).toInt)
- s.write((v >>> 35 | 0x80).toInt)
- s.write((v >>> 42).toInt)
- } else if ((v >>> 56) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28 | 0x80).toInt)
- s.write((v >>> 35 | 0x80).toInt)
- s.write((v >>> 42 | 0x80).toInt)
- s.write((v >>> 49).toInt)
- } else {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28 | 0x80).toInt)
- s.write((v >>> 35 | 0x80).toInt)
- s.write((v >>> 42 | 0x80).toInt)
- s.write((v >>> 49 | 0x80).toInt)
- s.write((v >>> 56).toInt)
- }
- }
-
- def writeLong(v: Long) {
- s.write((v >>> 56).toInt)
- s.write((v >>> 48).toInt)
- s.write((v >>> 40).toInt)
- s.write((v >>> 32).toInt)
- s.write((v >>> 24).toInt)
- s.write((v >>> 16).toInt)
- s.write((v >>> 8).toInt)
- s.write(v.toInt)
- }
-
- def writeDouble(v: Double): Unit = writeLong(java.lang.Double.doubleToLongBits(v))
-
- override def flush(): Unit = s.flush()
-
- override def close(): Unit = s.close()
-}
-
-private[graphx]
-abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
- // The implementation should override this one.
- def readObject[T: ClassTag](): T
-
- def readInt(): Int = {
- val first = s.read()
- if (first < 0) throw new EOFException
- (first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF)
- }
-
- def readUnsignedVarInt(): Int = {
- var value: Int = 0
- var i: Int = 0
- def readOrThrow(): Int = {
- val in = s.read()
- if (in < 0) throw new EOFException
- in & 0xFF
- }
- var b: Int = readOrThrow()
- while ((b & 0x80) != 0) {
- value |= (b & 0x7F) << i
- i += 7
- if (i > 35) throw new IllegalArgumentException("Variable length quantity is too long")
- b = readOrThrow()
- }
- value | (b << i)
- }
-
- def readVarLong(optimizePositive: Boolean): Long = {
- def readOrThrow(): Int = {
- val in = s.read()
- if (in < 0) throw new EOFException
- in & 0xFF
- }
- var b = readOrThrow()
- var ret: Long = b & 0x7F
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F) << 7
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F) << 14
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F) << 21
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F).toLong << 28
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F).toLong << 35
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F).toLong << 42
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F).toLong << 49
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= b.toLong << 56
- }
- }
- }
- }
- }
- }
- }
- }
- if (!optimizePositive) (ret >>> 1) ^ -(ret & 1) else ret
- }
-
- def readLong(): Long = {
- val first = s.read()
- if (first < 0) throw new EOFException()
- (first.toLong << 56) |
- (s.read() & 0xFF).toLong << 48 |
- (s.read() & 0xFF).toLong << 40 |
- (s.read() & 0xFF).toLong << 32 |
- (s.read() & 0xFF).toLong << 24 |
- (s.read() & 0xFF) << 16 |
- (s.read() & 0xFF) << 8 |
- (s.read() & 0xFF)
- }
-
- def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong())
-
- override def close(): Unit = s.close()
-}
-
-private[graphx] sealed trait ShuffleSerializerInstance extends SerializerInstance {
-
- override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException
-
- override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
- throw new UnsupportedOperationException
-
- override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
- throw new UnsupportedOperationException
-
- // The implementation should override the following two.
- override def serializeStream(s: OutputStream): SerializationStream
- override def deserializeStream(s: InputStream): DeserializationStream
-}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
deleted file mode 100644
index 864cb1fdf0022..0000000000000
--- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.graphx
-
-import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
-
-import scala.util.Random
-import scala.reflect.ClassTag
-
-import org.scalatest.FunSuite
-
-import org.apache.spark._
-import org.apache.spark.graphx.impl._
-import org.apache.spark.serializer.SerializationStream
-
-
-class SerializerSuite extends FunSuite with LocalSparkContext {
-
- test("IntAggMsgSerializer") {
- val outMsg = (4: VertexId, 5)
- val bout = new ByteArrayOutputStream
- val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout)
- outStrm.writeObject(outMsg)
- outStrm.writeObject(outMsg)
- bout.flush()
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin)
- val inMsg1: (VertexId, Int) = inStrm.readObject()
- val inMsg2: (VertexId, Int) = inStrm.readObject()
- assert(outMsg === inMsg1)
- assert(outMsg === inMsg2)
-
- intercept[EOFException] {
- inStrm.readObject()
- }
- }
-
- test("LongAggMsgSerializer") {
- val outMsg = (4: VertexId, 1L << 32)
- val bout = new ByteArrayOutputStream
- val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout)
- outStrm.writeObject(outMsg)
- outStrm.writeObject(outMsg)
- bout.flush()
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin)
- val inMsg1: (VertexId, Long) = inStrm.readObject()
- val inMsg2: (VertexId, Long) = inStrm.readObject()
- assert(outMsg === inMsg1)
- assert(outMsg === inMsg2)
-
- intercept[EOFException] {
- inStrm.readObject()
- }
- }
-
- test("DoubleAggMsgSerializer") {
- val outMsg = (4: VertexId, 5.0)
- val bout = new ByteArrayOutputStream
- val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout)
- outStrm.writeObject(outMsg)
- outStrm.writeObject(outMsg)
- bout.flush()
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin)
- val inMsg1: (VertexId, Double) = inStrm.readObject()
- val inMsg2: (VertexId, Double) = inStrm.readObject()
- assert(outMsg === inMsg1)
- assert(outMsg === inMsg2)
-
- intercept[EOFException] {
- inStrm.readObject()
- }
- }
-
- test("variable long encoding") {
- def testVarLongEncoding(v: Long, optimizePositive: Boolean) {
- val bout = new ByteArrayOutputStream
- val stream = new ShuffleSerializationStream(bout) {
- def writeObject[T: ClassTag](t: T): SerializationStream = {
- writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive)
- this
- }
- }
- stream.writeObject(v)
-
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val dstream = new ShuffleDeserializationStream(bin) {
- def readObject[T: ClassTag](): T = {
- readVarLong(optimizePositive).asInstanceOf[T]
- }
- }
- val read = dstream.readObject[Long]()
- assert(read === v)
- }
-
- // Test all variable encoding code path (each branch uses 7 bits, i.e. 1L << 7 difference)
- val d = Random.nextLong() % 128
- Seq[Long](0, 1L << 0 + d, 1L << 7 + d, 1L << 14 + d, 1L << 21 + d, 1L << 28 + d, 1L << 35 + d,
- 1L << 42 + d, 1L << 49 + d, 1L << 56 + d, 1L << 63 + d).foreach { number =>
- testVarLongEncoding(number, optimizePositive = false)
- testVarLongEncoding(number, optimizePositive = true)
- testVarLongEncoding(-number, optimizePositive = false)
- testVarLongEncoding(-number, optimizePositive = true)
- }
- }
-}
diff --git a/mllib/pom.xml b/mllib/pom.xml
index de062a4901596..87a7ddaba97f2 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -45,6 +45,11 @@
spark-streaming_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ org.eclipse.jettyjetty-server
@@ -65,6 +70,10 @@
junitjunit
+
+ org.apache.commons
+ commons-math3
+
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 485abe272326c..70d7138e3060f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -18,7 +18,7 @@
package org.apache.spark.mllib.api.python
import java.io.OutputStream
-import java.util.{ArrayList => JArrayList}
+import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
import scala.language.existentials
@@ -43,6 +43,7 @@ import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
+import org.apache.spark.mllib.stat.test.ChiSqTestResult
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -72,15 +73,11 @@ class PythonMLLibAPI extends Serializable {
private def trainRegressionModel(
learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
data: JavaRDD[LabeledPoint],
- initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
- val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector]
+ initialWeights: Vector): JList[Object] = {
// Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
learner.disableUncachedWarning()
val model = learner.run(data.rdd, initialWeights)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(SerDe.dumps(model.weights))
- ret.add(model.intercept: java.lang.Double)
- ret
+ List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
}
/**
@@ -91,10 +88,10 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val lrAlg = new LinearRegressionWithSGD()
lrAlg.setIntercept(intercept)
lrAlg.optimizer
@@ -113,7 +110,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
lrAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -125,7 +122,7 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeights: Vector): JList[Object] = {
val lassoAlg = new LassoWithSGD()
lassoAlg.optimizer
.setNumIterations(numIterations)
@@ -135,7 +132,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
lassoAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -147,7 +144,7 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeights: Vector): JList[Object] = {
val ridgeAlg = new RidgeRegressionWithSGD()
ridgeAlg.optimizer
.setNumIterations(numIterations)
@@ -157,7 +154,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
ridgeAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -169,9 +166,9 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val SVMAlg = new SVMWithSGD()
SVMAlg.setIntercept(intercept)
SVMAlg.optimizer
@@ -190,7 +187,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
SVMAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -201,10 +198,10 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val LogRegAlg = new LogisticRegressionWithSGD()
LogRegAlg.setIntercept(intercept)
LogRegAlg.optimizer
@@ -223,7 +220,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
LogRegAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -231,13 +228,10 @@ class PythonMLLibAPI extends Serializable {
*/
def trainNaiveBayes(
data: JavaRDD[LabeledPoint],
- lambda: Double): java.util.List[java.lang.Object] = {
+ lambda: Double): JList[Object] = {
val model = NaiveBayes.train(data.rdd, lambda)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(Vectors.dense(model.labels))
- ret.add(Vectors.dense(model.pi))
- ret.add(model.theta)
- ret
+ List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta).
+ map(_.asInstanceOf[Object]).asJava
}
/**
@@ -259,6 +253,21 @@ class PythonMLLibAPI extends Serializable {
return kMeansAlg.run(data.rdd)
}
+ /**
+ * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python
+ */
+ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel)
+ extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) {
+
+ def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] =
+ predict(SerDe.asTupleRDD(userAndProducts.rdd))
+
+ def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
+
+ def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
+
+ }
+
/**
* Java stub for Python mllib ALS.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
@@ -266,12 +275,25 @@ class PythonMLLibAPI extends Serializable {
* the Py4J documentation.
*/
def trainALSModel(
- ratings: JavaRDD[Rating],
+ ratingsJRDD: JavaRDD[Rating],
rank: Int,
iterations: Int,
lambda: Double,
- blocks: Int): MatrixFactorizationModel = {
- ALS.train(ratings.rdd, rank, iterations, lambda, blocks)
+ blocks: Int,
+ nonnegative: Boolean,
+ seed: java.lang.Long): MatrixFactorizationModel = {
+
+ val als = new ALS()
+ .setRank(rank)
+ .setIterations(iterations)
+ .setLambda(lambda)
+ .setBlocks(blocks)
+ .setNonnegative(nonnegative)
+
+ if (seed != null) als.setSeed(seed)
+
+ val model = als.run(ratingsJRDD.rdd)
+ new MatrixFactorizationModelWrapper(model)
}
/**
@@ -286,8 +308,23 @@ class PythonMLLibAPI extends Serializable {
iterations: Int,
lambda: Double,
blocks: Int,
- alpha: Double): MatrixFactorizationModel = {
- ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)
+ alpha: Double,
+ nonnegative: Boolean,
+ seed: java.lang.Long): MatrixFactorizationModel = {
+
+ val als = new ALS()
+ .setImplicitPrefs(true)
+ .setRank(rank)
+ .setIterations(iterations)
+ .setLambda(lambda)
+ .setBlocks(blocks)
+ .setAlpha(alpha)
+ .setNonnegative(nonnegative)
+
+ if (seed != null) als.setSeed(seed)
+
+ val model = als.run(ratingsJRDD.rdd)
+ new MatrixFactorizationModelWrapper(model)
}
/**
@@ -373,19 +410,16 @@ class PythonMLLibAPI extends Serializable {
rdd.rdd.map(model.transform)
}
- def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = {
+ def findSynonyms(word: String, num: Int): JList[Object] = {
val vec = transform(word)
findSynonyms(vec, num)
}
- def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = {
+ def findSynonyms(vector: Vector, num: Int): JList[Object] = {
val result = model.findSynonyms(vector, num)
val similarity = Vectors.dense(result.map(_._2))
val words = result.map(_._1)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(words)
- ret.add(similarity)
- ret
+ List(words, similarity).map(_.asInstanceOf[Object]).asJava
}
}
@@ -395,13 +429,13 @@ class PythonMLLibAPI extends Serializable {
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
* see the Py4J documentation.
* @param data Training data
- * @param categoricalFeaturesInfoJMap Categorical features info, as Java map
+ * @param categoricalFeaturesInfo Categorical features info, as Java map
*/
def trainDecisionTreeModel(
data: JavaRDD[LabeledPoint],
algoStr: String,
numClasses: Int,
- categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
+ categoricalFeaturesInfo: JMap[Int, Int],
impurityStr: String,
maxDepth: Int,
maxBins: Int,
@@ -417,7 +451,7 @@ class PythonMLLibAPI extends Serializable {
maxDepth = maxDepth,
numClassesForClassification = numClasses,
maxBins = maxBins,
- categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
+ categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap,
minInstancesPerNode = minInstancesPerNode,
minInfoGain = minInfoGain)
@@ -448,6 +482,31 @@ class PythonMLLibAPI extends Serializable {
Statistics.corr(x.rdd, y.rdd, getCorrNameOrDefault(method))
}
+ /**
+ * Java stub for mllib Statistics.chiSqTest()
+ */
+ def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = {
+ if (expected == null) {
+ Statistics.chiSqTest(observed)
+ } else {
+ Statistics.chiSqTest(observed, expected)
+ }
+ }
+
+ /**
+ * Java stub for mllib Statistics.chiSqTest(observed: Matrix)
+ */
+ def chiSqTest(observed: Matrix): ChiSqTestResult = {
+ Statistics.chiSqTest(observed)
+ }
+
+ /**
+ * Java stub for mllib Statistics.chiSqTest(RDD[LabelPoint])
+ */
+ def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = {
+ Statistics.chiSqTest(data.rdd)
+ }
+
// used by the corr methods to retrieve the name of the correlation method passed in via pyspark
private def getCorrNameOrDefault(method: String) = {
if (method == null) CorrelationNames.defaultCorrName else method
@@ -589,7 +648,7 @@ private[spark] object SerDe extends Serializable {
if (objects.length == 0 || objects.length > 3) {
out.write(Opcodes.MARK)
}
- objects.foreach(pickler.save(_))
+ objects.foreach(pickler.save)
val code = objects.length match {
case 1 => Opcodes.TUPLE1
case 2 => Opcodes.TUPLE2
@@ -719,7 +778,7 @@ private[spark] object SerDe extends Serializable {
}
/* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
- def fromTuple2RDD(rdd: RDD[Tuple2[Any, Any]]): RDD[Array[Any]] = {
+ def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
rdd.map(x => Array(x._1, x._2))
}
@@ -730,7 +789,7 @@ private[spark] object SerDe extends Serializable {
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter =>
initialize() // let it called in executor
- new PythonRDD.AutoBatchedPickler(iter)
+ new SerDeUtil.AutoBatchedPickler(iter)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
new file mode 100644
index 0000000000000..6189dce9b27da
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * :: DeveloperApi ::
+ * StreamingKMeansModel extends MLlib's KMeansModel for streaming
+ * algorithms, so it can keep track of a continuously updated weight
+ * associated with each cluster, and also update the model by
+ * doing a single iteration of the standard k-means algorithm.
+ *
+ * The update algorithm uses the "mini-batch" KMeans rule,
+ * generalized to incorporate forgetfullness (i.e. decay).
+ * The update rule (for each cluster) is:
+ *
+ * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t]
+ * n_t+t = n_t * a + m_t
+ *
+ * Where c_t is the previously estimated centroid for that cluster,
+ * n_t is the number of points assigned to it thus far, x_t is the centroid
+ * estimated on the current batch, and m_t is the number of points assigned
+ * to that centroid in the current batch.
+ *
+ * The decay factor 'a' scales the contribution of the clusters as estimated thus far,
+ * by applying a as a discount weighting on the current point when evaluating
+ * new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids
+ * are determined entirely by recent data. Lower values correspond to
+ * more forgetting.
+ *
+ * Decay can optionally be specified by a half life and associated
+ * time unit. The time unit can either be a batch of data or a single
+ * data point. Considering data arrived at time t, the half life h is defined
+ * such that at time t + h the discount applied to the data from t is 0.5.
+ * The definition remains the same whether the time unit is given
+ * as batches or points.
+ *
+ */
+@DeveloperApi
+class StreamingKMeansModel(
+ override val clusterCenters: Array[Vector],
+ val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging {
+
+ /** Perform a k-means update on a batch of data. */
+ def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {
+
+ // find nearest cluster to each point
+ val closest = data.map(point => (this.predict(point), (point, 1L)))
+
+ // get sums and counts for updating each cluster
+ val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
+ BLAS.axpy(1.0, p2._1, p1._1)
+ (p1._1, p1._2 + p2._2)
+ }
+ val dim = clusterCenters(0).size
+ val pointStats: Array[(Int, (Vector, Long))] = closest
+ .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
+ .collect()
+
+ val discount = timeUnit match {
+ case StreamingKMeans.BATCHES => decayFactor
+ case StreamingKMeans.POINTS =>
+ val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
+ n
+ }.sum
+ math.pow(decayFactor, numNewPoints)
+ }
+
+ // apply discount to weights
+ BLAS.scal(discount, Vectors.dense(clusterWeights))
+
+ // implement update rule
+ pointStats.foreach { case (label, (sum, count)) =>
+ val centroid = clusterCenters(label)
+
+ val updatedWeight = clusterWeights(label) + count
+ val lambda = count / math.max(updatedWeight, 1e-16)
+
+ clusterWeights(label) = updatedWeight
+ BLAS.scal(1.0 - lambda, centroid)
+ BLAS.axpy(lambda / count, sum, centroid)
+
+ // display the updated cluster centers
+ val display = clusterCenters(label).size match {
+ case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
+ case _ => centroid.toArray.mkString("[", ",", "]")
+ }
+
+ logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
+ }
+
+ // Check whether the smallest cluster is dying. If so, split the largest cluster.
+ val weightsWithIndex = clusterWeights.view.zipWithIndex
+ val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
+ val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
+ if (minWeight < 1e-8 * maxWeight) {
+ logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
+ val weight = (maxWeight + minWeight) / 2.0
+ clusterWeights(largest) = weight
+ clusterWeights(smallest) = weight
+ val largestClusterCenter = clusterCenters(largest)
+ val smallestClusterCenter = clusterCenters(smallest)
+ var j = 0
+ while (j < dim) {
+ val x = largestClusterCenter(j)
+ val p = 1e-14 * math.max(math.abs(x), 1.0)
+ largestClusterCenter.toBreeze(j) = x + p
+ smallestClusterCenter.toBreeze(j) = x - p
+ j += 1
+ }
+ }
+
+ this
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * StreamingKMeans provides methods for configuring a
+ * streaming k-means analysis, training the model on streaming,
+ * and using the model to make predictions on streaming data.
+ * See KMeansModel for details on algorithm and update rules.
+ *
+ * Use a builder pattern to construct a streaming k-means analysis
+ * in an application, like:
+ *
+ * val model = new StreamingKMeans()
+ * .setDecayFactor(0.5)
+ * .setK(3)
+ * .setRandomCenters(5, 100.0)
+ * .trainOn(DStream)
+ */
+@DeveloperApi
+class StreamingKMeans(
+ var k: Int,
+ var decayFactor: Double,
+ var timeUnit: String) extends Logging {
+
+ def this() = this(2, 1.0, StreamingKMeans.BATCHES)
+
+ protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
+
+ /** Set the number of clusters. */
+ def setK(k: Int): this.type = {
+ this.k = k
+ this
+ }
+
+ /** Set the decay factor directly (for forgetful algorithms). */
+ def setDecayFactor(a: Double): this.type = {
+ this.decayFactor = decayFactor
+ this
+ }
+
+ /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
+ def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
+ if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) {
+ throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
+ }
+ this.decayFactor = math.exp(math.log(0.5) / halfLife)
+ logInfo("Setting decay factor to: %g ".format (this.decayFactor))
+ this.timeUnit = timeUnit
+ this
+ }
+
+ /** Specify initial centers directly. */
+ def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
+ model = new StreamingKMeansModel(centers, weights)
+ this
+ }
+
+ /**
+ * Initialize random centers, requiring only the number of dimensions.
+ *
+ * @param dim Number of dimensions
+ * @param weight Weight for each center
+ * @param seed Random seed
+ */
+ def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
+ val random = new XORShiftRandom(seed)
+ val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
+ val weights = Array.fill(k)(weight)
+ model = new StreamingKMeansModel(centers, weights)
+ this
+ }
+
+ /** Return the latest model. */
+ def latestModel(): StreamingKMeansModel = {
+ model
+ }
+
+ /**
+ * Update the clustering model by training on batches of data from a DStream.
+ * This operation registers a DStream for training the model,
+ * checks whether the cluster centers have been initialized,
+ * and updates the model using each batch of data from the stream.
+ *
+ * @param data DStream containing vector data
+ */
+ def trainOn(data: DStream[Vector]) {
+ assertInitialized()
+ data.foreachRDD { (rdd, time) =>
+ model = model.update(rdd, decayFactor, timeUnit)
+ }
+ }
+
+ /**
+ * Use the clustering model to make predictions on batches of data from a DStream.
+ *
+ * @param data DStream containing vector data
+ * @return DStream containing predictions
+ */
+ def predictOn(data: DStream[Vector]): DStream[Int] = {
+ assertInitialized()
+ data.map(model.predict)
+ }
+
+ /**
+ * Use the model to make predictions on the values of a DStream and carry over its keys.
+ *
+ * @param data DStream containing (key, feature vector) pairs
+ * @tparam K key type
+ * @return DStream containing the input keys and the predictions as values
+ */
+ def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = {
+ assertInitialized()
+ data.mapValues(model.predict)
+ }
+
+ /** Check whether cluster centers have been initialized. */
+ private[this] def assertInitialized(): Unit = {
+ if (model.clusterCenters == null) {
+ throw new IllegalStateException(
+ "Initial cluster centers must be set before starting predictions")
+ }
+ }
+}
+
+private[clustering] object StreamingKMeans {
+ final val BATCHES = "batches"
+ final val POINTS = "points"
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala
index 7858ec602483f..078fbfbe4f0e1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala
@@ -43,7 +43,7 @@ private[evaluation] object AreaUnderCurve {
*/
def of(curve: RDD[(Double, Double)]): Double = {
curve.sliding(2).aggregate(0.0)(
- seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
+ seqOp = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points),
combOp = _ + _
)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
new file mode 100644
index 0000000000000..ea10bde5fa252
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext._
+
+/**
+ * Evaluator for multilabel classification.
+ * @param predictionAndLabels an RDD of (predictions, labels) pairs,
+ * both are non-null Arrays, each with unique elements.
+ */
+class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {
+
+ private lazy val numDocs: Long = predictionAndLabels.count()
+
+ private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>
+ labels}.distinct().count()
+
+ /**
+ * Returns subset accuracy
+ * (for equal sets of labels)
+ */
+ lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) =>
+ predictions.deep == labels.deep
+ }.count().toDouble / numDocs
+
+ /**
+ * Returns accuracy
+ */
+ lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) =>
+ labels.intersect(predictions).size.toDouble /
+ (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs
+
+
+ /**
+ * Returns Hamming-loss
+ */
+ lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) =>
+ labels.size + predictions.size - 2 * labels.intersect(predictions).size
+ }.sum / (numDocs * numLabels)
+
+ /**
+ * Returns document-based precision averaged by the number of documents
+ */
+ lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) =>
+ if (predictions.size > 0) {
+ predictions.intersect(labels).size.toDouble / predictions.size
+ } else {
+ 0
+ }
+ }.sum / numDocs
+
+ /**
+ * Returns document-based recall averaged by the number of documents
+ */
+ lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) =>
+ labels.intersect(predictions).size.toDouble / labels.size
+ }.sum / numDocs
+
+ /**
+ * Returns document-based f1-measure averaged by the number of documents
+ */
+ lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) =>
+ 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)
+ }.sum / numDocs
+
+ private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
+ predictions.intersect(labels)
+ }.countByValue()
+
+ private lazy val fpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
+ predictions.diff(labels)
+ }.countByValue()
+
+ private lazy val fnPerClass = predictionAndLabels.flatMap { case(predictions, labels) =>
+ labels.diff(predictions)
+ }.countByValue()
+
+ /**
+ * Returns precision for a given label (category)
+ * @param label the label.
+ */
+ def precision(label: Double) = {
+ val tp = tpPerClass(label)
+ val fp = fpPerClass.getOrElse(label, 0L)
+ if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
+ }
+
+ /**
+ * Returns recall for a given label (category)
+ * @param label the label.
+ */
+ def recall(label: Double) = {
+ val tp = tpPerClass(label)
+ val fn = fnPerClass.getOrElse(label, 0L)
+ if (tp + fn == 0) 0 else tp.toDouble / (tp + fn)
+ }
+
+ /**
+ * Returns f1-measure for a given label (category)
+ * @param label the label.
+ */
+ def f1Measure(label: Double) = {
+ val p = precision(label)
+ val r = recall(label)
+ if((p + r) == 0) 0 else 2 * p * r / (p + r)
+ }
+
+ private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp }
+ private lazy val sumFpClass = fpPerClass.foldLeft(0L) { case (sum, (_, fp)) => sum + fp }
+ private lazy val sumFnClass = fnPerClass.foldLeft(0L) { case (sum, (_, fn)) => sum + fn }
+
+ /**
+ * Returns micro-averaged label-based precision
+ * (equals to micro-averaged document-based precision)
+ */
+ lazy val microPrecision = {
+ val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp}
+ sumTp.toDouble / (sumTp + sumFp)
+ }
+
+ /**
+ * Returns micro-averaged label-based recall
+ * (equals to micro-averaged document-based recall)
+ */
+ lazy val microRecall = {
+ val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn}
+ sumTp.toDouble / (sumTp + sumFn)
+ }
+
+ /**
+ * Returns micro-averaged label-based f1-measure
+ * (equals to micro-averaged document-based f1-measure)
+ */
+ lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
+
+ /**
+ * Returns the sequence of labels in ascending order
+ */
+ lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
new file mode 100644
index 0000000000000..693117d820580
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+import org.apache.spark.Logging
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
+
+/**
+ * :: Experimental ::
+ * Evaluator for regression.
+ *
+ * @param predictionAndObservations an RDD of (prediction, observation) pairs.
+ */
+@Experimental
+class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging {
+
+ /**
+ * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
+ */
+ private lazy val summary: MultivariateStatisticalSummary = {
+ val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
+ case (prediction, observation) => Vectors.dense(observation, observation - prediction)
+ }.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, v) => summary.add(v),
+ (sum1, sum2) => sum1.merge(sum2)
+ )
+ summary
+ }
+
+ /**
+ * Returns the explained variance regression score.
+ * explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
+ * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
+ */
+ def explainedVariance: Double = {
+ 1 - summary.variance(1) / summary.variance(0)
+ }
+
+ /**
+ * Returns the mean absolute error, which is a risk function corresponding to the
+ * expected value of the absolute error loss or l1-norm loss.
+ */
+ def meanAbsoluteError: Double = {
+ summary.normL1(1) / summary.count
+ }
+
+ /**
+ * Returns the mean squared error, which is a risk function corresponding to the
+ * expected value of the squared error loss or quadratic loss.
+ */
+ def meanSquaredError: Double = {
+ val rmse = summary.normL2(1) / math.sqrt(summary.count)
+ rmse * rmse
+ }
+
+ /**
+ * Returns the root mean squared error, which is defined as the square root of
+ * the mean squared error.
+ */
+ def rootMeanSquaredError: Double = {
+ summary.normL2(1) / math.sqrt(summary.count)
+ }
+
+ /**
+ * Returns R^2^, the coefficient of determination.
+ * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
+ */
+ def r2: Double = {
+ 1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 6af225b7f49f7..ac217edc619ab 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -17,22 +17,26 @@
package org.apache.spark.mllib.linalg
-import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
import java.util
+import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
import scala.annotation.varargs
import scala.collection.JavaConverters._
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
-import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.SparkException
+import org.apache.spark.mllib.util.NumericParser
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
+import org.apache.spark.sql.catalyst.types._
/**
* Represents a numeric vector, whose index type is Int and value type is Double.
*
* Note: Users should not implement this interface.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
sealed trait Vector extends Serializable {
/**
@@ -74,6 +78,65 @@ sealed trait Vector extends Serializable {
}
}
+/**
+ * User-defined type for [[Vector]] which allows easy interaction with SQL
+ * via [[org.apache.spark.sql.SchemaRDD]].
+ */
+private[spark] class VectorUDT extends UserDefinedType[Vector] {
+
+ override def sqlType: StructType = {
+ // type: 0 = sparse, 1 = dense
+ // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
+ // vectors. The "values" field is nullable because we might want to add binary vectors later,
+ // which uses "size" and "indices", but not "values".
+ StructType(Seq(
+ StructField("type", ByteType, nullable = false),
+ StructField("size", IntegerType, nullable = true),
+ StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
+ StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
+ }
+
+ override def serialize(obj: Any): Row = {
+ val row = new GenericMutableRow(4)
+ obj match {
+ case sv: SparseVector =>
+ row.setByte(0, 0)
+ row.setInt(1, sv.size)
+ row.update(2, sv.indices.toSeq)
+ row.update(3, sv.values.toSeq)
+ case dv: DenseVector =>
+ row.setByte(0, 1)
+ row.setNullAt(1)
+ row.setNullAt(2)
+ row.update(3, dv.values.toSeq)
+ }
+ row
+ }
+
+ override def deserialize(datum: Any): Vector = {
+ datum match {
+ case row: Row =>
+ require(row.length == 4,
+ s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
+ val tpe = row.getByte(0)
+ tpe match {
+ case 0 =>
+ val size = row.getInt(1)
+ val indices = row.getAs[Iterable[Int]](2).toArray
+ val values = row.getAs[Iterable[Double]](3).toArray
+ new SparseVector(size, indices, values)
+ case 1 =>
+ val values = row.getAs[Iterable[Double]](3).toArray
+ new DenseVector(values)
+ }
+ }
+ }
+
+ override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT"
+
+ override def userClass: Class[Vector] = classOf[Vector]
+}
+
/**
* Factory methods for [[org.apache.spark.mllib.linalg.Vector]].
* We don't use the name `Vector` because Scala imports
@@ -191,6 +254,7 @@ object Vectors {
/**
* A dense vector represented by a value array.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
class DenseVector(val values: Array[Double]) extends Vector {
override def size: Int = values.length
@@ -215,6 +279,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
* @param indices index array, assume to be strictly increasing.
* @param values value array, must have the same length as the index array.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
class SparseVector(
override val size: Int,
val indices: Array[Int],
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
index b5e403bc8c14d..57c0768084e41 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.rdd
import scala.language.implicitConversions
import scala.reflect.ClassTag
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.HashPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
@@ -28,8 +29,8 @@ import org.apache.spark.util.Utils
/**
* Machine learning specific RDD functions.
*/
-private[mllib]
-class RDDFunctions[T: ClassTag](self: RDD[T]) {
+@DeveloperApi
+class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable {
/**
* Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
@@ -39,10 +40,10 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
* trigger a Spark job if the parent RDD has more than one partitions and the window size is
* greater than 1.
*/
- def sliding(windowSize: Int): RDD[Seq[T]] = {
+ def sliding(windowSize: Int): RDD[Array[T]] = {
require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.")
if (windowSize == 1) {
- self.map(Seq(_))
+ self.map(Array(_))
} else {
new SlidingRDD[T](self, windowSize)
}
@@ -112,7 +113,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
}
}
-private[mllib]
+@DeveloperApi
object RDDFunctions {
/** Implicit conversion from an RDD to RDDFunctions. */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
index dd80782c0f001..35e81fcb3de0d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
@@ -45,15 +45,16 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T]
*/
private[mllib]
class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int)
- extends RDD[Seq[T]](parent) {
+ extends RDD[Array[T]](parent) {
require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.")
- override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
val part = split.asInstanceOf[SlidingRDDPartition[T]]
(firstParent[T].iterator(part.prev, context) ++ part.tail)
.sliding(windowSize)
.withPartial(false)
+ .map(_.toArray)
}
override def getPreferredLocations(split: Partition): Seq[String] =
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 6737a2f4176c2..78acc17f901c1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -62,7 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
val rfModel = rf.train(input)
- rfModel.trees(0)
+ rfModel.weakHypotheses(0)
}
}
@@ -437,6 +437,11 @@ object DecisionTree extends Serializable with Logging {
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
* Updated with new non-leaf nodes which are created.
+ * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
+ * each value in the array is the data point's node Id
+ * for a corresponding tree. This is used to prevent the need
+ * to pass the entire tree to the executors during
+ * the node stat aggregation phase.
*/
private[tree] def findBestSplits(
input: RDD[BaggedPoint[TreePoint]],
@@ -447,7 +452,8 @@ object DecisionTree extends Serializable with Logging {
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
nodeQueue: mutable.Queue[(Int, Node)],
- timer: TimeTracker = new TimeTracker): Unit = {
+ timer: TimeTracker = new TimeTracker,
+ nodeIdCache: Option[NodeIdCache] = None): Unit = {
/*
* The high-level descriptions of the best split optimizations are noted here.
@@ -479,6 +485,37 @@ object DecisionTree extends Serializable with Logging {
logDebug("isMulticlass = " + metadata.isMulticlass)
logDebug("isMulticlassWithCategoricalFeatures = " +
metadata.isMulticlassWithCategoricalFeatures)
+ logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
+
+ /**
+ * Performs a sequential aggregation over a partition for a particular tree and node.
+ *
+ * For each feature, the aggregate sufficient statistics are updated for the relevant
+ * bins.
+ *
+ * @param treeIndex Index of the tree that we want to perform aggregation for.
+ * @param nodeInfo The node info for the tree node.
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics
+ * for each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ */
+ def nodeBinSeqOp(
+ treeIndex: Int,
+ nodeInfo: RandomForest.NodeIndexInfo,
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePoint]): Unit = {
+ if (nodeInfo != null) {
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val featuresForNode = nodeInfo.featureSubset
+ val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+ if (metadata.unorderedFeatures.isEmpty) {
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
+ } else {
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
+ instanceWeight, featuresForNode)
+ }
+ }
+ }
/**
* Performs a sequential aggregation over a partition.
@@ -497,20 +534,25 @@ object DecisionTree extends Serializable with Logging {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
bins, metadata.unorderedFeatures)
- val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null)
- // If the example does not reach a node in this group, then nodeIndex = null.
- if (nodeInfo != null) {
- val aggNodeIndex = nodeInfo.nodeIndexInGroup
- val featuresForNode = nodeInfo.featureSubset
- val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
- if (metadata.unorderedFeatures.isEmpty) {
- orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
- } else {
- mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
- instanceWeight, featuresForNode)
- }
- }
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
+ }
+
+ agg
+ }
+
+ /**
+ * Do the same thing as binSeqOp, but with nodeIdCache.
+ */
+ def binSeqOpWithNodeIdCache(
+ agg: Array[DTStatsAggregator],
+ dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ val baggedPoint = dataPoint._1
+ val nodeIdCache = dataPoint._2
+ val nodeIndex = nodeIdCache(treeIndex)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}
+
agg
}
@@ -553,7 +595,26 @@ object DecisionTree extends Serializable with Logging {
// Finally, only best Splits for nodes are collected to driver to construct decision tree.
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
- val nodeToBestSplits =
+
+ val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
+ input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+ }
+ } else {
input.mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
@@ -570,7 +631,10 @@ object DecisionTree extends Serializable with Logging {
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
- }.reduceByKey((a, b) => a.merge(b))
+ }
+ }
+
+ val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
.map { case (nodeIndex, aggStats) =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
@@ -584,6 +648,13 @@ object DecisionTree extends Serializable with Logging {
timer.stop("chooseSplits")
+ val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
+ Array.fill[mutable.Map[Int, NodeIndexUpdater]](
+ metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
+ } else {
+ null
+ }
+
// Iterate over all nodes in this group.
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
@@ -613,6 +684,13 @@ object DecisionTree extends Serializable with Logging {
node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
+ if (nodeIdCache.nonEmpty) {
+ val nodeIndexUpdater = NodeIndexUpdater(
+ split = split,
+ nodeIndex = nodeIndex)
+ nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
+ }
+
// enqueue left child and right child if they are not leaves
if (!leftChildIsLeaf) {
nodeQueue.enqueue((treeIndex, node.leftNode.get))
@@ -629,6 +707,10 @@ object DecisionTree extends Serializable with Logging {
}
}
+ if (nodeIdCache.nonEmpty) {
+ // Update the cache if needed.
+ nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins)
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
new file mode 100644
index 0000000000000..f729344a682e2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
+import org.apache.spark.mllib.tree.impl.TimeTracker
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * :: Experimental ::
+ * A class that implements Stochastic Gradient Boosting
+ * for regression and binary classification problems.
+ *
+ * The implementation is based upon:
+ * J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes:
+ * - This currently can be run with several loss functions. However, only SquaredError is
+ * fully supported. Specifically, the loss function should be used to compute the gradient
+ * (to re-label training instances on each iteration) and to weight weak hypotheses.
+ * Currently, gradients are computed correctly for the available loss functions,
+ * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError.
+ * Running with those losses will likely behave reasonably, but lacks the same guarantees.
+ *
+ * @param boostingStrategy Parameters for the gradient boosting algorithm
+ */
+@Experimental
+class GradientBoosting (
+ private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
+
+ boostingStrategy.weakLearnerParams.algo = Regression
+ boostingStrategy.weakLearnerParams.impurity = impurity.Variance
+
+ // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
+ boostingStrategy.weakLearnerParams.numClassesForClassification =
+ boostingStrategy.numClassesForClassification
+
+ boostingStrategy.assertValid()
+
+ /**
+ * Method to train a gradient boosting model
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
+ val algo = boostingStrategy.algo
+ algo match {
+ case Regression => GradientBoosting.boost(input, boostingStrategy)
+ case Classification =>
+ // Map labels to -1, +1 so binary classification can be treated as regression.
+ val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ GradientBoosting.boost(remappedInput, boostingStrategy)
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
+ }
+ }
+
+}
+
+
+object GradientBoosting extends Logging {
+
+ /**
+ * Method to train a gradient boosting model.
+ *
+ * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
+ * is recommended to clearly specify regression.
+ * Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
+ * is recommended to clearly specify regression.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param boostingStrategy Configuration options for the boosting algorithm.
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def train(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ new GradientBoosting(boostingStrategy).train(input)
+ }
+
+ /**
+ * Method to train a gradient boosting classification model.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param boostingStrategy Configuration options for the boosting algorithm.
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def trainClassifier(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ val algo = boostingStrategy.algo
+ require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.")
+ new GradientBoosting(boostingStrategy).train(input)
+ }
+
+ /**
+ * Method to train a gradient boosting regression model.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param boostingStrategy Configuration options for the boosting algorithm.
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def trainRegressor(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ val algo = boostingStrategy.algo
+ require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.")
+ new GradientBoosting(boostingStrategy).train(input)
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]]
+ */
+ def train(
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ train(input.rdd, boostingStrategy)
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
+ */
+ def trainClassifier(
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ trainClassifier(input.rdd, boostingStrategy)
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
+ */
+ def trainRegressor(
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ trainRegressor(input.rdd, boostingStrategy)
+ }
+
+ /**
+ * Internal method for performing regression using trees as base learners.
+ * @param input training dataset
+ * @param boostingStrategy boosting parameters
+ * @return
+ */
+ private def boost(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+
+ val timer = new TimeTracker()
+ timer.start("total")
+ timer.start("init")
+
+ // Initialize gradient boosting parameters
+ val numIterations = boostingStrategy.numIterations
+ val baseLearners = new Array[DecisionTreeModel](numIterations)
+ val baseLearnerWeights = new Array[Double](numIterations)
+ val loss = boostingStrategy.loss
+ val learningRate = boostingStrategy.learningRate
+ val strategy = boostingStrategy.weakLearnerParams
+
+ // Cache input
+ if (input.getStorageLevel == StorageLevel.NONE) {
+ input.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+
+ timer.stop("init")
+
+ logDebug("##########")
+ logDebug("Building tree 0")
+ logDebug("##########")
+ var data = input
+
+ // Initialize tree
+ timer.start("building tree 0")
+ val firstTreeModel = new DecisionTree(strategy).train(data)
+ baseLearners(0) = firstTreeModel
+ baseLearnerWeights(0) = 1.0
+ val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression,
+ Sum)
+ logDebug("error of gbt = " + loss.computeError(startingModel, input))
+ // Note: A model of type regression is used since we require raw prediction
+ timer.stop("building tree 0")
+
+ // psuedo-residual for second iteration
+ data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
+ point.features))
+
+ var m = 1
+ while (m < numIterations) {
+ timer.start(s"building tree $m")
+ logDebug("###################################################")
+ logDebug("Gradient boosting tree iteration " + m)
+ logDebug("###################################################")
+ val model = new DecisionTree(strategy).train(data)
+ timer.stop(s"building tree $m")
+ // Create partial model
+ baseLearners(m) = model
+ // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
+ // Technically, the weight should be optimized for the particular loss.
+ // However, the behavior should be reasonable, though not optimal.
+ baseLearnerWeights(m) = learningRate
+ // Note: A model of type regression is used since we require raw prediction
+ val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
+ baseLearnerWeights.slice(0, m + 1), Regression, Sum)
+ logDebug("error of gbt = " + loss.computeError(partialModel, input))
+ // Update data with pseudo-residuals
+ data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
+ point.features))
+ m += 1
+ }
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+ new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)
+
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index ebbd8e0257209..9683916d9b3f1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -26,8 +26,9 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache }
import org.apache.spark.mllib.tree.impurity.Impurities
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
@@ -59,7 +60,7 @@ import org.apache.spark.util.Utils
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
- * @param seed Random seed for bootstrapping and choosing feature subsets.
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
*/
@Experimental
private class RandomForest (
@@ -78,9 +79,9 @@ private class RandomForest (
/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
- def train(input: RDD[LabeledPoint]): RandomForestModel = {
+ def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
val timer = new TimeTracker()
@@ -111,11 +112,20 @@ private class RandomForest (
// Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
- val baggedInput = if (numTrees > 1) {
- BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed)
- } else {
- BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
- }.persist(StorageLevel.MEMORY_AND_DISK)
+
+ val (subsample, withReplacement) = {
+ // TODO: Have a stricter check for RF in the strategy
+ val isRandomForest = numTrees > 1
+ if (isRandomForest) {
+ (1.0, true)
+ } else {
+ (strategy.subsamplingRate, false)
+ }
+ }
+
+ val baggedInput
+ = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed)
+ .persist(StorageLevel.MEMORY_AND_DISK)
// depth of the decision tree
val maxDepth = strategy.maxDepth
@@ -150,6 +160,19 @@ private class RandomForest (
* in lower levels).
*/
+ // Create an RDD of node Id cache.
+ // At first, all the rows belong to the root nodes (node Id == 1).
+ val nodeIdCache = if (strategy.useNodeIdCache) {
+ Some(NodeIdCache.init(
+ data = baggedInput,
+ numTrees = numTrees,
+ checkpointDir = strategy.checkpointDir,
+ checkpointInterval = strategy.checkpointInterval,
+ initVal = 1))
+ } else {
+ None
+ }
+
// FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, Node)]()
@@ -172,7 +195,7 @@ private class RandomForest (
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
- treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
+ treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
timer.stop("findBestSplits")
}
@@ -183,8 +206,14 @@ private class RandomForest (
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
+ // Delete any remaining checkpoints used for node Id cache.
+ if (nodeIdCache.nonEmpty) {
+ nodeIdCache.get.deleteAllCheckpoints()
+ }
+
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
- RandomForestModel.build(trees)
+ val treeWeights = Array.fill[Double](numTrees)(1.0)
+ new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average)
}
}
@@ -205,14 +234,14 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
- seed: Int): RandomForestModel = {
+ seed: Int): WeightedEnsembleModel = {
require(strategy.algo == Classification,
s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
@@ -243,7 +272,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
@@ -254,7 +283,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int = Utils.random.nextInt()): RandomForestModel = {
+ seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Classification, impurityType, maxDepth,
numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
@@ -273,7 +302,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int): RandomForestModel = {
+ seed: Int): WeightedEnsembleModel = {
trainClassifier(input.rdd, numClassesForClassification,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
@@ -293,14 +322,14 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
- seed: Int): RandomForestModel = {
+ seed: Int): WeightedEnsembleModel = {
require(strategy.algo == Regression,
s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
@@ -330,7 +359,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
@@ -340,7 +369,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int = Utils.random.nextInt()): RandomForestModel = {
+ seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Regression, impurityType, maxDepth,
0, maxBins, Sort, categoricalFeaturesInfo)
@@ -358,7 +387,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int): RandomForestModel = {
+ seed: Int): WeightedEnsembleModel = {
trainRegressor(input.rdd,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
new file mode 100644
index 0000000000000..abbda040bd528
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+import scala.beans.BeanProperty
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
+
+/**
+ * :: Experimental ::
+ * Stores all the configuration options for the boosting algorithms
+ * @param algo Learning goal. Supported:
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * @param numIterations Number of iterations of boosting. In other words, the number of
+ * weak hypotheses used in the final model.
+ * @param loss Loss function used for minimization during gradient boosting.
+ * @param learningRate Learning rate for shrinking the contribution of each estimator. The
+ * learning rate should be between in the interval (0, 1]
+ * @param numClassesForClassification Number of classes for classification.
+ * (Ignored for regression.)
+ * This setting overrides any setting in [[weakLearnerParams]].
+ * Default value is 2 (binary classification).
+ * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are
+ * supported.
+ */
+@Experimental
+case class BoostingStrategy(
+ // Required boosting parameters
+ @BeanProperty var algo: Algo,
+ @BeanProperty var numIterations: Int,
+ @BeanProperty var loss: Loss,
+ // Optional boosting parameters
+ @BeanProperty var learningRate: Double = 0.1,
+ @BeanProperty var numClassesForClassification: Int = 2,
+ @BeanProperty var weakLearnerParams: Strategy) extends Serializable {
+
+ // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
+ weakLearnerParams.numClassesForClassification = numClassesForClassification
+
+ /**
+ * Sets Algorithm using a String.
+ */
+ def setAlgo(algo: String): Unit = algo match {
+ case "Classification" => setAlgo(Classification)
+ case "Regression" => setAlgo(Regression)
+ }
+
+ /**
+ * Check validity of parameters.
+ * Throws exception if invalid.
+ */
+ private[tree] def assertValid(): Unit = {
+ algo match {
+ case Classification =>
+ require(numClassesForClassification == 2)
+ case Regression =>
+ // nothing
+ case _ =>
+ throw new IllegalArgumentException(
+ s"BoostingStrategy given invalid algo parameter: $algo." +
+ s" Valid settings are: Classification, Regression.")
+ }
+ require(learningRate > 0 && learningRate <= 1,
+ "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.")
+ }
+}
+
+@Experimental
+object BoostingStrategy {
+
+ /**
+ * Returns default configuration for the boosting algorithm
+ * @param algo Learning goal. Supported:
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * @return Configuration for boosting algorithm
+ */
+ def defaultParams(algo: String): BoostingStrategy = {
+ val treeStrategy = Strategy.defaultStrategy("Regression")
+ treeStrategy.maxDepth = 3
+ algo match {
+ case "Classification" =>
+ new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy)
+ case "Regression" =>
+ new BoostingStrategy(Algo.withName(algo), 100, SquaredError,
+ weakLearnerParams = treeStrategy)
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by the boosting.")
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
new file mode 100644
index 0000000000000..82889dc00cdad
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: Experimental ::
+ * Enum to select ensemble combining strategy for base learners
+ */
+@DeveloperApi
+object EnsembleCombiningStrategy extends Enumeration {
+ type EnsembleCombiningStrategy = Value
+ val Sum, Average = Value
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index caaccbfb8ad16..b5b1f82177edc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.tree.configuration
+import scala.beans.BeanProperty
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
@@ -43,7 +44,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* for choosing how to split on features at each node.
* More bins give higher granularity.
* @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported:
- * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
+ * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
* number of discrete values they take. For example, an entry (n ->
* k) implies the feature n is categorical with k categories 0,
@@ -58,31 +59,35 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* this split will not be considered as a valid split.
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 256 MB.
+ * @param subsamplingRate Fraction of the training data used for learning decision tree.
+ * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
+ * maintain a separate RDD of node Id cache for each row.
+ * @param checkpointDir If the node Id cache is used, it will help to checkpoint
+ * the node Id cache periodically. This is the checkpoint directory
+ * to be used for the node Id cache.
+ * @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
+ * E.g. 10 means that the cache will get checkpointed every 10 updates.
*/
@Experimental
class Strategy (
- val algo: Algo,
- val impurity: Impurity,
- val maxDepth: Int,
- val numClassesForClassification: Int = 2,
- val maxBins: Int = 32,
- val quantileCalculationStrategy: QuantileStrategy = Sort,
- val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
- val minInstancesPerNode: Int = 1,
- val minInfoGain: Double = 0.0,
- val maxMemoryInMB: Int = 256) extends Serializable {
+ @BeanProperty var algo: Algo,
+ @BeanProperty var impurity: Impurity,
+ @BeanProperty var maxDepth: Int,
+ @BeanProperty var numClassesForClassification: Int = 2,
+ @BeanProperty var maxBins: Int = 32,
+ @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
+ @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
+ @BeanProperty var minInstancesPerNode: Int = 1,
+ @BeanProperty var minInfoGain: Double = 0.0,
+ @BeanProperty var maxMemoryInMB: Int = 256,
+ @BeanProperty var subsamplingRate: Double = 1,
+ @BeanProperty var useNodeIdCache: Boolean = false,
+ @BeanProperty var checkpointDir: Option[String] = None,
+ @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
- if (algo == Classification) {
- require(numClassesForClassification >= 2)
- }
- require(minInstancesPerNode >= 1,
- s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
- require(maxMemoryInMB <= 10240,
- s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
-
- val isMulticlassClassification =
+ def isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
- val isMulticlassWithCategoricalFeatures
+ def isMulticlassWithCategoricalFeatures
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
/**
@@ -99,6 +104,23 @@ class Strategy (
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
}
+ /**
+ * Sets Algorithm using a String.
+ */
+ def setAlgo(algo: String): Unit = algo match {
+ case "Classification" => setAlgo(Classification)
+ case "Regression" => setAlgo(Regression)
+ }
+
+ /**
+ * Sets categoricalFeaturesInfo using a Java Map.
+ */
+ def setCategoricalFeaturesInfo(
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = {
+ setCategoricalFeaturesInfo(
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
+ }
+
/**
* Check validity of parameters.
* Throws exception if invalid.
@@ -130,6 +152,26 @@ class Strategy (
s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" +
s" feature $feature has $arity categories. The number of categories should be >= 2.")
}
+ require(minInstancesPerNode >= 1,
+ s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
+ require(maxMemoryInMB <= 10240,
+ s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
}
+}
+
+@Experimental
+object Strategy {
+ /**
+ * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
+ * @param algo "Classification" or "Regression"
+ */
+ def defaultStrategy(algo: String): Strategy = algo match {
+ case "Classification" =>
+ new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
+ numClassesForClassification = 2)
+ case "Regression" =>
+ new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
+ numClassesForClassification = 0)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
index e7a2127c5d2e7..089010c81ffb6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
@@ -21,13 +21,14 @@ import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
/**
* Internal representation of a datapoint which belongs to several subsamples of the same dataset,
* particularly for bagging (e.g., for random forests).
*
* This holds one instance, as well as an array of weights which represent the (weighted)
- * number of times which this instance appears in each subsample.
+ * number of times which this instance appears in each subsamplingRate.
* E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that
* this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
*
@@ -44,22 +45,65 @@ private[tree] object BaggedPoint {
/**
* Convert an input dataset into its BaggedPoint representation,
- * choosing subsample counts for each instance.
- * Each subsample has the same number of instances as the original dataset,
- * and is created by subsampling with replacement.
- * @param input Input dataset.
- * @param numSubsamples Number of subsamples of this RDD to take.
- * @param seed Random seed.
- * @return BaggedPoint dataset representation
+ * choosing subsamplingRate counts for each instance.
+ * Each subsamplingRate has the same number of instances as the original dataset,
+ * and is created by subsampling without replacement.
+ * @param input Input dataset.
+ * @param subsamplingRate Fraction of the training data used for learning decision tree.
+ * @param numSubsamples Number of subsamples of this RDD to take.
+ * @param withReplacement Sampling with/without replacement.
+ * @param seed Random seed.
+ * @return BaggedPoint dataset representation.
*/
- def convertToBaggedRDD[Datum](
+ def convertToBaggedRDD[Datum] (
input: RDD[Datum],
+ subsamplingRate: Double,
numSubsamples: Int,
+ withReplacement: Boolean,
seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
+ if (withReplacement) {
+ convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
+ } else {
+ if (numSubsamples == 1 && subsamplingRate == 1.0) {
+ convertToBaggedRDDWithoutSampling(input)
+ } else {
+ convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
+ }
+ }
+ }
+
+ private def convertToBaggedRDDSamplingWithoutReplacement[Datum] (
+ input: RDD[Datum],
+ subsamplingRate: Double,
+ numSubsamples: Int,
+ seed: Int): RDD[BaggedPoint[Datum]] = {
+ input.mapPartitionsWithIndex { (partitionIndex, instances) =>
+ // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
+ val rng = new XORShiftRandom
+ rng.setSeed(seed + partitionIndex + 1)
+ instances.map { instance =>
+ val subsampleWeights = new Array[Double](numSubsamples)
+ var subsampleIndex = 0
+ while (subsampleIndex < numSubsamples) {
+ val x = rng.nextDouble()
+ subsampleWeights(subsampleIndex) = {
+ if (x < subsamplingRate) 1.0 else 0.0
+ }
+ subsampleIndex += 1
+ }
+ new BaggedPoint(instance, subsampleWeights)
+ }
+ }
+ }
+
+ private def convertToBaggedRDDSamplingWithReplacement[Datum] (
+ input: RDD[Datum],
+ subsample: Double,
+ numSubsamples: Int,
+ seed: Int): RDD[BaggedPoint[Datum]] = {
input.mapPartitionsWithIndex { (partitionIndex, instances) =>
- // TODO: Support different sampling rates, and sampling without replacement.
// Use random seed = seed + partitionIndex + 1 to make generation reproducible.
- val poisson = new PoissonDistribution(1.0)
+ val poisson = new PoissonDistribution(subsample)
poisson.reseedRandomGenerator(seed + partitionIndex + 1)
instances.map { instance =>
val subsampleWeights = new Array[Double](numSubsamples)
@@ -73,7 +117,8 @@ private[tree] object BaggedPoint {
}
}
- def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
+ private def convertToBaggedRDDWithoutSampling[Datum] (
+ input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
input.map(datum => new BaggedPoint(datum, Array(1.0)))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
new file mode 100644
index 0000000000000..83011b48b7d9b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.tree.model.{Bin, Node, Split}
+
+/**
+ * :: DeveloperApi ::
+ * This is used by the node id cache to find the child id that a data point would belong to.
+ * @param split Split information.
+ * @param nodeIndex The current node index of a data point that this will update.
+ */
+@DeveloperApi
+private[tree] case class NodeIndexUpdater(
+ split: Split,
+ nodeIndex: Int) {
+ /**
+ * Determine a child node index based on the feature value and the split.
+ * @param binnedFeatures Binned feature values.
+ * @param bins Bin information to convert the bin indices to approximate feature values.
+ * @return Child node index to update to.
+ */
+ def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = {
+ if (split.featureType == Continuous) {
+ val featureIndex = split.feature
+ val binIndex = binnedFeatures(featureIndex)
+ val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
+ if (featureValueUpperBound <= split.threshold) {
+ Node.leftChildIndex(nodeIndex)
+ } else {
+ Node.rightChildIndex(nodeIndex)
+ }
+ } else {
+ if (split.categories.contains(binnedFeatures(split.feature).toDouble)) {
+ Node.leftChildIndex(nodeIndex)
+ } else {
+ Node.rightChildIndex(nodeIndex)
+ }
+ }
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * A given TreePoint would belong to a particular node per tree.
+ * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
+ * in each tree. Initially, values should all be 1 for root node.
+ * The nodeIdsForInstances RDD needs to be updated at each iteration.
+ * @param nodeIdsForInstances The initial values in the cache
+ * (should be an Array of all 1's (meaning the root nodes)).
+ * @param checkpointDir The checkpoint directory where
+ * the checkpointed files will be stored.
+ * @param checkpointInterval The checkpointing interval
+ * (how often should the cache be checkpointed.).
+ */
+@DeveloperApi
+private[tree] class NodeIdCache(
+ var nodeIdsForInstances: RDD[Array[Int]],
+ val checkpointDir: Option[String],
+ val checkpointInterval: Int) {
+
+ // Keep a reference to a previous node Ids for instances.
+ // Because we will keep on re-persisting updated node Ids,
+ // we want to unpersist the previous RDD.
+ private var prevNodeIdsForInstances: RDD[Array[Int]] = null
+
+ // To keep track of the past checkpointed RDDs.
+ private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
+ private var rddUpdateCount = 0
+
+ // If a checkpoint directory is given, and there's no prior checkpoint directory,
+ // then set the checkpoint directory with the given one.
+ if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) {
+ nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get)
+ }
+
+ /**
+ * Update the node index values in the cache.
+ * This updates the RDD and its lineage.
+ * TODO: Passing bin information to executors seems unnecessary and costly.
+ * @param data The RDD of training rows.
+ * @param nodeIdUpdaters A map of node index updaters.
+ * The key is the indices of nodes that we want to update.
+ * @param bins Bin information needed to find child node indices.
+ */
+ def updateNodeIndices(
+ data: RDD[BaggedPoint[TreePoint]],
+ nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
+ bins: Array[Array[Bin]]): Unit = {
+ if (prevNodeIdsForInstances != null) {
+ // Unpersist the previous one if one exists.
+ prevNodeIdsForInstances.unpersist()
+ }
+
+ prevNodeIdsForInstances = nodeIdsForInstances
+ nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
+ dataPoint => {
+ var treeId = 0
+ while (treeId < nodeIdUpdaters.length) {
+ val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null)
+ if (nodeIdUpdater != null) {
+ val newNodeIndex = nodeIdUpdater.updateNodeIndex(
+ binnedFeatures = dataPoint._1.datum.binnedFeatures,
+ bins = bins)
+ dataPoint._2(treeId) = newNodeIndex
+ }
+
+ treeId += 1
+ }
+
+ dataPoint._2
+ }
+ }
+
+ // Keep on persisting new ones.
+ nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
+ rddUpdateCount += 1
+
+ // Handle checkpointing if the directory is not None.
+ if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty &&
+ (rddUpdateCount % checkpointInterval) == 0) {
+ // Let's see if we can delete previous checkpoints.
+ var canDelete = true
+ while (checkpointQueue.size > 1 && canDelete) {
+ // We can delete the oldest checkpoint iff
+ // the next checkpoint actually exists in the file system.
+ if (checkpointQueue.get(1).get.getCheckpointFile != None) {
+ val old = checkpointQueue.dequeue()
+
+ // Since the old checkpoint is not deleted by Spark,
+ // we'll manually delete it here.
+ val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
+ fs.delete(new Path(old.getCheckpointFile.get), true)
+ } else {
+ canDelete = false
+ }
+ }
+
+ nodeIdsForInstances.checkpoint()
+ checkpointQueue.enqueue(nodeIdsForInstances)
+ }
+ }
+
+ /**
+ * Call this after training is finished to delete any remaining checkpoints.
+ */
+ def deleteAllCheckpoints(): Unit = {
+ while (checkpointQueue.size > 0) {
+ val old = checkpointQueue.dequeue()
+ if (old.getCheckpointFile != None) {
+ val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
+ fs.delete(new Path(old.getCheckpointFile.get), true)
+ }
+ }
+ }
+}
+
+@DeveloperApi
+private[tree] object NodeIdCache {
+ /**
+ * Initialize the node Id cache with initial node Id values.
+ * @param data The RDD of training rows.
+ * @param numTrees The number of trees that we want to create cache for.
+ * @param checkpointDir The checkpoint directory where the checkpointed files will be stored.
+ * @param checkpointInterval The checkpointing interval
+ * (how often should the cache be checkpointed.).
+ * @param initVal The initial values in the cache.
+ * @return A node Id cache containing an RDD of initial root node Indices.
+ */
+ def init(
+ data: RDD[BaggedPoint[TreePoint]],
+ numTrees: Int,
+ checkpointDir: Option[String],
+ checkpointInterval: Int,
+ initVal: Int = 1): NodeIdCache = {
+ new NodeIdCache(
+ data.map(_ => Array.fill[Int](numTrees)(initVal)),
+ checkpointDir,
+ checkpointInterval)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
new file mode 100644
index 0000000000000..d111ffe30ed9e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for least absolute error loss calculation.
+ * The features x and the corresponding label y is predicted using the function F.
+ * For each instance:
+ * Loss: |y - F|
+ * Negative gradient: sign(y - F)
+ */
+@DeveloperApi
+object AbsoluteError extends Loss {
+
+ /**
+ * Method to calculate the gradients for the gradient boosting calculation for least
+ * absolute error calculation.
+ * @param model Model of the weak learner
+ * @param point Instance of the training dataset
+ * @return Loss gradient
+ */
+ override def gradient(
+ model: WeightedEnsembleModel,
+ point: LabeledPoint): Double = {
+ if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Model of the weak learner.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return
+ */
+ override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ val sumOfAbsolutes = data.map { y =>
+ val err = model.predict(y.features) - y.label
+ math.abs(err)
+ }.sum()
+ sumOfAbsolutes / data.count()
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
new file mode 100644
index 0000000000000..6f3d4340f0d3b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for least squares error loss calculation.
+ *
+ * The features x and the corresponding label y is predicted using the function F.
+ * For each instance:
+ * Loss: log(1 + exp(-2yF)), y in {-1, 1}
+ * Negative gradient: 2y / ( 1 + exp(2yF))
+ */
+@DeveloperApi
+object LogLoss extends Loss {
+
+ /**
+ * Method to calculate the loss gradients for the gradient boosting calculation for binary
+ * classification
+ * @param model Model of the weak learner
+ * @param point Instance of the training dataset
+ * @return Loss gradient
+ */
+ override def gradient(
+ model: WeightedEnsembleModel,
+ point: LabeledPoint): Double = {
+ val prediction = model.predict(point.features)
+ 1.0 / (1.0 + math.exp(-prediction)) - point.label
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Model of the weak learner.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return
+ */
+ override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count()
+ wrongPredictions / data.count
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
new file mode 100644
index 0000000000000..5580866c879e2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
+ */
+@DeveloperApi
+trait Loss extends Serializable {
+
+ /**
+ * Method to calculate the gradients for the gradient boosting calculation.
+ * @param model Model of the weak learner.
+ * @param point Instance of the training dataset.
+ * @return Loss gradient.
+ */
+ def gradient(
+ model: WeightedEnsembleModel,
+ point: LabeledPoint): Double
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Model of the weak learner.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return
+ */
+ def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
new file mode 100644
index 0000000000000..42c9ead9884b4
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+object Losses {
+
+ def fromString(name: String): Loss = name match {
+ case "leastSquaresError" => SquaredError
+ case "leastAbsoluteError" => AbsoluteError
+ case "logLoss" => LogLoss
+ case _ => throw new IllegalArgumentException(s"Did not recognize Loss name: $name")
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
new file mode 100644
index 0000000000000..4349fefef2c74
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for least squares error loss calculation.
+ *
+ * The features x and the corresponding label y is predicted using the function F.
+ * For each instance:
+ * Loss: (y - F)**2/2
+ * Negative gradient: y - F
+ */
+@DeveloperApi
+object SquaredError extends Loss {
+
+ /**
+ * Method to calculate the gradients for the gradient boosting calculation for least
+ * squares error calculation.
+ * @param model Model of the weak learner
+ * @param point Instance of the training dataset
+ * @return Loss gradient
+ */
+ override def gradient(
+ model: WeightedEnsembleModel,
+ point: LabeledPoint): Double = {
+ model.predict(point.features) - point.label
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Model of the weak learner.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return
+ */
+ override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ data.map { y =>
+ val err = model.predict(y.features) - y.label
+ err * err
+ }.mean()
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
deleted file mode 100644
index 6a22e2abe59bd..0000000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-import scala.collection.mutable
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.rdd.RDD
-
-/**
- * :: Experimental ::
- * Random forest model for classification or regression.
- * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make
- * aggregate predictions.
- * @param trees Trees which make up this forest. This cannot be empty.
- * @param algo algorithm type -- classification or regression
- */
-@Experimental
-class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable {
-
- require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
-
- /**
- * Predict values for a single data point.
- *
- * @param features array representing a single data point
- * @return Double prediction from the trained model
- */
- def predict(features: Vector): Double = {
- algo match {
- case Classification =>
- val predictionToCount = new mutable.HashMap[Int, Int]()
- trees.foreach { tree =>
- val prediction = tree.predict(features).toInt
- predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
- }
- predictionToCount.maxBy(_._2)._1
- case Regression =>
- trees.map(_.predict(features)).sum / trees.size
- }
- }
-
- /**
- * Predict values for the given data set.
- *
- * @param features RDD representing data points to be predicted
- * @return RDD[Double] where each entry contains the corresponding prediction
- */
- def predict(features: RDD[Vector]): RDD[Double] = {
- features.map(x => predict(x))
- }
-
- /**
- * Get number of trees in forest.
- */
- def numTrees: Int = trees.size
-
- /**
- * Get total number of nodes, summed over all trees in the forest.
- */
- def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum
-
- /**
- * Print a summary of the model.
- */
- override def toString: String = algo match {
- case Classification =>
- s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes"
- case Regression =>
- s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes"
- case _ => throw new IllegalArgumentException(
- s"RandomForestModel given unknown algo parameter: $algo.")
- }
-
- /**
- * Print the full model to a string.
- */
- def toDebugString: String = {
- val header = toString + "\n"
- header + trees.zipWithIndex.map { case (tree, treeIndex) =>
- s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
- }.fold("")(_ + _)
- }
-
-}
-
-private[tree] object RandomForestModel {
-
- def build(trees: Array[DecisionTreeModel]): RandomForestModel = {
- require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
- val algo: Algo = trees(0).algo
- require(trees.forall(_.algo == algo),
- "RandomForestModel cannot combine trees which have different output types" +
- " (classification/regression).")
- new RandomForestModel(trees, algo)
- }
-
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
new file mode 100644
index 0000000000000..7b052d9163a13
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.rdd.RDD
+
+import scala.collection.mutable
+
+@Experimental
+class WeightedEnsembleModel(
+ val weakHypotheses: Array[DecisionTreeModel],
+ val weakHypothesisWeights: Array[Double],
+ val algo: Algo,
+ val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {
+
+ require(numWeakHypotheses > 0, s"WeightedEnsembleModel cannot be created without weakHypotheses" +
+ s". Number of weakHypotheses = $weakHypotheses")
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ private def predictRaw(features: Vector): Double = {
+ val treePredictions = weakHypotheses.map(learner => learner.predict(features))
+ if (numWeakHypotheses == 1){
+ treePredictions(0)
+ } else {
+ var prediction = treePredictions(0)
+ var index = 1
+ while (index < numWeakHypotheses) {
+ prediction += weakHypothesisWeights(index) * treePredictions(index)
+ index += 1
+ }
+ prediction
+ }
+ }
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ private def predictBySumming(features: Vector): Double = {
+ algo match {
+ case Regression => predictRaw(features)
+ case Classification => {
+ // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
+ if (predictRaw(features) > 0 ) 1.0 else 0.0
+ }
+ case _ => throw new IllegalArgumentException(
+ s"WeightedEnsembleModel given unknown algo parameter: $algo.")
+ }
+ }
+
+ /**
+ * Predict values for a single data point.
+ *
+ * @param features array representing a single data point
+ * @return Double prediction from the trained model
+ */
+ private def predictByAveraging(features: Vector): Double = {
+ algo match {
+ case Classification =>
+ val predictionToCount = new mutable.HashMap[Int, Int]()
+ weakHypotheses.foreach { learner =>
+ val prediction = learner.predict(features).toInt
+ predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
+ }
+ predictionToCount.maxBy(_._2)._1
+ case Regression =>
+ weakHypotheses.map(_.predict(features)).sum / weakHypotheses.size
+ }
+ }
+
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ def predict(features: Vector): Double = {
+ combiningStrategy match {
+ case Sum => predictBySumming(features)
+ case Average => predictByAveraging(features)
+ case _ => throw new IllegalArgumentException(
+ s"WeightedEnsembleModel given unknown combining parameter: $combiningStrategy.")
+ }
+ }
+
+ /**
+ * Predict values for the given data set.
+ *
+ * @param features RDD representing data points to be predicted
+ * @return RDD[Double] where each entry contains the corresponding prediction
+ */
+ def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x))
+
+ /**
+ * Print a summary of the model.
+ */
+ override def toString: String = {
+ algo match {
+ case Classification =>
+ s"WeightedEnsembleModel classifier with $numWeakHypotheses trees\n"
+ case Regression =>
+ s"WeightedEnsembleModel regressor with $numWeakHypotheses trees\n"
+ case _ => throw new IllegalArgumentException(
+ s"WeightedEnsembleModel given unknown algo parameter: $algo.")
+ }
+ }
+
+ /**
+ * Print the full model to a string.
+ */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + weakHypotheses.zipWithIndex.map { case (tree, treeIndex) =>
+ s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
+ }.fold("")(_ + _)
+ }
+
+ /**
+ * Get number of trees in forest.
+ */
+ def numWeakHypotheses: Int = weakHypotheses.size
+
+ // TODO: Remove these helpers methods once class is generalized to support any base learning
+ // algorithms.
+
+ /**
+ * Get total number of nodes, summed over all trees in the forest.
+ */
+ def totalNumNodes: Int = weakHypotheses.map(tree => tree.numNodes).sum
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index b88e08bf148ae..9353351af72a0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -26,7 +26,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.PartitionwiseSampledRDD
-import org.apache.spark.util.random.BernoulliSampler
+import org.apache.spark.util.random.BernoulliCellSampler
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.storage.StorageLevel
@@ -244,7 +244,7 @@ object MLUtils {
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
val numFoldsF = numFolds.toFloat
(1 to numFolds).map { fold =>
- val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
+ val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
complement = false)
val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed)
val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
new file mode 100644
index 0000000000000..850c9fce507cd
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.streaming.TestSuiteBase
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.util.random.XORShiftRandom
+
+class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
+
+ override def maxWaitTimeMillis = 30000
+
+ test("accuracy for single center and equivalence to grand average") {
+ // set parameters
+ val numBatches = 10
+ val numPoints = 50
+ val k = 1
+ val d = 5
+ val r = 0.1
+
+ // create model with one cluster
+ val model = new StreamingKMeans()
+ .setK(1)
+ .setDecayFactor(1.0)
+ .setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0)), Array(0.0))
+
+ // generate random data for k-means
+ val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
+
+ // setup and run the model training
+ val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ model.trainOn(inputDStream)
+ inputDStream.count()
+ })
+ runStreams(ssc, numBatches, numBatches)
+
+ // estimated center should be close to true center
+ assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1)
+
+ // estimated center from streaming should exactly match the arithmetic mean of all data points
+ // because the decay factor is set to 1.0
+ val grandMean =
+ input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble
+ assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5)
+ }
+
+ test("accuracy for two centers") {
+ val numBatches = 10
+ val numPoints = 5
+ val k = 2
+ val d = 5
+ val r = 0.1
+
+ // create model with two clusters
+ val kMeans = new StreamingKMeans()
+ .setK(2)
+ .setHalfLife(2, "batches")
+ .setInitialCenters(
+ Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1),
+ Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1)),
+ Array(5.0, 5.0))
+
+ // generate random data for k-means
+ val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
+
+ // setup and run the model training
+ val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ kMeans.trainOn(inputDStream)
+ inputDStream.count()
+ })
+ runStreams(ssc, numBatches, numBatches)
+
+ // check that estimated centers are close to true centers
+ // NOTE exact assignment depends on the initialization!
+ assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1)
+ assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1)
+ }
+
+ test("detecting dying clusters") {
+ val numBatches = 10
+ val numPoints = 5
+ val k = 1
+ val d = 1
+ val r = 1.0
+
+ // create model with two clusters
+ val kMeans = new StreamingKMeans()
+ .setK(2)
+ .setHalfLife(0.5, "points")
+ .setInitialCenters(
+ Array(Vectors.dense(0.0), Vectors.dense(1000.0)),
+ Array(1.0, 1.0))
+
+ // new data are all around the first cluster 0.0
+ val (input, _) =
+ StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0)))
+
+ // setup and run the model training
+ val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ kMeans.trainOn(inputDStream)
+ inputDStream.count()
+ })
+ runStreams(ssc, numBatches, numBatches)
+
+ // check that estimated centers are close to true centers
+ // NOTE exact assignment depends on the initialization!
+ val model = kMeans.latestModel()
+ val c0 = model.clusterCenters(0)(0)
+ val c1 = model.clusterCenters(1)(0)
+
+ assert(c0 * c1 < 0.0, "should have one positive center and one negative center")
+ // 0.8 is the mean of half-normal distribution
+ assert(math.abs(c0) ~== 0.8 absTol 0.6)
+ assert(math.abs(c1) ~== 0.8 absTol 0.6)
+ }
+
+ def StreamingKMeansDataGenerator(
+ numPoints: Int,
+ numBatches: Int,
+ k: Int,
+ d: Int,
+ r: Double,
+ seed: Int,
+ initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[Vector]], Array[Vector]) = {
+ val rand = new XORShiftRandom(seed)
+ val centers = initCenters match {
+ case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian())))
+ case _ => initCenters
+ }
+ val data = (0 until numBatches).map { i =>
+ (0 until numPoints).map { idx =>
+ val center = centers(idx % k)
+ Vectors.dense(Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r))
+ }
+ }
+ (data, centers)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
new file mode 100644
index 0000000000000..342baa0274e9c
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.rdd.RDD
+
+class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
+ test("Multilabel evaluation metrics") {
+ /*
+ * Documents true labels (5x class0, 3x class1, 4x class2):
+ * doc 0 - predict 0, 1 - class 0, 2
+ * doc 1 - predict 0, 2 - class 0, 1
+ * doc 2 - predict none - class 0
+ * doc 3 - predict 2 - class 2
+ * doc 4 - predict 2, 0 - class 2, 0
+ * doc 5 - predict 0, 1, 2 - class 0, 1
+ * doc 6 - predict 1 - class 1, 2
+ *
+ * predicted classes
+ * class 0 - doc 0, 1, 4, 5 (total 4)
+ * class 1 - doc 0, 5, 6 (total 3)
+ * class 2 - doc 1, 3, 4, 5 (total 4)
+ *
+ * true classes
+ * class 0 - doc 0, 1, 2, 4, 5 (total 5)
+ * class 1 - doc 1, 5, 6 (total 3)
+ * class 2 - doc 0, 3, 4, 6 (total 4)
+ *
+ */
+ val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize(
+ Seq((Array(0.0, 1.0), Array(0.0, 2.0)),
+ (Array(0.0, 2.0), Array(0.0, 1.0)),
+ (Array(), Array(0.0)),
+ (Array(2.0), Array(2.0)),
+ (Array(2.0, 0.0), Array(2.0, 0.0)),
+ (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)),
+ (Array(1.0), Array(1.0, 2.0))), 2)
+ val metrics = new MultilabelMetrics(scoreAndLabels)
+ val delta = 0.00001
+ val precision0 = 4.0 / (4 + 0)
+ val precision1 = 2.0 / (2 + 1)
+ val precision2 = 2.0 / (2 + 2)
+ val recall0 = 4.0 / (4 + 1)
+ val recall1 = 2.0 / (2 + 1)
+ val recall2 = 2.0 / (2 + 2)
+ val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
+ val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
+ val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
+ val sumTp = 4 + 2 + 2
+ assert(sumTp == (1 + 1 + 0 + 1 + 2 + 2 + 1))
+ val microPrecisionClass = sumTp.toDouble / (4 + 0 + 2 + 1 + 2 + 2)
+ val microRecallClass = sumTp.toDouble / (4 + 1 + 2 + 1 + 2 + 2)
+ val microF1MeasureClass = 2.0 * sumTp.toDouble /
+ (2 * sumTp.toDouble + (1 + 1 + 2) + (0 + 1 + 2))
+ val macroPrecisionDoc = 1.0 / 7 *
+ (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0)
+ val macroRecallDoc = 1.0 / 7 *
+ (1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2)
+ val macroF1MeasureDoc = (1.0 / 7) *
+ 2 * ( 1.0 / (2 + 2) + 1.0 / (2 + 2) + 0 + 1.0 / (1 + 1) +
+ 2.0 / (2 + 2) + 2.0 / (3 + 2) + 1.0 / (1 + 2) )
+ val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1)
+ val strictAccuracy = 2.0 / 7
+ val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2)
+ assert(math.abs(metrics.precision(0.0) - precision0) < delta)
+ assert(math.abs(metrics.precision(1.0) - precision1) < delta)
+ assert(math.abs(metrics.precision(2.0) - precision2) < delta)
+ assert(math.abs(metrics.recall(0.0) - recall0) < delta)
+ assert(math.abs(metrics.recall(1.0) - recall1) < delta)
+ assert(math.abs(metrics.recall(2.0) - recall2) < delta)
+ assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta)
+ assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta)
+ assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta)
+ assert(math.abs(metrics.microPrecision - microPrecisionClass) < delta)
+ assert(math.abs(metrics.microRecall - microRecallClass) < delta)
+ assert(math.abs(metrics.microF1Measure - microF1MeasureClass) < delta)
+ assert(math.abs(metrics.precision - macroPrecisionDoc) < delta)
+ assert(math.abs(metrics.recall - macroRecallDoc) < delta)
+ assert(math.abs(metrics.f1Measure - macroF1MeasureDoc) < delta)
+ assert(math.abs(metrics.hammingLoss - hammingLoss) < delta)
+ assert(math.abs(metrics.subsetAccuracy - strictAccuracy) < delta)
+ assert(math.abs(metrics.accuracy - accuracy) < delta)
+ assert(metrics.labels.sameElements(Array(0.0, 1.0, 2.0)))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
new file mode 100644
index 0000000000000..5396d7b2b74fa
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+class RegressionMetricsSuite extends FunSuite with LocalSparkContext {
+
+ test("regression metrics") {
+ val predictionAndObservations = sc.parallelize(
+ Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)), 2)
+ val metrics = new RegressionMetrics(predictionAndObservations)
+ assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5,
+ "explained variance regression score mismatch")
+ assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
+ assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch")
+ assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5,
+ "root mean squared error mismatch")
+ assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch")
+ }
+
+ test("regression metrics with complete fitting") {
+ val predictionAndObservations = sc.parallelize(
+ Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)), 2)
+ val metrics = new RegressionMetrics(predictionAndObservations)
+ assert(metrics.explainedVariance ~== 1.0 absTol 1E-5,
+ "explained variance regression score mismatch")
+ assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
+ assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch")
+ assert(metrics.rootMeanSquaredError ~== 0.0 absTol 1E-5,
+ "root mean squared error mismatch")
+ assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch")
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index cd651fe2d2ddf..93a84fe07b32a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -155,4 +155,15 @@ class VectorsSuite extends FunSuite {
throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.")
}
}
+
+ test("VectorUDT") {
+ val dv0 = Vectors.dense(Array.empty[Double])
+ val dv1 = Vectors.dense(1.0, 2.0)
+ val sv0 = Vectors.sparse(2, Array.empty, Array.empty)
+ val sv1 = Vectors.sparse(2, Array(1), Array(2.0))
+ val udt = new VectorUDT()
+ for (v <- Seq(dv0, dv1, sv0, sv1)) {
+ assert(v === udt.deserialize(udt.serialize(v)))
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
index 27a19f793242b..4ef67a40b9f49 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
@@ -42,9 +42,9 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext {
val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7))
val rdd = sc.parallelize(data, data.length).flatMap(s => s)
assert(rdd.partitions.size === data.length)
- val sliding = rdd.sliding(3)
- val expected = data.flatMap(x => x).sliding(3).toList
- assert(sliding.collect().toList === expected)
+ val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq)
+ val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq)
+ assert(sliding === expected)
}
test("treeAggregate") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 8fc5e111bbc17..c579cb58549f5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -493,7 +493,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(rootNode1.rightNode.nonEmpty)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
// Single group second level tree construction.
val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
@@ -786,7 +786,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
val topNode = Node.emptyNode(nodeIndex = 1)
assert(topNode.predict.predict === Double.MinValue)
@@ -829,7 +829,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
val topNode = Node.emptyNode(nodeIndex = 1)
assert(topNode.predict.predict === Double.MinValue)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
new file mode 100644
index 0000000000000..effb7b8259ffb
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.util.StatCounter
+
+import scala.collection.mutable
+
+object EnsembleTestHelper {
+
+ /**
+ * Aggregates all values in data, and tests whether the empirical mean and stddev are within
+ * epsilon of the expected values.
+ * @param data Every element of the data should be an i.i.d. sample from some distribution.
+ */
+ def testRandomArrays(
+ data: Array[Array[Double]],
+ numCols: Int,
+ expectedMean: Double,
+ expectedStddev: Double,
+ epsilon: Double) {
+ val values = new mutable.ArrayBuffer[Double]()
+ data.foreach { row =>
+ assert(row.size == numCols)
+ values ++= row
+ }
+ val stats = new StatCounter(values)
+ assert(math.abs(stats.mean - expectedMean) < epsilon)
+ assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+ }
+
+ def validateClassifier(
+ model: WeightedEnsembleModel,
+ input: Seq[LabeledPoint],
+ requiredAccuracy: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
+ prediction != expected.label
+ }
+ val accuracy = (input.length - numOffPredictions).toDouble / input.length
+ assert(accuracy >= requiredAccuracy,
+ s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
+ }
+
+ def validateRegressor(
+ model: WeightedEnsembleModel,
+ input: Seq[LabeledPoint],
+ requiredMSE: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+ val err = prediction - expected.label
+ err * err
+ }.sum
+ val mse = squaredError / input.length
+ assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+ }
+
+ def generateOrderedLabeledPoints(numFeatures: Int, numInstances: Int): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](numInstances)
+ for (i <- 0 until numInstances) {
+ val label = if (i < numInstances / 10) {
+ 0.0
+ } else if (i < numInstances / 2) {
+ 1.0
+ } else if (i < numInstances * 0.9) {
+ 0.0
+ } else {
+ 1.0
+ }
+ val features = Array.fill[Double](numFeatures)(i.toDouble)
+ arr(i) = new LabeledPoint(label, Vectors.dense(features))
+ }
+ arr
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
new file mode 100644
index 0000000000000..99a02eda60baf
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
+import org.apache.spark.mllib.tree.impurity.Variance
+import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss}
+
+import org.apache.spark.mllib.util.LocalSparkContext
+
+/**
+ * Test suite for [[GradientBoosting]].
+ */
+class GradientBoostingSuite extends FunSuite with LocalSparkContext {
+
+ test("Regression with continuous features: SquaredError") {
+ GradientBoostingSuite.testCombinations.foreach {
+ case (numIterations, learningRate, subsamplingRate) =>
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ subsamplingRate = subsamplingRate)
+
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+ val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
+ learningRate, 1, treeStrategy)
+
+ val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
+ assert(gbt.weakHypotheses.size === numIterations)
+ val gbtTree = gbt.weakHypotheses(0)
+
+ EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
+
+ // Make sure trees are the same.
+ assert(gbtTree.toString == dt.toString)
+ }
+ }
+
+ test("Regression with continuous features: Absolute Error") {
+ GradientBoostingSuite.testCombinations.foreach {
+ case (numIterations, learningRate, subsamplingRate) =>
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ subsamplingRate = subsamplingRate)
+
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+ val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
+ learningRate, numClassesForClassification = 2, treeStrategy)
+
+ val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
+ assert(gbt.weakHypotheses.size === numIterations)
+ val gbtTree = gbt.weakHypotheses(0)
+
+ EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
+
+ // Make sure trees are the same.
+ assert(gbtTree.toString == dt.toString)
+ }
+ }
+
+ test("Binary classification with continuous features: Log Loss") {
+ GradientBoostingSuite.testCombinations.foreach {
+ case (numIterations, learningRate, subsamplingRate) =>
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ subsamplingRate = subsamplingRate)
+
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+ val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss,
+ learningRate, numClassesForClassification = 2, treeStrategy)
+
+ val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy)
+ assert(gbt.weakHypotheses.size === numIterations)
+ val gbtTree = gbt.weakHypotheses(0)
+
+ EnsembleTestHelper.validateClassifier(gbt, arr, 0.9)
+
+ // Make sure trees are the same.
+ assert(gbtTree.toString == dt.toString)
+ }
+ }
+
+}
+
+object GradientBoostingSuite {
+
+ // Combinations for estimators, learning rates and subsamplingRate
+ val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index d3eff59aa0409..73c4393c3581a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -25,100 +25,91 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
+import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
-import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
+import org.apache.spark.mllib.tree.model.Node
import org.apache.spark.mllib.util.LocalSparkContext
-import org.apache.spark.util.StatCounter
/**
* Test suite for [[RandomForest]].
*/
class RandomForestSuite extends FunSuite with LocalSparkContext {
-
- test("BaggedPoint RDD: without subsampling") {
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
- val rdd = sc.parallelize(arr)
- val baggedRDD = BaggedPoint.convertToBaggedRDDWithoutSampling(rdd)
- baggedRDD.collect().foreach { baggedPoint =>
- assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
- }
- }
-
- test("BaggedPoint RDD: with subsampling") {
- val numSubsamples = 100
- val (expectedMean, expectedStddev) = (1.0, 1.0)
-
- val seeds = Array(123, 5354, 230, 349867, 23987)
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
+ def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
- seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, numSubsamples, seed = seed)
- val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
- RandomForestSuite.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
- expectedStddev, epsilon = 0.01)
- }
- }
-
- test("Binary classification with continuous features:" +
- " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
-
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
- val rdd = sc.parallelize(arr)
- val categoricalFeaturesInfo = Map.empty[Int, Int]
val numTrees = 1
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
-
val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
- assert(rf.trees.size === 1)
- val rfTree = rf.trees(0)
+ assert(rf.weakHypotheses.size === 1)
+ val rfTree = rf.weakHypotheses(0)
val dt = DecisionTree.train(rdd, strategy)
- RandomForestSuite.validateClassifier(rf, arr, 0.9)
+ EnsembleTestHelper.validateClassifier(rf, arr, 0.9)
DecisionTreeSuite.validateClassifier(dt, arr, 0.9)
// Make sure trees are the same.
assert(rfTree.toString == dt.toString)
}
- test("Regression with continuous features:" +
+ test("Binary classification with continuous features:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ binaryClassificationTestWithContinuousFeatures(strategy)
+ }
+
+ test("Binary classification with continuous features and node Id cache :" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+ binaryClassificationTestWithContinuousFeatures(strategy)
+ }
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
+ def regressionTestWithContinuousFeatures(strategy: Strategy) {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
- val categoricalFeaturesInfo = Map.empty[Int, Int]
val numTrees = 1
- val strategy = new Strategy(algo = Regression, impurity = Variance,
- maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
- categoricalFeaturesInfo = categoricalFeaturesInfo)
-
val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
- assert(rf.trees.size === 1)
- val rfTree = rf.trees(0)
+ assert(rf.weakHypotheses.size === 1)
+ val rfTree = rf.weakHypotheses(0)
val dt = DecisionTree.train(rdd, strategy)
- RandomForestSuite.validateRegressor(rf, arr, 0.01)
+ EnsembleTestHelper.validateRegressor(rf, arr, 0.01)
DecisionTreeSuite.validateRegressor(dt, arr, 0.01)
// Make sure trees are the same.
assert(rfTree.toString == dt.toString)
}
- test("Binary classification with continuous features: subsampling features") {
- val numFeatures = 50
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures)
- val rdd = sc.parallelize(arr)
+ test("Regression with continuous features:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Regression, impurity = Variance,
+ maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
+ categoricalFeaturesInfo = categoricalFeaturesInfo)
+ regressionTestWithContinuousFeatures(strategy)
+ }
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ test("Regression with continuous features and node Id cache :" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Regression, impurity = Variance,
+ maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
+ categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+ regressionTestWithContinuousFeatures(strategy)
+ }
+
+ def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: Strategy) {
+ val numFeatures = 50
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
+ val rdd = sc.parallelize(arr)
// Select feature subset for top nodes. Return true if OK.
def checkFeatureSubsetStrategy(
@@ -174,6 +165,20 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
}
+ test("Binary classification with continuous features: subsampling features") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+ }
+
+ test("Binary classification with continuous features and node Id cache: subsampling features") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+ binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+ }
+
test("alternating categorical and continuous features with multiclass labels to test indexing") {
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0))
@@ -187,77 +192,8 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
featureSubsetStrategy = "sqrt", seed = 12345)
- RandomForestSuite.validateClassifier(model, arr, 0.0)
+ EnsembleTestHelper.validateClassifier(model, arr, 1.0)
}
-
}
-object RandomForestSuite {
-
- /**
- * Aggregates all values in data, and tests whether the empirical mean and stddev are within
- * epsilon of the expected values.
- * @param data Every element of the data should be an i.i.d. sample from some distribution.
- */
- def testRandomArrays(
- data: Array[Array[Double]],
- numCols: Int,
- expectedMean: Double,
- expectedStddev: Double,
- epsilon: Double) {
- val values = new mutable.ArrayBuffer[Double]()
- data.foreach { row =>
- assert(row.size == numCols)
- values ++= row
- }
- val stats = new StatCounter(values)
- assert(math.abs(stats.mean - expectedMean) < epsilon)
- assert(math.abs(stats.stdev - expectedStddev) < epsilon)
- }
-
- def validateClassifier(
- model: RandomForestModel,
- input: Seq[LabeledPoint],
- requiredAccuracy: Double) {
- val predictions = input.map(x => model.predict(x.features))
- val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
- prediction != expected.label
- }
- val accuracy = (input.length - numOffPredictions).toDouble / input.length
- assert(accuracy >= requiredAccuracy,
- s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
- }
- def validateRegressor(
- model: RandomForestModel,
- input: Seq[LabeledPoint],
- requiredMSE: Double) {
- val predictions = input.map(x => model.predict(x.features))
- val squaredError = predictions.zip(input).map { case (prediction, expected) =>
- val err = prediction - expected.label
- err * err
- }.sum
- val mse = squaredError / input.length
- assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
- }
-
- def generateOrderedLabeledPoints(numFeatures: Int): Array[LabeledPoint] = {
- val numInstances = 1000
- val arr = new Array[LabeledPoint](numInstances)
- for (i <- 0 until numInstances) {
- val label = if (i < numInstances / 10) {
- 0.0
- } else if (i < numInstances / 2) {
- 1.0
- } else if (i < numInstances * 0.9) {
- 0.0
- } else {
- 1.0
- }
- val features = Array.fill[Double](numFeatures)(i.toDouble)
- arr(i) = new LabeledPoint(label, Vectors.dense(features))
- }
- arr
- }
-
-}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
new file mode 100644
index 0000000000000..5cb433232e714
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.tree.EnsembleTestHelper
+import org.apache.spark.mllib.util.LocalSparkContext
+
+/**
+ * Test suite for [[BaggedPoint]].
+ */
+class BaggedPointSuite extends FunSuite with LocalSparkContext {
+
+ test("BaggedPoint RDD: without subsampling") {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42)
+ baggedRDD.collect().foreach { baggedPoint =>
+ assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling with replacement (fraction = 1.0)") {
+ val numSubsamples = 100
+ val (expectedMean, expectedStddev) = (1.0, 1.0)
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling with replacement (fraction = 0.5)") {
+ val numSubsamples = 100
+ val subsample = 0.5
+ val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample))
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling without replacement (fraction = 1.0)") {
+ val numSubsamples = 100
+ val (expectedMean, expectedStddev) = (1.0, 0)
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling without replacement (fraction = 0.5)") {
+ val numSubsamples = 100
+ val subsample = 0.5
+ val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample * (1 - subsample)))
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+}
diff --git a/network/common/pom.xml b/network/common/pom.xml
index e3b7e328701b4..8b24ebf1ba1f2 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -27,12 +27,12 @@
org.apache.spark
- network
+ spark-network-common_2.10jar
- Shuffle Streaming Service
+ Spark Project Networkinghttp://spark.apache.org/
- network
+ network-common
@@ -50,6 +50,7 @@
com.google.guavaguava
+ 11.0.2provided
@@ -59,6 +60,11 @@
junittest
+
+ com.novocode
+ junit-interface
+ test
+ log4jlog4j
@@ -69,25 +75,36 @@
mockito-alltest
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ test
+
-
- target/java/classes
- target/java/test-classes
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
org.apache.maven.plugins
- maven-surefire-plugin
- 2.17
-
- false
-
- **/Test*.java
- **/*Test.java
- **/*Suite.java
-
-
+ maven-jar-plugin
+ 2.2
+
+
+
+ test-jar
+
+
+
+ test-jar-on-test-compile
+ test-compile
+
+ test-jar
+
+
+
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index 854aa6685f85f..5bc6e5a2418a9 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -17,12 +17,16 @@
package org.apache.spark.network;
+import java.util.List;
+
+import com.google.common.collect.Lists;
import io.netty.channel.Channel;
import io.netty.channel.socket.SocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.MessageDecoder;
@@ -52,26 +56,39 @@ public class TransportContext {
private final Logger logger = LoggerFactory.getLogger(TransportContext.class);
private final TransportConf conf;
- private final StreamManager streamManager;
private final RpcHandler rpcHandler;
private final MessageEncoder encoder;
private final MessageDecoder decoder;
- public TransportContext(TransportConf conf, StreamManager streamManager, RpcHandler rpcHandler) {
+ public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
this.conf = conf;
- this.streamManager = streamManager;
this.rpcHandler = rpcHandler;
this.encoder = new MessageEncoder();
this.decoder = new MessageDecoder();
}
+ /**
+ * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning
+ * a new Client. Bootstraps will be executed synchronously, and must run successfully in order
+ * to create a Client.
+ */
+ public TransportClientFactory createClientFactory(List bootstraps) {
+ return new TransportClientFactory(this, bootstraps);
+ }
+
public TransportClientFactory createClientFactory() {
- return new TransportClientFactory(this);
+ return createClientFactory(Lists.newArrayList());
+ }
+
+ /** Create a server which will attempt to bind to a specific port. */
+ public TransportServer createServer(int port) {
+ return new TransportServer(this, port);
}
+ /** Creates a new server, binding to any available ephemeral port. */
public TransportServer createServer() {
- return new TransportServer(this);
+ return new TransportServer(this, 0);
}
/**
@@ -109,7 +126,7 @@ private TransportChannelHandler createChannelHandler(Channel channel) {
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
- streamManager, rpcHandler);
+ rpcHandler);
return new TransportChannelHandler(client, responseHandler, requestHandler);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
index 89ed79bc63903..844eff4f4c701 100644
--- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
@@ -30,24 +30,20 @@
import io.netty.channel.DefaultFileRegion;
import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.network.util.TransportConf;
/**
* A {@link ManagedBuffer} backed by a segment in a file.
*/
public final class FileSegmentManagedBuffer extends ManagedBuffer {
-
- /**
- * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889).
- * Avoid unless there's a good reason not to.
- */
- // TODO: Make this configurable
- private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024;
-
+ private final TransportConf conf;
private final File file;
private final long offset;
private final long length;
- public FileSegmentManagedBuffer(File file, long offset, long length) {
+ public FileSegmentManagedBuffer(TransportConf conf, File file, long offset, long length) {
+ this.conf = conf;
this.file = file;
this.offset = offset;
this.length = length;
@@ -64,7 +60,7 @@ public ByteBuffer nioByteBuffer() throws IOException {
try {
channel = new RandomAccessFile(file, "r").getChannel();
// Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead.
- if (length < MIN_MEMORY_MAP_BYTES) {
+ if (length < conf.memoryMapBytes()) {
ByteBuffer buf = ByteBuffer.allocate((int) length);
channel.position(offset);
while (buf.remaining() != 0) {
@@ -101,7 +97,7 @@ public InputStream createInputStream() throws IOException {
try {
is = new FileInputStream(file);
ByteStreams.skipFully(is, offset);
- return ByteStreams.limit(is, length);
+ return new LimitedInputStream(is, length);
} catch (IOException e) {
try {
if (is != null) {
@@ -133,8 +129,12 @@ public ManagedBuffer release() {
@Override
public Object convertToNetty() throws IOException {
- FileChannel fileChannel = new FileInputStream(file).getChannel();
- return new DefaultFileRegion(fileChannel, offset, length);
+ if (conf.lazyFileDescriptor()) {
+ return new LazyFileRegion(file, offset, length);
+ } else {
+ FileChannel fileChannel = new FileInputStream(file).getChannel();
+ return new DefaultFileRegion(fileChannel, offset, length);
+ }
}
public File getFile() { return file; }
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java b/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java
new file mode 100644
index 0000000000000..81bc8ec40fc82
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.FileInputStream;
+import java.io.File;
+import java.io.IOException;
+import java.nio.channels.FileChannel;
+import java.nio.channels.WritableByteChannel;
+
+import com.google.common.base.Objects;
+import io.netty.channel.FileRegion;
+import io.netty.util.AbstractReferenceCounted;
+
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * A FileRegion implementation that only creates the file descriptor when the region is being
+ * transferred. This cannot be used with Epoll because there is no native support for it.
+ *
+ * This is mostly copied from DefaultFileRegion implementation in Netty. In the future, we
+ * should push this into Netty so the native Epoll transport can support this feature.
+ */
+public final class LazyFileRegion extends AbstractReferenceCounted implements FileRegion {
+
+ private final File file;
+ private final long position;
+ private final long count;
+
+ private FileChannel channel;
+
+ private long numBytesTransferred = 0L;
+
+ /**
+ * @param file file to transfer.
+ * @param position start position for the transfer.
+ * @param count number of bytes to transfer starting from position.
+ */
+ public LazyFileRegion(File file, long position, long count) {
+ this.file = file;
+ this.position = position;
+ this.count = count;
+ }
+
+ @Override
+ protected void deallocate() {
+ JavaUtils.closeQuietly(channel);
+ }
+
+ @Override
+ public long position() {
+ return position;
+ }
+
+ @Override
+ public long transfered() {
+ return numBytesTransferred;
+ }
+
+ @Override
+ public long count() {
+ return count;
+ }
+
+ @Override
+ public long transferTo(WritableByteChannel target, long position) throws IOException {
+ if (channel == null) {
+ channel = new FileInputStream(file).getChannel();
+ }
+
+ long count = this.count - position;
+ if (count < 0 || position < 0) {
+ throw new IllegalArgumentException(
+ "position out of range: " + position + " (expected: 0 - " + (count - 1) + ')');
+ }
+
+ if (count == 0) {
+ return 0L;
+ }
+
+ long written = channel.transferTo(this.position + position, count, target);
+ if (written > 0) {
+ numBytesTransferred += written;
+ }
+ return written;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("file", file)
+ .add("position", position)
+ .add("count", count)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index b1732fcde21f1..4e944114e8176 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -18,10 +18,15 @@
package org.apache.spark.network.client;
import java.io.Closeable;
+import java.io.IOException;
import java.util.UUID;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
+import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.util.concurrent.SettableFuture;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
@@ -113,8 +118,12 @@ public void operationComplete(ChannelFuture future) throws Exception {
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeFetchRequest(streamChunkId);
- callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
channel.close();
+ try {
+ callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
}
}
});
@@ -129,7 +138,7 @@ public void sendRpc(byte[] message, final RpcResponseCallback callback) {
final long startTime = System.currentTimeMillis();
logger.trace("Sending RPC to {}", serverAddr);
- final long requestId = UUID.randomUUID().getLeastSignificantBits();
+ final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
handler.addRpcRequest(requestId, callback);
channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
@@ -144,16 +153,56 @@ public void operationComplete(ChannelFuture future) throws Exception {
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeRpcRequest(requestId);
- callback.onFailure(new RuntimeException(errorMsg, future.cause()));
channel.close();
+ try {
+ callback.onFailure(new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
}
}
});
}
+ /**
+ * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
+ * a specified timeout for a response.
+ */
+ public byte[] sendRpcSync(byte[] message, long timeoutMs) {
+ final SettableFuture result = SettableFuture.create();
+
+ sendRpc(message, new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] response) {
+ result.set(response);
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ result.setException(e);
+ }
+ });
+
+ try {
+ return result.get(timeoutMs, TimeUnit.MILLISECONDS);
+ } catch (ExecutionException e) {
+ throw Throwables.propagate(e.getCause());
+ } catch (Exception e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
@Override
public void close() {
// close is a local operation and should finish with milliseconds; timeout just to be safe
channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
}
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("remoteAdress", channel.remoteAddress())
+ .add("isActive", isActive())
+ .toString();
+ }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
new file mode 100644
index 0000000000000..65e8020e34121
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * A bootstrap which is executed on a TransportClient before it is returned to the user.
+ * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per-
+ * connection basis.
+ *
+ * Since connections (and TransportClients) are reused as much as possible, it is generally
+ * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with
+ * the JVM itself.
+ */
+public interface TransportClientBootstrap {
+ /** Performs the bootstrapping operation, throwing an exception on failure. */
+ public void doBootstrap(TransportClient client) throws RuntimeException;
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 10eb9ef7a025f..397d3a8455c86 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -18,13 +18,17 @@
package org.apache.spark.network.client;
import java.io.Closeable;
+import java.io.IOException;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
+import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.Lists;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
@@ -47,22 +51,29 @@
* Factory for creating {@link TransportClient}s by using createClient.
*
* The factory maintains a connection pool to other hosts and should return the same
- * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for
- * all {@link TransportClient}s.
+ * TransportClient for the same remote host. It also shares a single worker thread pool for
+ * all TransportClients.
+ *
+ * TransportClients will be reused whenever possible. Prior to completing the creation of a new
+ * TransportClient, all given {@link TransportClientBootstrap}s will be run.
*/
public class TransportClientFactory implements Closeable {
private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
private final TransportContext context;
private final TransportConf conf;
+ private final List clientBootstraps;
private final ConcurrentHashMap connectionPool;
private final Class extends Channel> socketChannelClass;
- private final EventLoopGroup workerGroup;
+ private EventLoopGroup workerGroup;
- public TransportClientFactory(TransportContext context) {
- this.context = context;
+ public TransportClientFactory(
+ TransportContext context,
+ List clientBootstraps) {
+ this.context = Preconditions.checkNotNull(context);
this.conf = context.getConf();
+ this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
this.connectionPool = new ConcurrentHashMap();
IOMode ioMode = IOMode.valueOf(conf.ioMode());
@@ -72,21 +83,28 @@ public TransportClientFactory(TransportContext context) {
}
/**
- * Create a new BlockFetchingClient connecting to the given remote host / port.
+ * Create a new {@link TransportClient} connecting to the given remote host / port. This will
+ * reuse TransportClients if they are still active and are for the same remote address. Prior
+ * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s
+ * that are registered with this factory.
*
- * This blocks until a connection is successfully established.
+ * This blocks until a connection is successfully established and fully bootstrapped.
*
* Concurrency: This method is safe to call from multiple threads.
*/
- public TransportClient createClient(String remoteHost, int remotePort) throws TimeoutException {
+ public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
TransportClient cachedClient = connectionPool.get(address);
- if (cachedClient != null && cachedClient.isActive()) {
- return cachedClient;
- } else if (cachedClient != null) {
- connectionPool.remove(address, cachedClient); // Remove inactive clients.
+ if (cachedClient != null) {
+ if (cachedClient.isActive()) {
+ logger.trace("Returning cached connection to {}: {}", address, cachedClient);
+ return cachedClient;
+ } else {
+ logger.info("Found inactive connection to {}, closing it.", address);
+ connectionPool.remove(address, cachedClient); // Remove inactive clients.
+ }
}
logger.debug("Creating new connection to " + address);
@@ -102,33 +120,55 @@ public TransportClient createClient(String remoteHost, int remotePort) throws Ti
// Use pooled buffers to reduce temporary buffer allocation
bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator());
- final AtomicReference client = new AtomicReference();
+ final AtomicReference clientRef = new AtomicReference();
bootstrap.handler(new ChannelInitializer() {
@Override
public void initChannel(SocketChannel ch) {
TransportChannelHandler clientHandler = context.initializePipeline(ch);
- client.set(clientHandler.getClient());
+ clientRef.set(clientHandler.getClient());
}
});
// Connect to the remote server
+ long preConnect = System.currentTimeMillis();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
- throw new TimeoutException(
+ throw new IOException(
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
} else if (cf.cause() != null) {
- throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
+ throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
+ }
+
+ TransportClient client = clientRef.get();
+ assert client != null : "Channel future completed successfully with null client";
+
+ // Execute any client bootstraps synchronously before marking the Client as successful.
+ long preBootstrap = System.currentTimeMillis();
+ logger.debug("Connection to {} successful, running bootstraps...", address);
+ try {
+ for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
+ clientBootstrap.doBootstrap(client);
+ }
+ } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
+ long bootstrapTime = System.currentTimeMillis() - preBootstrap;
+ logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e);
+ client.close();
+ throw Throwables.propagate(e);
}
+ long postBootstrap = System.currentTimeMillis();
- // Successful connection
- assert client.get() != null : "Channel future completed successfully with null client";
- TransportClient oldClient = connectionPool.putIfAbsent(address, client.get());
+ // Successful connection & bootstrap -- in the event that two threads raced to create a client,
+ // use the first one that was put into the connectionPool and close the one we made here.
+ TransportClient oldClient = connectionPool.putIfAbsent(address, client);
if (oldClient == null) {
- return client.get();
+ logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
+ address, postBootstrap - preConnect, postBootstrap - preBootstrap);
+ return client;
} else {
- logger.debug("Two clients were created concurrently, second one will be disposed.");
- client.get().close();
+ logger.debug("Two clients were created concurrently after {} ms, second will be disposed.",
+ postBootstrap - preConnect);
+ client.close();
return oldClient;
}
}
@@ -147,6 +187,7 @@ public void close() {
if (workerGroup != null) {
workerGroup.shutdownGracefully();
+ workerGroup = null;
}
}
@@ -158,7 +199,7 @@ public void close() {
*/
private PooledByteBufAllocator createPooledByteBufAllocator() {
return new PooledByteBufAllocator(
- PlatformDependent.directBufferPreferred(),
+ conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(),
getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
getPrivateStaticField("DEFAULT_PAGE_SIZE"),
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index d8965590b34da..2044afb0d85db 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -17,6 +17,7 @@
package org.apache.spark.network.client;
+import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@@ -94,7 +95,7 @@ public void channelUnregistered() {
String remoteAddress = NettyUtils.getRemoteAddress(channel);
logger.error("Still have {} requests outstanding when connection from {} is closed",
numOutstandingRequests(), remoteAddress);
- failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
+ failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
index 152af98ced7ce..986957c1509fd 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -38,23 +38,19 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) {
@Override
public int encodedLength() {
- return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString);
}
@Override
public void encode(ByteBuf buf) {
streamChunkId.encode(buf);
- byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
- buf.writeInt(errorBytes.length);
- buf.writeBytes(errorBytes);
+ Encoders.Strings.encode(buf, errorString);
}
public static ChunkFetchFailure decode(ByteBuf buf) {
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
- int numErrorStringBytes = buf.readInt();
- byte[] errorBytes = new byte[numErrorStringBytes];
- buf.readBytes(errorBytes);
- return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8));
+ String errorString = Encoders.Strings.decode(buf);
+ return new ChunkFetchFailure(streamChunkId, errorString);
}
@Override
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java
new file mode 100644
index 0000000000000..873c694250942
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+
+import com.google.common.base.Charsets;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+/** Provides a canonical set of Encoders for simple types. */
+public class Encoders {
+
+ /** Strings are encoded with their length followed by UTF-8 bytes. */
+ public static class Strings {
+ public static int encodedLength(String s) {
+ return 4 + s.getBytes(Charsets.UTF_8).length;
+ }
+
+ public static void encode(ByteBuf buf, String s) {
+ byte[] bytes = s.getBytes(Charsets.UTF_8);
+ buf.writeInt(bytes.length);
+ buf.writeBytes(bytes);
+ }
+
+ public static String decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return new String(bytes, Charsets.UTF_8);
+ }
+ }
+
+ /** Byte arrays are encoded with their length followed by bytes. */
+ public static class ByteArrays {
+ public static int encodedLength(byte[] arr) {
+ return 4 + arr.length;
+ }
+
+ public static void encode(ByteBuf buf, byte[] arr) {
+ buf.writeInt(arr.length);
+ buf.writeBytes(arr);
+ }
+
+ public static byte[] decode(ByteBuf buf) {
+ int length = buf.readInt();
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return bytes;
+ }
+ }
+
+ /** String arrays are encoded with the number of strings followed by per-String encoding. */
+ public static class StringArrays {
+ public static int encodedLength(String[] strings) {
+ int totalLength = 4;
+ for (String s : strings) {
+ totalLength += Strings.encodedLength(s);
+ }
+ return totalLength;
+ }
+
+ public static void encode(ByteBuf buf, String[] strings) {
+ buf.writeInt(strings.length);
+ for (String s : strings) {
+ Strings.encode(buf, s);
+ }
+ }
+
+ public static String[] decode(ByteBuf buf) {
+ int numStrings = buf.readInt();
+ String[] strings = new String[numStrings];
+ for (int i = 0; i < strings.length; i ++) {
+ strings[i] = Strings.decode(buf);
+ }
+ return strings;
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 4cb8becc3ed22..91d1e8a538a77 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -66,7 +66,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List