diff --git a/.gitignore b/.gitignore
index 3d178992123da..857e9feb953bd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,7 @@
sbt/*.jar
.settings
.cache
+.mima-excludes
/build/
work/
out/
@@ -17,6 +18,7 @@ conf/java-opts
conf/spark-env.sh
conf/streaming-env.sh
conf/log4j.properties
+conf/spark-defaults.conf
docs/_site
docs/api
target/
@@ -45,3 +47,5 @@ dist/
spark-*-bin.tar.gz
unit-tests.log
/lib/
+rat-results.txt
+scalastyle.txt
diff --git a/.rat-excludes b/.rat-excludes
new file mode 100644
index 0000000000000..50766954ef070
--- /dev/null
+++ b/.rat-excludes
@@ -0,0 +1,45 @@
+target
+.gitignore
+.project
+.classpath
+.mima-excludes
+.rat-excludes
+.*md
+derby.log
+TAGS
+RELEASE
+control
+docs
+fairscheduler.xml.template
+spark-defaults.conf.template
+log4j.properties
+log4j.properties.template
+metrics.properties.template
+slaves
+spark-env.sh
+spark-env.sh.template
+log4j-defaults.properties
+sorttable.js
+.*txt
+.*data
+.*log
+cloudpickle.py
+join.py
+SparkExprTyper.scala
+SparkILoop.scala
+SparkILoopInit.scala
+SparkIMain.scala
+SparkImports.scala
+SparkJLineCompletion.scala
+SparkJLineReader.scala
+SparkMemberHandlers.scala
+sbt
+sbt-launch-lib.bash
+plugins.sbt
+work
+.*\.q
+golden
+test.out/*
+.*iml
+service.properties
+db.lck
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 0000000000000..8ebd0d68429fc
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+ language: scala
+ scala:
+ - "2.10.3"
+ jdk:
+ - oraclejdk7
+ env:
+ matrix:
+ - TEST="scalastyle assembly/assembly"
+ - TEST="catalyst/test sql/test streaming/test mllib/test graphx/test bagel/test"
+ - TEST=hive/test
+ cache:
+ directories:
+ - $HOME/.m2
+ - $HOME/.ivy2
+ - $HOME/.sbt
+ script:
+ - "sbt ++$TRAVIS_SCALA_VERSION $TEST"
diff --git a/LICENSE b/LICENSE
index 1c166d1333614..1c1c2c0255fa9 100644
--- a/LICENSE
+++ b/LICENSE
@@ -396,3 +396,35 @@ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
+
+
+========================================================================
+For sbt and sbt-launch-lib.bash in sbt/:
+========================================================================
+
+// Generated from http://www.opensource.org/licenses/bsd-license.php
+Copyright (c) 2011, Paul Phillips.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice,
+ this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+ * Neither the name of the author nor the names of its contributors may be
+ used to endorse or promote products derived from this software without
+ specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
+EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/NOTICE b/NOTICE
index 7cbb114b2ae2d..42f6c3a835725 100644
--- a/NOTICE
+++ b/NOTICE
@@ -1,5 +1,14 @@
Apache Spark
-Copyright 2013 The Apache Software Foundation.
+Copyright 2014 The Apache Software Foundation.
This product includes software developed at
The Apache Software Foundation (http://www.apache.org/).
+
+In addition, this product includes:
+
+- JUnit (http://www.junit.org) is a testing framework for Java. We included it
+ under the terms of the Eclipse Public License v1.0.
+
+- JTransforms (https://sites.google.com/site/piotrwendykier/software/jtransforms)
+ provides fast transforms in Java. It is tri-licensed, and we included it under
+ the terms of the Mozilla Public License v1.1.
diff --git a/README.md b/README.md
index c840a68f76b17..e2d1dcb5672ff 100644
--- a/README.md
+++ b/README.md
@@ -1,29 +1,42 @@
# Apache Spark
-Lightning-Fast Cluster Computing -
+Lightning-Fast Cluster Computing -
## Online Documentation
You can find the latest Spark documentation, including a programming
-guide, on the project webpage at .
+guide, on the project webpage at .
This README file only contains basic setup instructions.
-## Building
+## Building Spark
-Spark requires Scala 2.10. The project is built using Simple Build Tool (SBT),
-which can be obtained [here](http://www.scala-sbt.org). If SBT is installed we
-will use the system version of sbt otherwise we will attempt to download it
-automatically. To build Spark and its example programs, run:
+Spark is built on Scala 2.10. To build Spark and its example programs, run:
./sbt/sbt assembly
-Once you've built Spark, the easiest way to start using it is the shell:
+## Interactive Scala Shell
+
+The easiest way to start using Spark is through the Scala shell:
./bin/spark-shell
-Or, for the Python API, the Python shell (`./bin/pyspark`).
+Try the following command, which should return 1000:
+
+ scala> sc.parallelize(1 to 1000).count()
+
+## Interactive Python Shell
+
+Alternatively, if you prefer Python, you can use the Python shell:
+
+ ./bin/pyspark
+
+And run the following command, which should also return 1000:
+
+ >>> sc.parallelize(range(1000)).count()
+
+## Example Programs
Spark also comes with several sample programs in the `examples` directory.
To run one of them, use `./bin/run-example `. For example:
@@ -38,13 +51,13 @@ All of the Spark samples take a `` parameter that is the cluster URL
to connect to. This can be a mesos:// or spark:// URL, or "local" to run
locally with one thread, or "local[N]" to run locally with N threads.
-## Running tests
+## Running Tests
-Testing first requires [Building](#building) Spark. Once Spark is built, tests
+Testing first requires [building Spark](#building-spark). Once Spark is built, tests
can be run using:
-`./sbt/sbt test`
-
+ ./sbt/sbt test
+
## A Note About Hadoop Versions
Spark uses the Hadoop core library to talk to HDFS and other Hadoop-supported
@@ -92,21 +105,10 @@ If your project is built with Maven, add this to your POM file's `
## Configuration
-Please refer to the [Configuration guide](http://spark.incubator.apache.org/docs/latest/configuration.html)
+Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html)
in the online documentation for an overview on how to configure Spark.
-## Apache Incubator Notice
-
-Apache Spark is an effort undergoing incubation at The Apache Software
-Foundation (ASF), sponsored by the Apache Incubator. Incubation is required of
-all newly accepted projects until a further review indicates that the
-infrastructure, communications, and decision making process have stabilized in
-a manner consistent with other successful ASF projects. While incubation status
-is not necessarily a reflection of the completeness or stability of the code,
-it does indicate that the project has yet to be fully endorsed by the ASF.
-
-
## Contributing to Spark
Contributions via GitHub pull requests are gladly accepted from their original
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 82396040251d3..bdb38806492a6 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,17 +21,20 @@
org.apache.sparkspark-parent
- 1.0.0-incubating-SNAPSHOT
+ 1.0.0-SNAPSHOT../pom.xmlorg.apache.sparkspark-assembly_2.10Spark Project Assembly
- http://spark.incubator.apache.org/
+ http://spark.apache.org/
+ pom
- ${project.build.directory}/scala-${scala.binary.version}/${project.artifactId}-${project.version}-hadoop${hadoop.version}.jar
+ scala-${scala.binary.version}
+ spark-assembly-${project.version}-hadoop${hadoop.version}.jar
+ ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename}spark/usr/share/sparkroot
@@ -76,6 +79,11 @@
spark-graphx_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ net.sf.py4jpy4j
@@ -155,6 +163,26 @@
+
+ hive
+
+
+ org.apache.spark
+ spark-hive_${scala.binary.version}
+ ${project.version}
+
+
+
+
+ spark-ganglia-lgpl
+
+
+ org.apache.spark
+ spark-ganglia-lgpl_${scala.binary.version}
+ ${project.version}
+
+
+ bigtop-dist
+ yarn-alpha
+
+
+ org.apache.avro
+ avro
+
+
+
+
diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala
index dd3eed8affe39..70a99b33d753c 100644
--- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala
+++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala
@@ -27,7 +27,7 @@ object Bagel extends Logging {
/**
* Runs a Bagel program.
- * @param sc [[org.apache.spark.SparkContext]] to use for the program.
+ * @param sc org.apache.spark.SparkContext to use for the program.
* @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the
* Key will be the vertex id.
* @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often
@@ -38,10 +38,10 @@ object Bagel extends Logging {
* @param aggregator [[org.apache.spark.bagel.Aggregator]] performs a reduce across all vertices
* after each superstep and provides the result to each vertex in the next
* superstep.
- * @param partitioner [[org.apache.spark.Partitioner]] partitions values by key
+ * @param partitioner org.apache.spark.Partitioner partitions values by key
* @param numPartitions number of partitions across which to split the graph.
* Default is the default parallelism of the SparkContext
- * @param storageLevel [[org.apache.spark.storage.StorageLevel]] to use for caching of
+ * @param storageLevel org.apache.spark.storage.StorageLevel to use for caching of
* intermediate RDDs in each superstep. Defaults to caching in memory.
* @param compute function that takes a Vertex, optional set of (possibly combined) messages to
* the Vertex, optional Aggregator and the current superstep,
@@ -131,7 +131,7 @@ object Bagel extends Logging {
/**
* Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], default
- * [[org.apache.spark.HashPartitioner]] and default storage level
+ * org.apache.spark.HashPartitioner and default storage level
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
@@ -146,7 +146,7 @@ object Bagel extends Logging {
/**
* Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the
- * default [[org.apache.spark.HashPartitioner]]
+ * default org.apache.spark.HashPartitioner
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
@@ -166,7 +166,7 @@ object Bagel extends Logging {
/**
* Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]],
- * default [[org.apache.spark.HashPartitioner]],
+ * default org.apache.spark.HashPartitioner,
* [[org.apache.spark.bagel.DefaultCombiner]] and the default storage level
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
@@ -180,7 +180,7 @@ object Bagel extends Logging {
/**
* Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]],
- * the default [[org.apache.spark.HashPartitioner]]
+ * the default org.apache.spark.HashPartitioner
* and [[org.apache.spark.bagel.DefaultCombiner]]
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
@@ -220,20 +220,23 @@ object Bagel extends Logging {
*/
private def comp[K: Manifest, V <: Vertex, M <: Message[K], C](
sc: SparkContext,
- grouped: RDD[(K, (Seq[C], Seq[V]))],
+ grouped: RDD[(K, (Iterable[C], Iterable[V]))],
compute: (V, Option[C]) => (V, Array[M]),
storageLevel: StorageLevel
): (RDD[(K, (V, Array[M]))], Int, Int) = {
var numMsgs = sc.accumulator(0)
var numActiveVerts = sc.accumulator(0)
- val processed = grouped.flatMapValues {
- case (_, vs) if vs.size == 0 => None
- case (c, vs) =>
+ val processed = grouped.mapValues(x => (x._1.iterator, x._2.iterator))
+ .flatMapValues {
+ case (_, vs) if !vs.hasNext => None
+ case (c, vs) => {
val (newVert, newMsgs) =
- compute(vs(0), c match {
- case Seq(comb) => Some(comb)
- case Seq() => None
- })
+ compute(vs.next,
+ c.hasNext match {
+ case true => Some(c.next)
+ case false => None
+ }
+ )
numMsgs += newMsgs.size
if (newVert.active) {
@@ -241,6 +244,7 @@ object Bagel extends Logging {
}
Some((newVert, newMsgs))
+ }
}.persist(storageLevel)
// Force evaluation of processed RDD for accurate performance measurements
diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
index 9c37fadb78d2f..8e0f82ddb8897 100644
--- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
@@ -24,13 +24,15 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark._
import org.apache.spark.storage.StorageLevel
+import scala.language.postfixOps
+
class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable
class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts {
-
+
var sc: SparkContext = _
-
+
after {
if (sc != null) {
sc.stop()
diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd
index 4f60bff19cb93..065553eb31939 100644
--- a/bin/compute-classpath.cmd
+++ b/bin/compute-classpath.cmd
@@ -1,69 +1,88 @@
-@echo off
-
-rem
-rem Licensed to the Apache Software Foundation (ASF) under one or more
-rem contributor license agreements. See the NOTICE file distributed with
-rem this work for additional information regarding copyright ownership.
-rem The ASF licenses this file to You under the Apache License, Version 2.0
-rem (the "License"); you may not use this file except in compliance with
-rem the License. You may obtain a copy of the License at
-rem
-rem http://www.apache.org/licenses/LICENSE-2.0
-rem
-rem Unless required by applicable law or agreed to in writing, software
-rem distributed under the License is distributed on an "AS IS" BASIS,
-rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-rem See the License for the specific language governing permissions and
-rem limitations under the License.
-rem
-
-rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
-rem script and the ExecutorRunner in standalone cluster mode.
-
-set SCALA_VERSION=2.10
-
-rem Figure out where the Spark framework is installed
-set FWDIR=%~dp0..\
-
-rem Load environment variables from conf\spark-env.cmd, if it exists
-if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
-
-rem Build up classpath
-set CLASSPATH=%FWDIR%conf
-if exist "%FWDIR%RELEASE" (
- for %%d in ("%FWDIR%jars\spark-assembly*.jar") do (
- set ASSEMBLY_JAR=%%d
- )
-) else (
- for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do (
- set ASSEMBLY_JAR=%%d
- )
-)
-set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR%
-
-if "x%SPARK_TESTING%"=="x1" (
- rem Add test clases to path
- set CLASSPATH=%CLASSPATH%;%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes
- set CLASSPATH=%CLASSPATH%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes
- set CLASSPATH=%CLASSPATH%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes
- set CLASSPATH=%CLASSPATH%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes
- set CLASSPATH=%CLASSPATH%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes
-)
-
-rem Add hadoop conf dir - else FileSystem.*, etc fail
-rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
-rem the configurtion files.
-if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
- set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
-:no_hadoop_conf_dir
-
-if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
- set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
-:no_yarn_conf_dir
-
-rem A bit of a hack to allow calling this script within run2.cmd without seeing output
-if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
-
-echo %CLASSPATH%
-
-:exit
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
+rem script and the ExecutorRunner in standalone cluster mode.
+
+set SCALA_VERSION=2.10
+
+rem Figure out where the Spark framework is installed
+set FWDIR=%~dp0..\
+
+rem Load environment variables from conf\spark-env.cmd, if it exists
+if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
+
+rem Build up classpath
+set CLASSPATH=%FWDIR%conf
+if exist "%FWDIR%RELEASE" (
+ for %%d in ("%FWDIR%jars\spark-assembly*.jar") do (
+ set ASSEMBLY_JAR=%%d
+ )
+) else (
+ for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do (
+ set ASSEMBLY_JAR=%%d
+ )
+)
+
+set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR%
+
+set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes
+
+set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes
+
+if "x%SPARK_TESTING%"=="x1" (
+ rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH
+ rem so that local compilation takes precedence over assembled jar
+ set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH%
+)
+
+rem Add hadoop conf dir - else FileSystem.*, etc fail
+rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
+rem the configurtion files.
+if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
+ set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
+:no_hadoop_conf_dir
+
+if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
+ set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
+:no_yarn_conf_dir
+
+rem A bit of a hack to allow calling this script within run2.cmd without seeing output
+if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
+
+echo %CLASSPATH%
+
+:exit
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index 278969655de48..3a59f599fd7d2 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -25,35 +25,55 @@ SCALA_VERSION=2.10
# Figure out where Spark is installed
FWDIR="$(cd `dirname $0`/..; pwd)"
-# Load environment variables from conf/spark-env.sh, if it exists
-if [ -e "$FWDIR/conf/spark-env.sh" ] ; then
- . $FWDIR/conf/spark-env.sh
-fi
+. $FWDIR/bin/load-spark-env.sh
# Build up classpath
CLASSPATH="$SPARK_CLASSPATH:$FWDIR/conf"
+ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION"
+
# First check if we have a dependencies jar. If so, include binary classes with the deps jar
-if [ -f "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar ]; then
+if [ -f "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar ]; then
CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes"
- DEPS_ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar`
+ DEPS_ASSEMBLY_JAR=`ls "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar`
CLASSPATH="$CLASSPATH:$DEPS_ASSEMBLY_JAR"
else
# Else use spark-assembly jar from either RELEASE or assembly directory
if [ -f "$FWDIR/RELEASE" ]; then
- ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
+ ASSEMBLY_JAR=`ls "$FWDIR"/lib/spark-assembly*hadoop*.jar`
else
- ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
+ ASSEMBLY_JAR=`ls "$ASSEMBLY_DIR"/spark-assembly*hadoop*.jar`
fi
CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
fi
+# When Hive support is needed, Datanucleus jars must be included on the classpath.
+# Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost.
+# Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is
+# built with Hive, so first check if the datanucleus jars exist, and then ensure the current Spark
+# assembly is built for Hive, before actually populating the CLASSPATH with the jars.
+# Note that this check order is faster (by up to half a second) in the case where Hive is not used.
+num_datanucleus_jars=$(ls "$FWDIR"/lib_managed/jars/ 2>/dev/null | grep "datanucleus-.*\\.jar" | wc -l)
+if [ $num_datanucleus_jars -gt 0 ]; then
+ AN_ASSEMBLY_JAR=${ASSEMBLY_JAR:-$DEPS_ASSEMBLY_JAR}
+ num_hive_files=$(jar tvf "$AN_ASSEMBLY_JAR" org/apache/hadoop/hive/ql/exec 2>/dev/null | wc -l)
+ if [ $num_hive_files -gt 0 ]; then
+ echo "Spark assembly has been built with Hive, including Datanucleus jars on classpath" 1>&2
+ DATANUCLEUSJARS=$(echo "$FWDIR/lib_managed/jars"/datanucleus-*.jar | tr " " :)
+ CLASSPATH=$CLASSPATH:$DATANUCLEUSJARS
+ fi
+fi
+
# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1
if [[ $SPARK_TESTING == 1 ]]; then
CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/test-classes"
@@ -62,6 +82,9 @@ if [[ $SPARK_TESTING == 1 ]]; then
CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/test-classes"
CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/test-classes"
CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/test-classes"
fi
# Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail !
diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh
new file mode 100644
index 0000000000000..d425f9feaac54
--- /dev/null
+++ b/bin/load-spark-env.sh
@@ -0,0 +1,38 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# This script loads spark-env.sh if it exists, and ensures it is only loaded once.
+# spark-env.sh is loaded from SPARK_CONF_DIR if set, or within the current directory's
+# conf/ subdirectory.
+
+if [ -z "$SPARK_ENV_LOADED" ]; then
+ export SPARK_ENV_LOADED=1
+
+ # Returns the parent of the directory this script lives in.
+ parent_dir="$(cd `dirname $0`/..; pwd)"
+
+ use_conf_dir=${SPARK_CONF_DIR:-"$parent_dir/conf"}
+
+ if [ -f "${use_conf_dir}/spark-env.sh" ]; then
+ # Promote all variable declarations to environment (exported) variables
+ set -a
+ . "${use_conf_dir}/spark-env.sh"
+ set +a
+ fi
+fi
diff --git a/bin/pyspark b/bin/pyspark
index ed6f8da73035a..cad982bc33477 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -36,10 +36,7 @@ if [ ! -f "$FWDIR/RELEASE" ]; then
fi
fi
-# Load environment variables from conf/spark-env.sh, if it exists
-if [ -e "$FWDIR/conf/spark-env.sh" ] ; then
- . $FWDIR/conf/spark-env.sh
-fi
+. $FWDIR/bin/load-spark-env.sh
# Figure out which Python executable to use
if [ -z "$PYSPARK_PYTHON" ] ; then
@@ -58,7 +55,8 @@ if [ -n "$IPYTHON_OPTS" ]; then
IPYTHON=1
fi
-if [[ "$IPYTHON" = "1" ]] ; then
+# Only use ipython if no command line arguments were provided [SPARK-1134]
+if [[ "$IPYTHON" = "1" && $# = 0 ]] ; then
exec ipython $IPYTHON_OPTS
else
exec "$PYSPARK_PYTHON" "$@"
diff --git a/bin/run-example b/bin/run-example
index adba7dd97aaf8..d8a94f2e31e07 100755
--- a/bin/run-example
+++ b/bin/run-example
@@ -30,10 +30,7 @@ FWDIR="$(cd `dirname $0`/..; pwd)"
# Export this as SPARK_HOME
export SPARK_HOME="$FWDIR"
-# Load environment variables from conf/spark-env.sh, if it exists
-if [ -e "$FWDIR/conf/spark-env.sh" ] ; then
- . $FWDIR/conf/spark-env.sh
-fi
+. $FWDIR/bin/load-spark-env.sh
if [ -z "$1" ]; then
echo "Usage: run-example []" >&2
@@ -43,12 +40,15 @@ fi
# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack
# to avoid the -sources and -doc packages that are built by publish-local.
EXAMPLES_DIR="$FWDIR"/examples
-SPARK_EXAMPLES_JAR=""
-if [ -e "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar ]; then
- export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar`
+
+if [ -f "$FWDIR/RELEASE" ]; then
+ export SPARK_EXAMPLES_JAR=`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar`
+elif [ -e "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar ]; then
+ export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar`
fi
+
if [[ -z $SPARK_EXAMPLES_JAR ]]; then
- echo "Failed to find Spark examples assembly in $FWDIR/examples/target" >&2
+ echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" >&2
echo "You need to build Spark with sbt/sbt assembly before running this program" >&2
exit 1
fi
@@ -78,7 +78,6 @@ fi
# Set JAVA_OPTS to be able to load native libraries and to set heap size
JAVA_OPTS="$SPARK_JAVA_OPTS"
-JAVA_OPTS="$JAVA_OPTS -Djava.library.path=$SPARK_LIBRARY_PATH"
# Load extra JAVA_OPTS from conf/java-opts, if it exists
if [ -e "$FWDIR/conf/java-opts" ] ; then
JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`"
diff --git a/bin/spark-class b/bin/spark-class
index c4225a392d6da..6871e180c9fa8 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -30,44 +30,57 @@ FWDIR="$(cd `dirname $0`/..; pwd)"
# Export this as SPARK_HOME
export SPARK_HOME="$FWDIR"
-# Load environment variables from conf/spark-env.sh, if it exists
-if [ -e "$FWDIR/conf/spark-env.sh" ] ; then
- . $FWDIR/conf/spark-env.sh
-fi
+. $FWDIR/bin/load-spark-env.sh
if [ -z "$1" ]; then
echo "Usage: spark-class []" >&2
exit 1
fi
-# If this is a standalone cluster daemon, reset SPARK_JAVA_OPTS and SPARK_MEM to reasonable
-# values for that; it doesn't need a lot
-if [ "$1" = "org.apache.spark.deploy.master.Master" -o "$1" = "org.apache.spark.deploy.worker.Worker" ]; then
- SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m}
- SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true"
- # Do not overwrite SPARK_JAVA_OPTS environment variable in this script
- OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS" # Empty by default
-else
- OUR_JAVA_OPTS="$SPARK_JAVA_OPTS"
+if [ -n "$SPARK_MEM" ]; then
+ echo "Warning: SPARK_MEM is deprecated, please use a more specific config option"
+ echo "(e.g., spark.executor.memory or SPARK_DRIVER_MEMORY)."
fi
+# Use SPARK_MEM or 512m as the default memory, to be overridden by specific options
+DEFAULT_MEM=${SPARK_MEM:-512m}
+
+SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true"
-# Add java opts for master, worker, executor. The opts maybe null
+# Add java opts and memory settings for master, worker, history server, executors, and repl.
case "$1" in
+ # Master, Worker, and HistoryServer use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY.
'org.apache.spark.deploy.master.Master')
- OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_MASTER_OPTS"
+ OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_MASTER_OPTS"
+ OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM}
;;
'org.apache.spark.deploy.worker.Worker')
- OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_WORKER_OPTS"
+ OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_WORKER_OPTS"
+ OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM}
;;
+ 'org.apache.spark.deploy.history.HistoryServer')
+ OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_HISTORY_OPTS"
+ OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM}
+ ;;
+
+ # Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY.
'org.apache.spark.executor.CoarseGrainedExecutorBackend')
- OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
+ OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
+ OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM}
;;
'org.apache.spark.executor.MesosExecutorBackend')
- OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
+ OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
+ OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM}
;;
+
+ # All drivers use SPARK_JAVA_OPTS + SPARK_DRIVER_MEMORY. The repl also uses SPARK_REPL_OPTS.
'org.apache.spark.repl.Main')
- OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_REPL_OPTS"
+ OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_REPL_OPTS"
+ OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM}
+ ;;
+ *)
+ OUR_JAVA_OPTS="$SPARK_JAVA_OPTS"
+ OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM}
;;
esac
@@ -83,14 +96,10 @@ else
fi
fi
-# Set SPARK_MEM if it isn't already set since we also use it for this process
-SPARK_MEM=${SPARK_MEM:-512m}
-export SPARK_MEM
-
# Set JAVA_OPTS to be able to load native libraries and to set heap size
JAVA_OPTS="$OUR_JAVA_OPTS"
-JAVA_OPTS="$JAVA_OPTS -Djava.library.path=$SPARK_LIBRARY_PATH"
-JAVA_OPTS="$JAVA_OPTS -Xms$SPARK_MEM -Xmx$SPARK_MEM"
+JAVA_OPTS="$JAVA_OPTS -Djava.library.path=$_SPARK_LIBRARY_PATH"
+JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM"
# Load extra JAVA_OPTS from conf/java-opts, if it exists
if [ -e "$FWDIR/conf/java-opts" ] ; then
JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`"
@@ -129,8 +138,7 @@ fi
# Compute classpath using external script
CLASSPATH=`$FWDIR/bin/compute-classpath.sh`
-
-if [ "$1" == "org.apache.spark.tools.JavaAPICompletenessChecker" ]; then
+if [[ "$1" =~ org.apache.spark.tools.* ]]; then
CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR"
fi
@@ -150,5 +158,3 @@ if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then
fi
exec "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@"
-
-
diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd
index 80818c78ec24b..4302c1b6b7ff4 100755
--- a/bin/spark-class2.cmd
+++ b/bin/spark-class2.cmd
@@ -34,22 +34,48 @@ if not "x%1"=="x" goto arg_given
goto exit
:arg_given
-set RUNNING_DAEMON=0
-if "%1"=="spark.deploy.master.Master" set RUNNING_DAEMON=1
-if "%1"=="spark.deploy.worker.Worker" set RUNNING_DAEMON=1
-if "x%SPARK_DAEMON_MEMORY%" == "x" set SPARK_DAEMON_MEMORY=512m
+if not "x%SPARK_MEM%"=="x" (
+ echo Warning: SPARK_MEM is deprecated, please use a more specific config option
+ echo e.g., spark.executor.memory or SPARK_DRIVER_MEMORY.
+)
+
+rem Use SPARK_MEM or 512m as the default memory, to be overridden by specific options
+set OUR_JAVA_MEM=%SPARK_MEM%
+if "x%OUR_JAVA_MEM%"=="x" set OUR_JAVA_MEM=512m
+
set SPARK_DAEMON_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -Dspark.akka.logLifecycleEvents=true
-if "%RUNNING_DAEMON%"=="1" set SPARK_MEM=%SPARK_DAEMON_MEMORY%
-rem Do not overwrite SPARK_JAVA_OPTS environment variable in this script
-if "%RUNNING_DAEMON%"=="0" set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS%
-if "%RUNNING_DAEMON%"=="1" set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS%
-rem Figure out how much memory to use per executor and set it as an environment
-rem variable so that our process sees it and can report it to Mesos
-if "x%SPARK_MEM%"=="x" set SPARK_MEM=512m
+rem Add java opts and memory settings for master, worker, history server, executors, and repl.
+rem Master, Worker and HistoryServer use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY.
+if "%1"=="org.apache.spark.deploy.master.Master" (
+ set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_MASTER_OPTS%
+ if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY%
+) else if "%1"=="org.apache.spark.deploy.worker.Worker" (
+ set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_WORKER_OPTS%
+ if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY%
+) else if "%1"=="org.apache.spark.deploy.history.HistoryServer" (
+ set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_HISTORY_OPTS%
+ if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY%
+
+rem Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY.
+) else if "%1"=="org.apache.spark.executor.CoarseGrainedExecutorBackend" (
+ set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_EXECUTOR_OPTS%
+ if not "x%SPARK_EXECUTOR_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_EXECUTOR_MEMORY%
+) else if "%1"=="org.apache.spark.executor.MesosExecutorBackend" (
+ set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_EXECUTOR_OPTS%
+ if not "x%SPARK_EXECUTOR_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_EXECUTOR_MEMORY%
+
+rem All drivers use SPARK_JAVA_OPTS + SPARK_DRIVER_MEMORY. The repl also uses SPARK_REPL_OPTS.
+) else if "%1"=="org.apache.spark.repl.Main" (
+ set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_REPL_OPTS%
+ if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY%
+) else (
+ set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS%
+ if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY%
+)
rem Set JAVA_OPTS to be able to load native libraries and to set heap size
-set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM%
+set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%OUR_JAVA_MEM% -Xmx%OUR_JAVA_MEM%
rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala!
rem Test whether the user has built Spark
diff --git a/bin/spark-shell b/bin/spark-shell
index 2bff06cf70051..f1f3c18877ed4 100755
--- a/bin/spark-shell
+++ b/bin/spark-shell
@@ -19,9 +19,8 @@
#
# Shell script for starting the Spark Shell REPL
-# Note that it will set MASTER to spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT}
-# if those two env vars are set in spark-env.sh but MASTER is not.
+args="$@"
cygwin=false
case "`uname`" in
CYGWIN*) cygwin=true;;
@@ -30,71 +29,31 @@ esac
# Enter posix mode for bash
set -o posix
-CORE_PATTERN="^[0-9]+$"
-MEM_PATTERN="^[0-9]+[m|g|M|G]$"
-
-FWDIR="$(cd `dirname $0`/..; pwd)"
-
-if [ "$1" = "--help" ] || [ "$1" = "-h" ]; then
- echo "Usage: spark-shell [OPTIONS]"
- echo "OPTIONS:"
- echo "-c --cores num, the maximum number of cores to be used by the spark shell"
- echo "-em --execmem num[m|g], the memory used by each executor of spark shell"
- echo "-dm --drivermem num[m|g], the memory used by the spark shell and driver"
- echo "-h --help, print this help information"
- exit
+if [[ "$@" == *--help* ]]; then
+ echo "Usage: ./bin/spark-shell [options]"
+ ./bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
+ exit 0
fi
-SPARK_SHELL_OPTS=""
-
-for o in "$@"; do
- if [ "$1" = "-c" -o "$1" = "--cores" ]; then
- shift
- if [[ "$1" =~ $CORE_PATTERN ]]; then
- SPARK_SHELL_OPTS="$SPARK_SHELL_OPTS -Dspark.cores.max=$1"
- shift
- else
- echo "ERROR: wrong format for -c/--cores"
- exit 1
- fi
- fi
- if [ "$1" = "-em" -o "$1" = "--execmem" ]; then
- shift
- if [[ $1 =~ $MEM_PATTERN ]]; then
- SPARK_SHELL_OPTS="$SPARK_SHELL_OPTS -Dspark.executor.memory=$1"
- shift
- else
- echo "ERROR: wrong format for --execmem/-em"
- exit 1
- fi
- fi
- if [ "$1" = "-dm" -o "$1" = "--drivermem" ]; then
- shift
- if [[ $1 =~ $MEM_PATTERN ]]; then
- export SPARK_MEM=$1
- shift
- else
- echo "ERROR: wrong format for --drivermem/-dm"
- exit 1
- fi
- fi
-done
+## Global script variables
+FWDIR="$(cd `dirname $0`/..; pwd)"
-# Set MASTER from spark-env if possible
-DEFAULT_SPARK_MASTER_PORT=7077
-if [ -z "$MASTER" ]; then
- if [ -e "$FWDIR/conf/spark-env.sh" ]; then
- . "$FWDIR/conf/spark-env.sh"
- fi
- if [ "x" != "x$SPARK_MASTER_IP" ]; then
- if [ "y" != "y$SPARK_MASTER_PORT" ]; then
- SPARK_MASTER_PORT="${SPARK_MASTER_PORT}"
+function main(){
+ if $cygwin; then
+ # Workaround for issue involving JLine and Cygwin
+ # (see http://sourceforge.net/p/jline/bugs/40/).
+ # If you're using the Mintty terminal emulator in Cygwin, may need to set the
+ # "Backspace sends ^H" setting in "Keys" section of the Mintty options
+ # (see https://github.com/sbt/sbt/issues/562).
+ stty -icanon min 1 -echo > /dev/null 2>&1
+ export SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Djline.terminal=unix"
+ $FWDIR/bin/spark-submit spark-internal "$args" --class org.apache.spark.repl.Main
+ stty icanon echo > /dev/null 2>&1
else
- SPARK_MASTER_PORT=$DEFAULT_SPARK_MASTER_PORT
+ export SPARK_REPL_OPTS
+ $FWDIR/bin/spark-submit spark-internal "$args" --class org.apache.spark.repl.Main
fi
- export MASTER="spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT}"
- fi
-fi
+}
# Copy restore-TTY-on-exit functions from Scala script so spark-shell exits properly even in
# binary distribution of Spark where Scala is not installed
@@ -124,20 +83,10 @@ if [[ ! $? ]]; then
saved_stty=""
fi
-if $cygwin; then
- # Workaround for issue involving JLine and Cygwin
- # (see http://sourceforge.net/p/jline/bugs/40/).
- # If you're using the Mintty terminal emulator in Cygwin, may need to set the
- # "Backspace sends ^H" setting in "Keys" section of the Mintty options
- # (see https://github.com/sbt/sbt/issues/562).
- stty -icanon min 1 -echo > /dev/null 2>&1
- $FWDIR/bin/spark-class -Djline.terminal=unix $SPARK_SHELL_OPTS org.apache.spark.repl.Main "$@"
- stty icanon echo > /dev/null 2>&1
-else
- $FWDIR/bin/spark-class $SPARK_SHELL_OPTS org.apache.spark.repl.Main "$@"
-fi
+main
# record the exit status lest it be overwritten:
# then reenable echo and propagate the code.
exit_status=$?
onExit
+
diff --git a/bin/spark-submit b/bin/spark-submit
new file mode 100755
index 0000000000000..b2a1dca721dff
--- /dev/null
+++ b/bin/spark-submit
@@ -0,0 +1,43 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+export SPARK_HOME="$(cd `dirname $0`/..; pwd)"
+ORIG_ARGS=$@
+
+while (($#)); do
+ if [ "$1" = "--deploy-mode" ]; then
+ DEPLOY_MODE=$2
+ elif [ "$1" = "--driver-memory" ]; then
+ DRIVER_MEMORY=$2
+ elif [ "$1" = "--driver-library-path" ]; then
+ export _SPARK_LIBRARY_PATH=$2
+ elif [ "$1" = "--driver-class-path" ]; then
+ export SPARK_CLASSPATH="$SPARK_CLASSPATH:$2"
+ elif [ "$1" = "--driver-java-options" ]; then
+ export SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $2"
+ fi
+ shift
+done
+
+if [ ! -z $DRIVER_MEMORY ] && [ ! -z $DEPLOY_MODE ] && [ $DEPLOY_MODE = "client" ]; then
+ export SPARK_MEM=$DRIVER_MEMORY
+fi
+
+$SPARK_HOME/bin/spark-class org.apache.spark.deploy.SparkSubmit $ORIG_ARGS
+
diff --git a/conf/spark-defaults.conf.template b/conf/spark-defaults.conf.template
new file mode 100644
index 0000000000000..f840ff681d019
--- /dev/null
+++ b/conf/spark-defaults.conf.template
@@ -0,0 +1,7 @@
+# Default system properties included when running spark-submit.
+# This is useful for setting default environmental settings.
+
+# Example:
+# spark.master spark://master:7077
+# spark.eventLog.enabled true
+# spark.eventLog.dir hdfs://namenode:8021/directory
diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template
index 6432a566089be..f906be611a931 100755
--- a/conf/spark-env.sh.template
+++ b/conf/spark-env.sh.template
@@ -1,22 +1,43 @@
#!/usr/bin/env bash
-# This file contains environment variables required to run Spark. Copy it as
-# spark-env.sh and edit that to configure Spark for your site.
-#
-# The following variables can be set in this file:
+# This file is sourced when running various Spark programs.
+# Copy it as spark-env.sh and edit that to configure Spark for your site.
+
+# Options read when launching programs locally with
+# ./bin/run-example or ./bin/spark-submit
+# - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files
+# - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node
+# - SPARK_PUBLIC_DNS, to set the public dns name of the driver program
+# - SPARK_CLASSPATH, default classpath entries to append
+
+# Options read by executors and drivers running inside the cluster
# - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node
+# - SPARK_PUBLIC_DNS, to set the public DNS name of the driver program
+# - SPARK_CLASSPATH, default classpath entries to append
+# - SPARK_LOCAL_DIRS, storage directories to use on this node for shuffle and RDD data
# - MESOS_NATIVE_LIBRARY, to point to your libmesos.so if you use Mesos
-# - SPARK_JAVA_OPTS, to set node-specific JVM options for Spark. Note that
-# we recommend setting app-wide options in the application's driver program.
-# Examples of node-specific options : -Dspark.local.dir, GC options
-# Examples of app-wide options : -Dspark.serializer
-#
-# If using the standalone deploy mode, you can also set variables for it here:
+
+# Options read in YARN client mode
+# - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files
+# - SPARK_EXECUTOR_INSTANCES, Number of workers to start (Default: 2)
+# - SPARK_EXECUTOR_CORES, Number of cores for the workers (Default: 1).
+# - SPARK_EXECUTOR_MEMORY, Memory per Worker (e.g. 1000M, 2G) (Default: 1G)
+# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)
+# - SPARK_YARN_APP_NAME, The name of your application (Default: Spark)
+# - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: ‘default’)
+# - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job.
+# - SPARK_YARN_DIST_ARCHIVES, Comma separated list of archives to be distributed with the job.
+
+# Options for the daemons used in the standalone deploy mode:
# - SPARK_MASTER_IP, to bind the master to a different IP address or hostname
# - SPARK_MASTER_PORT / SPARK_MASTER_WEBUI_PORT, to use non-default ports
+# - SPARK_MASTER_OPTS, to set config properties only for the master (e.g. "-Dx=y")
# - SPARK_WORKER_CORES, to set the number of cores to use on this machine
-# - SPARK_WORKER_MEMORY, to set how much memory to use (e.g. 1000m, 2g)
+# - SPARK_WORKER_MEMORY, to set how much total memory workers have to give executors (e.g. 1000m, 2g)
# - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT
# - SPARK_WORKER_INSTANCES, to set the number of worker processes per node
# - SPARK_WORKER_DIR, to set the working directory of worker processes
-# - SPARK_PUBLIC_DNS, to set the public dns name of the master
+# - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y")
+# - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y")
+# - SPARK_DAEMON_OPTS, to set config properties for all daemons (e.g. "-Dx=y")
+# - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers
diff --git a/core/pom.xml b/core/pom.xml
index 5576b0c3b4795..058b7acba73ca 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -17,242 +17,305 @@
-->
- 4.0.0
-
- org.apache.spark
- spark-parent
- 1.0.0-incubating-SNAPSHOT
- ../pom.xml
-
-
+ 4.0.0
+ org.apache.spark
- spark-core_2.10
- jar
- Spark Project Core
- http://spark.incubator.apache.org/
+ spark-parent
+ 1.0.0-SNAPSHOT
+ ../pom.xml
+
-
-
- org.apache.hadoop
- hadoop-client
-
-
- net.java.dev.jets3t
- jets3t
-
-
- commons-logging
- commons-logging
-
-
-
-
- org.apache.avro
- avro
-
-
- org.apache.avro
- avro-ipc
-
-
- org.apache.zookeeper
- zookeeper
-
-
- org.eclipse.jetty
- jetty-server
-
-
- com.google.guava
- guava
-
-
- com.google.code.findbugs
- jsr305
-
-
- org.slf4j
- slf4j-api
-
-
- org.slf4j
- jul-to-slf4j
-
-
- org.slf4j
- jcl-over-slf4j
-
-
- log4j
- log4j
-
-
- org.slf4j
- slf4j-log4j12
-
-
- com.ning
- compress-lzf
-
-
- org.xerial.snappy
- snappy-java
-
-
- org.ow2.asm
- asm
-
-
- com.twitter
- chill_${scala.binary.version}
- 0.3.1
-
-
- com.twitter
- chill-java
- 0.3.1
-
-
- ${akka.group}
- akka-remote_${scala.binary.version}
-
-
- ${akka.group}
- akka-slf4j_${scala.binary.version}
-
-
- ${akka.group}
- akka-testkit_${scala.binary.version}
- test
-
-
- org.scala-lang
- scala-library
-
-
- net.liftweb
- lift-json_${scala.binary.version}
-
-
- it.unimi.dsi
- fastutil
-
-
- colt
- colt
-
-
- org.apache.mesos
- mesos
-
-
- io.netty
- netty-all
-
-
- com.clearspring.analytics
- stream
-
-
- com.codahale.metrics
- metrics-core
-
-
- com.codahale.metrics
- metrics-jvm
-
-
- com.codahale.metrics
- metrics-json
-
-
- com.codahale.metrics
- metrics-ganglia
-
-
- com.codahale.metrics
- metrics-graphite
-
-
- org.apache.derby
- derby
- test
-
-
- commons-io
- commons-io
- test
-
-
- org.scalatest
- scalatest_${scala.binary.version}
- test
-
-
- org.mockito
- mockito-all
- test
-
-
- org.scalacheck
- scalacheck_${scala.binary.version}
- test
-
-
- org.easymock
- easymock
- test
-
-
- com.novocode
- junit-interface
- test
-
-
-
- target/scala-${scala.binary.version}/classes
- target/scala-${scala.binary.version}/test-classes
-
-
- org.apache.maven.plugins
- maven-antrun-plugin
-
-
- test
-
- run
-
-
- true
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- org.scalatest
- scalatest-maven-plugin
-
-
- ${basedir}/..
- 1
- ${spark.classpath}
-
-
-
-
-
+ org.apache.spark
+ spark-core_2.10
+ jar
+ Spark Project Core
+ http://spark.apache.org/
+
+
+
+ yarn-alpha
+
+
+ org.apache.avro
+ avro
+
+
+
+
+
+
+
+ org.apache.hadoop
+ hadoop-client
+
+
+ net.java.dev.jets3t
+ jets3t
+
+
+ commons-logging
+ commons-logging
+
+
+
+
+ org.apache.curator
+ curator-recipes
+
+
+ org.eclipse.jetty
+ jetty-plus
+
+
+ org.eclipse.jetty
+ jetty-security
+
+
+ org.eclipse.jetty
+ jetty-util
+
+
+ org.eclipse.jetty
+ jetty-server
+
+
+ com.google.guava
+ guava
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.slf4j
+ jul-to-slf4j
+
+
+ org.slf4j
+ jcl-over-slf4j
+
+
+ log4j
+ log4j
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+ com.ning
+ compress-lzf
+
+
+ org.xerial.snappy
+ snappy-java
+
+
+ com.twitter
+ chill_${scala.binary.version}
+
+
+ com.twitter
+ chill-java
+
+
+ commons-net
+ commons-net
+
+
+ ${akka.group}
+ akka-remote_${scala.binary.version}
+
+
+ ${akka.group}
+ akka-slf4j_${scala.binary.version}
+
+
+ ${akka.group}
+ akka-testkit_${scala.binary.version}
+ test
+
+
+ org.scala-lang
+ scala-library
+
+
+ org.json4s
+ json4s-jackson_${scala.binary.version}
+ 3.2.6
+
+
+
+ org.scala-lang
+ scalap
+
+
+
+
+ colt
+ colt
+
+
+ org.apache.mesos
+ mesos
+
+
+ io.netty
+ netty-all
+
+
+ com.clearspring.analytics
+ stream
+
+
+ com.codahale.metrics
+ metrics-core
+
+
+ com.codahale.metrics
+ metrics-jvm
+
+
+ com.codahale.metrics
+ metrics-json
+
+
+ com.codahale.metrics
+ metrics-graphite
+
+
+ org.apache.derby
+ derby
+ test
+
+
+ org.tachyonproject
+ tachyon
+ 0.4.1-thrift
+
+
+ org.apache.hadoop
+ hadoop-client
+
+
+ org.apache.curator
+ curator-recipes
+
+
+ org.eclipse.jetty
+ jetty-jsp
+
+
+ org.eclipse.jetty
+ jetty-webapp
+
+
+ org.eclipse.jetty
+ jetty-server
+
+
+ org.eclipse.jetty
+ jetty-servlet
+
+
+ junit
+ junit
+
+
+ org.powermock
+ powermock-module-junit4
+
+
+ org.powermock
+ powermock-api-mockito
+
+
+ org.apache.curator
+ curator-test
+
+
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ test
+
+
+ org.mockito
+ mockito-all
+ test
+
+
+ org.scalacheck
+ scalacheck_${scala.binary.version}
+ test
+
+
+ org.easymock
+ easymock
+ test
+
+
+ com.novocode
+ junit-interface
+ test
+
+
+ org.spark-project
+ pyrolite
+ 2.0.1
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ org.apache.maven.plugins
+ maven-antrun-plugin
+
+
+ test
+
+ run
+
+
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ org.scalatest
+ scalatest-maven-plugin
+
+
+ ${basedir}/..
+ 1
+ ${spark.classpath}
+
+
+
+
+
diff --git a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java
index 9f13b39909481..840a1bd93bfbb 100644
--- a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java
+++ b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java
@@ -23,17 +23,18 @@
* Expose some commonly useful storage level constants.
*/
public class StorageLevels {
- public static final StorageLevel NONE = create(false, false, false, 1);
- public static final StorageLevel DISK_ONLY = create(true, false, false, 1);
- public static final StorageLevel DISK_ONLY_2 = create(true, false, false, 2);
- public static final StorageLevel MEMORY_ONLY = create(false, true, true, 1);
- public static final StorageLevel MEMORY_ONLY_2 = create(false, true, true, 2);
- public static final StorageLevel MEMORY_ONLY_SER = create(false, true, false, 1);
- public static final StorageLevel MEMORY_ONLY_SER_2 = create(false, true, false, 2);
- public static final StorageLevel MEMORY_AND_DISK = create(true, true, true, 1);
- public static final StorageLevel MEMORY_AND_DISK_2 = create(true, true, true, 2);
- public static final StorageLevel MEMORY_AND_DISK_SER = create(true, true, false, 1);
- public static final StorageLevel MEMORY_AND_DISK_SER_2 = create(true, true, false, 2);
+ public static final StorageLevel NONE = create(false, false, false, false, 1);
+ public static final StorageLevel DISK_ONLY = create(true, false, false, false, 1);
+ public static final StorageLevel DISK_ONLY_2 = create(true, false, false, false, 2);
+ public static final StorageLevel MEMORY_ONLY = create(false, true, false, true, 1);
+ public static final StorageLevel MEMORY_ONLY_2 = create(false, true, false, true, 2);
+ public static final StorageLevel MEMORY_ONLY_SER = create(false, true, false, false, 1);
+ public static final StorageLevel MEMORY_ONLY_SER_2 = create(false, true, false, false, 2);
+ public static final StorageLevel MEMORY_AND_DISK = create(true, true, false, true, 1);
+ public static final StorageLevel MEMORY_AND_DISK_2 = create(true, true, false, true, 2);
+ public static final StorageLevel MEMORY_AND_DISK_SER = create(true, true, false, false, 1);
+ public static final StorageLevel MEMORY_AND_DISK_SER_2 = create(true, true, false, false, 2);
+ public static final StorageLevel OFF_HEAP = create(false, false, true, false, 1);
/**
* Create a new StorageLevel object.
@@ -42,7 +43,26 @@ public class StorageLevels {
* @param deserialized saved as deserialized objects, if true
* @param replication replication factor
*/
- public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) {
- return StorageLevel.apply(useDisk, useMemory, deserialized, replication);
+ @Deprecated
+ public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized,
+ int replication) {
+ return StorageLevel.apply(useDisk, useMemory, false, deserialized, replication);
+ }
+
+ /**
+ * Create a new StorageLevel object.
+ * @param useDisk saved to disk, if true
+ * @param useMemory saved to memory, if true
+ * @param useOffHeap saved to Tachyon, if true
+ * @param deserialized saved as deserialized objects, if true
+ * @param replication replication factor
+ */
+ public static StorageLevel create(
+ boolean useDisk,
+ boolean useMemory,
+ boolean useOffHeap,
+ boolean deserialized,
+ int replication) {
+ return StorageLevel.apply(useDisk, useMemory, useOffHeap, deserialized, replication);
}
}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
new file mode 100644
index 0000000000000..57fd0a7a80494
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * A function that returns zero or more records of type Double from each input record.
+ */
+public interface DoubleFlatMapFunction extends Serializable {
+ public Iterable call(T t) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java b/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java
new file mode 100644
index 0000000000000..150144e0e418c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * A function that returns Doubles, and can be used to construct DoubleRDDs.
+ */
+public interface DoubleFunction extends Serializable {
+ public double call(T t) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java
new file mode 100644
index 0000000000000..23f5fdd43631b
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * A function that returns zero or more output records from each input record.
+ */
+public interface FlatMapFunction extends Serializable {
+ public Iterable call(T t) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java
new file mode 100644
index 0000000000000..c48e92f535ff5
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * A function that takes two inputs and returns zero or more output records.
+ */
+public interface FlatMapFunction2 extends Serializable {
+ public Iterable call(T1 t1, T2 t2) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function.java b/core/src/main/java/org/apache/spark/api/java/function/Function.java
new file mode 100644
index 0000000000000..d00551bb0add6
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * Base interface for functions whose return types do not create special RDDs. PairFunction and
+ * DoubleFunction are handled separately, to allow PairRDDs and DoubleRDDs to be constructed
+ * when mapping RDDs of other types.
+ */
+public interface Function extends Serializable {
+ public R call(T1 v1) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function2.java b/core/src/main/java/org/apache/spark/api/java/function/Function2.java
new file mode 100644
index 0000000000000..793caaa61ac5a
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function2.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * A two-argument function that takes arguments of type T1 and T2 and returns an R.
+ */
+public interface Function2 extends Serializable {
+ public R call(T1 v1, T2 v2) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function3.java b/core/src/main/java/org/apache/spark/api/java/function/Function3.java
new file mode 100644
index 0000000000000..b4151c3417df4
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function3.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * A three-argument function that takes arguments of type T1, T2 and T3 and returns an R.
+ */
+public interface Function3 extends Serializable {
+ public R call(T1 v1, T2 v2, T3 v3) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java
new file mode 100644
index 0000000000000..691ef2eceb1f6
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+import scala.Tuple2;
+
+/**
+ * A function that returns zero or more key-value pair records from each input record. The
+ * key-value pairs are represented as scala.Tuple2 objects.
+ */
+public interface PairFlatMapFunction extends Serializable {
+ public Iterable> call(T t) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
new file mode 100644
index 0000000000000..abd9bcc07ac61
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+import scala.Tuple2;
+
+/**
+ * A function that returns key-value pairs (Tuple2), and can be used to construct PairRDDs.
+ */
+public interface PairFunction extends Serializable {
+ public Tuple2 call(T t) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java
new file mode 100644
index 0000000000000..2a10435b7523a
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * A function with no return value.
+ */
+public interface VoidFunction extends Serializable {
+ public void call(T t) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/package-info.java b/core/src/main/java/org/apache/spark/package-info.java
new file mode 100644
index 0000000000000..4426c7afcebdd
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/package-info.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Core Spark classes in Scala. A few classes here, such as {@link org.apache.spark.Accumulator}
+ * and {@link org.apache.spark.storage.StorageLevel}, are also used in Java, but the
+ * {@link org.apache.spark.api.java} package contains the main Java API.
+ */
+package org.apache.spark;
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index fe54c34ffb1da..599c3ac9b57c0 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -78,3 +78,12 @@ table.sortable thead {
background-repeat: repeat-x;
filter: progid:dximagetransform.microsoft.gradient(startColorstr='#FFA4EDFF', endColorstr='#FF94DDFF', GradientType=0);
}
+
+span.kill-link {
+ margin-right: 2px;
+ color: gray;
+}
+
+span.kill-link a {
+ color: gray;
+}
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index d5f3e3f6ec496..6d652faae149a 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -104,8 +104,11 @@ class Accumulable[R, T] (
* Set the accumulator's value; only allowed on master.
*/
def value_= (newValue: R) {
- if (!deserialized) value_ = newValue
- else throw new UnsupportedOperationException("Can't assign accumulator value in task")
+ if (!deserialized) {
+ value_ = newValue
+ } else {
+ throw new UnsupportedOperationException("Can't assign accumulator value in task")
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index c4579cf6ad560..59fdf659c9e11 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -17,17 +17,18 @@
package org.apache.spark
-import scala.{Option, deprecated}
-
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap}
/**
+ * :: DeveloperApi ::
* A set of functions used to aggregate data.
*
* @param createCombiner function to create the initial value of the aggregation.
* @param mergeValue function to merge a new value into the aggregation result.
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
*/
+@DeveloperApi
case class Aggregator[K, V, C] (
createCombiner: V => C,
mergeValue: (C, V) => C,
diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
index 754b46a4c7df2..a67392441ed29 100644
--- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
@@ -79,7 +79,6 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
- shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 1daabecf23292..811610c657b62 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -20,11 +20,12 @@ package org.apache.spark
import scala.collection.mutable.{ArrayBuffer, HashSet}
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.{BlockManager, RDDBlockId, StorageLevel}
+import org.apache.spark.storage.{BlockId, BlockManager, BlockStatus, RDDBlockId, StorageLevel}
-/** Spark class responsible for passing RDDs split contents to the BlockManager and making
- sure a node doesn't load two copies of an RDD at once.
- */
+/**
+ * Spark class responsible for passing RDDs split contents to the BlockManager and making
+ * sure a node doesn't load two copies of an RDD at once.
+ */
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
/** Keys of RDD splits that are being computed/loaded. */
@@ -46,14 +47,19 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
if (loading.contains(key)) {
logInfo("Another thread is loading %s, waiting for it to finish...".format(key))
while (loading.contains(key)) {
- try {loading.wait()} catch {case _ : Throwable =>}
+ try {
+ loading.wait()
+ } catch {
+ case e: Exception =>
+ logWarning(s"Got an exception while waiting for another thread to load $key", e)
+ }
}
logInfo("Finished waiting for %s".format(key))
- // See whether someone else has successfully loaded it. The main way this would fail
- // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
- // partition but we didn't want to make space for it. However, that case is unlikely
- // because it's unlikely that two threads would work on the same RDD partition. One
- // downside of the current code is that threads wait serially if this does happen.
+ /* See whether someone else has successfully loaded it. The main way this would fail
+ * is for the RDD-level cache eviction policy if someone else has loaded the same RDD
+ * partition but we didn't want to make space for it. However, that case is unlikely
+ * because it's unlikely that two threads would work on the same RDD partition. One
+ * downside of the current code is that threads wait serially if this does happen. */
blockManager.get(key) match {
case Some(values) =>
return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
@@ -69,12 +75,47 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
// If we got here, we have to load the split
logInfo("Partition %s not found, computing it".format(key))
val computedValues = rdd.computeOrReadCheckpoint(split, context)
+
// Persist the result, so long as the task is not running locally
- if (context.runningLocally) { return computedValues }
- val elements = new ArrayBuffer[Any]
- elements ++= computedValues
- blockManager.put(key, elements, storageLevel, tellMaster = true)
- elements.iterator.asInstanceOf[Iterator[T]]
+ if (context.runningLocally) {
+ return computedValues
+ }
+
+ // Keep track of blocks with updated statuses
+ var updatedBlocks = Seq[(BlockId, BlockStatus)]()
+ val returnValue: Iterator[T] = {
+ if (storageLevel.useDisk && !storageLevel.useMemory) {
+ /* In the case that this RDD is to be persisted using DISK_ONLY
+ * the iterator will be passed directly to the blockManager (rather then
+ * caching it to an ArrayBuffer first), then the resulting block data iterator
+ * will be passed back to the user. If the iterator generates a lot of data,
+ * this means that it doesn't all have to be held in memory at one time.
+ * This could also apply to MEMORY_ONLY_SER storage, but we need to make sure
+ * blocks aren't dropped by the block store before enabling that. */
+ updatedBlocks = blockManager.put(key, computedValues, storageLevel, tellMaster = true)
+ blockManager.get(key) match {
+ case Some(values) =>
+ values.asInstanceOf[Iterator[T]]
+ case None =>
+ logInfo("Failure to store %s".format(key))
+ throw new Exception("Block manager failed to return persisted valued")
+ }
+ } else {
+ // In this case the RDD is cached to an array buffer. This will save the results
+ // if we're dealing with a 'one-time' iterator
+ val elements = new ArrayBuffer[Any]
+ elements ++= computedValues
+ updatedBlocks = blockManager.put(key, elements, storageLevel, tellMaster = true)
+ elements.iterator.asInstanceOf[Iterator[T]]
+ }
+ }
+
+ // Update task metrics to include any blocks whose storage status is updated
+ val metrics = context.taskMetrics
+ metrics.updatedBlocks = Some(updatedBlocks)
+
+ new InterruptibleIterator(context, returnValue)
+
} finally {
loading.synchronized {
loading.remove(key)
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
new file mode 100644
index 0000000000000..54e08d7866f75
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.lang.ref.{ReferenceQueue, WeakReference}
+
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+
+/**
+ * Classes that represent cleaning tasks.
+ */
+private sealed trait CleanupTask
+private case class CleanRDD(rddId: Int) extends CleanupTask
+private case class CleanShuffle(shuffleId: Int) extends CleanupTask
+private case class CleanBroadcast(broadcastId: Long) extends CleanupTask
+
+/**
+ * A WeakReference associated with a CleanupTask.
+ *
+ * When the referent object becomes only weakly reachable, the corresponding
+ * CleanupTaskWeakReference is automatically added to the given reference queue.
+ */
+private class CleanupTaskWeakReference(
+ val task: CleanupTask,
+ referent: AnyRef,
+ referenceQueue: ReferenceQueue[AnyRef])
+ extends WeakReference(referent, referenceQueue)
+
+/**
+ * An asynchronous cleaner for RDD, shuffle, and broadcast state.
+ *
+ * This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest,
+ * to be processed when the associated object goes out of scope of the application. Actual
+ * cleanup is performed in a separate daemon thread.
+ */
+private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
+
+ private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference]
+ with SynchronizedBuffer[CleanupTaskWeakReference]
+
+ private val referenceQueue = new ReferenceQueue[AnyRef]
+
+ private val listeners = new ArrayBuffer[CleanerListener]
+ with SynchronizedBuffer[CleanerListener]
+
+ private val cleaningThread = new Thread() { override def run() { keepCleaning() }}
+
+ /**
+ * Whether the cleaning thread will block on cleanup tasks.
+ * This is set to true only for tests.
+ */
+ private val blockOnCleanupTasks = sc.conf.getBoolean(
+ "spark.cleaner.referenceTracking.blocking", false)
+
+ @volatile private var stopped = false
+
+ /** Attach a listener object to get information of when objects are cleaned. */
+ def attachListener(listener: CleanerListener) {
+ listeners += listener
+ }
+
+ /** Start the cleaner. */
+ def start() {
+ cleaningThread.setDaemon(true)
+ cleaningThread.setName("Spark Context Cleaner")
+ cleaningThread.start()
+ }
+
+ /** Stop the cleaner. */
+ def stop() {
+ stopped = true
+ }
+
+ /** Register a RDD for cleanup when it is garbage collected. */
+ def registerRDDForCleanup(rdd: RDD[_]) {
+ registerForCleanup(rdd, CleanRDD(rdd.id))
+ }
+
+ /** Register a ShuffleDependency for cleanup when it is garbage collected. */
+ def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
+ registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
+ }
+
+ /** Register a Broadcast for cleanup when it is garbage collected. */
+ def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) {
+ registerForCleanup(broadcast, CleanBroadcast(broadcast.id))
+ }
+
+ /** Register an object for cleanup. */
+ private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) {
+ referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
+ }
+
+ /** Keep cleaning RDD, shuffle, and broadcast state. */
+ private def keepCleaning() {
+ while (!stopped) {
+ try {
+ val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT))
+ .map(_.asInstanceOf[CleanupTaskWeakReference])
+ reference.map(_.task).foreach { task =>
+ logDebug("Got cleaning task " + task)
+ referenceBuffer -= reference.get
+ task match {
+ case CleanRDD(rddId) =>
+ doCleanupRDD(rddId, blocking = blockOnCleanupTasks)
+ case CleanShuffle(shuffleId) =>
+ doCleanupShuffle(shuffleId, blocking = blockOnCleanupTasks)
+ case CleanBroadcast(broadcastId) =>
+ doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
+ }
+ }
+ } catch {
+ case t: Throwable => logError("Error in cleaning thread", t)
+ }
+ }
+ }
+
+ /** Perform RDD cleanup. */
+ def doCleanupRDD(rddId: Int, blocking: Boolean) {
+ try {
+ logDebug("Cleaning RDD " + rddId)
+ sc.unpersistRDD(rddId, blocking)
+ listeners.foreach(_.rddCleaned(rddId))
+ logInfo("Cleaned RDD " + rddId)
+ } catch {
+ case t: Throwable => logError("Error cleaning RDD " + rddId, t)
+ }
+ }
+
+ /** Perform shuffle cleanup, asynchronously. */
+ def doCleanupShuffle(shuffleId: Int, blocking: Boolean) {
+ try {
+ logDebug("Cleaning shuffle " + shuffleId)
+ mapOutputTrackerMaster.unregisterShuffle(shuffleId)
+ blockManagerMaster.removeShuffle(shuffleId, blocking)
+ listeners.foreach(_.shuffleCleaned(shuffleId))
+ logInfo("Cleaned shuffle " + shuffleId)
+ } catch {
+ case t: Throwable => logError("Error cleaning shuffle " + shuffleId, t)
+ }
+ }
+
+ /** Perform broadcast cleanup. */
+ def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) {
+ try {
+ logDebug("Cleaning broadcast " + broadcastId)
+ broadcastManager.unbroadcast(broadcastId, true, blocking)
+ listeners.foreach(_.broadcastCleaned(broadcastId))
+ logInfo("Cleaned broadcast " + broadcastId)
+ } catch {
+ case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t)
+ }
+ }
+
+ private def blockManagerMaster = sc.env.blockManager.master
+ private def broadcastManager = sc.env.broadcastManager
+ private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+
+ // Used for testing. These methods explicitly blocks until cleanup is completed
+ // to ensure that more reliable testing.
+}
+
+private object ContextCleaner {
+ private val REF_QUEUE_POLL_TIMEOUT = 100
+}
+
+/**
+ * Listener class used for testing when any item has been cleaned by the Cleaner class.
+ */
+private[spark] trait CleanerListener {
+ def rddCleaned(rddId: Int)
+ def shuffleCleaned(shuffleId: Int)
+ def broadcastCleaned(broadcastId: Long)
+}
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index cc30105940d1a..2c31cc20211ff 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -17,18 +17,24 @@
package org.apache.spark
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
+import org.apache.spark.serializer.Serializer
/**
+ * :: DeveloperApi ::
* Base class for dependencies.
*/
+@DeveloperApi
abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
/**
+ * :: DeveloperApi ::
* Base class for dependencies where each partition of the parent RDD is used by at most one
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
*/
+@DeveloperApi
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
/**
* Get the parent partitions for a child partition.
@@ -40,36 +46,46 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
/**
+ * :: DeveloperApi ::
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
- * @param serializerClass class name of the serializer to use
+ * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to null,
+ * the default serializer, as specified by `spark.serializer` config option, will
+ * be used.
*/
+@DeveloperApi
class ShuffleDependency[K, V](
@transient rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
- val serializerClass: String = null)
+ val serializer: Serializer = null)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
val shuffleId: Int = rdd.context.newShuffleId()
+
+ rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}
/**
+ * :: DeveloperApi ::
* Represents a one-to-one dependency between partitions of the parent and child RDDs.
*/
+@DeveloperApi
class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
override def getParents(partitionId: Int) = List(partitionId)
}
/**
+ * :: DeveloperApi ::
* Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs.
* @param rdd the parent RDD
* @param inStart the start of the range in the parent RDD
* @param outStart the start of the range in the child RDD
* @param length the length of the range
*/
+@DeveloperApi
class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
extends NarrowDependency[T](rdd) {
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index f2decd14ef6d9..1e4dec86a0530 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -21,13 +21,16 @@ import scala.concurrent._
import scala.concurrent.duration.Duration
import scala.util.Try
+import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter}
/**
+ * :: Experimental ::
* A future for the result of an action to support cancellation. This is an extension of the
* Scala Future interface to support cancellation.
*/
+@Experimental
trait FutureAction[T] extends Future[T] {
// Note that we redefine methods of the Future trait here explicitly so we can specify a different
// documentation (with reference to the word "action").
@@ -84,9 +87,11 @@ trait FutureAction[T] extends Future[T] {
/**
+ * :: Experimental ::
* A [[FutureAction]] holding the result of an action that triggers a single job. Examples include
* count, collect, reduce.
*/
+@Experimental
class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
extends FutureAction[T] {
@@ -141,17 +146,19 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
private def awaitResult(): Try[T] = {
jobWaiter.awaitResult() match {
case JobSucceeded => scala.util.Success(resultFunc)
- case JobFailed(e: Exception, _) => scala.util.Failure(e)
+ case JobFailed(e: Exception) => scala.util.Failure(e)
}
}
}
/**
+ * :: Experimental ::
* A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take,
* takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
* action thread if it is being blocked by a job.
*/
+@Experimental
class ComplexFutureAction[T] extends FutureAction[T] {
// Pointer to the thread that is executing the action. It is set when the action is run.
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index d3264a4bb3c81..a6e300d345786 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -23,14 +23,14 @@ import com.google.common.io.Files
import org.apache.spark.util.Utils
-private[spark] class HttpFileServer extends Logging {
-
+private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging {
+
var baseDir : File = null
var fileDir : File = null
var jarDir : File = null
var httpServer : HttpServer = null
var serverUri : String = null
-
+
def initialize() {
baseDir = Utils.createTempDir()
fileDir = new File(baseDir, "files")
@@ -38,28 +38,29 @@ private[spark] class HttpFileServer extends Logging {
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
- httpServer = new HttpServer(baseDir)
+ httpServer = new HttpServer(baseDir, securityManager)
httpServer.start()
serverUri = httpServer.uri
+ logDebug("HTTP file server started at: " + serverUri)
}
-
+
def stop() {
httpServer.stop()
}
-
+
def addFile(file: File) : String = {
addFileToDir(file, fileDir)
serverUri + "/files/" + file.getName
}
-
+
def addJar(file: File) : String = {
addFileToDir(file, jarDir)
serverUri + "/jars/" + file.getName
}
-
+
def addFileToDir(file: File, dir: File) : String = {
Files.copy(file, new File(dir, file.getName))
dir + "/" + file.getName
}
-
+
}
diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala
index 759e68ee0cc61..7e9b517f901a2 100644
--- a/core/src/main/scala/org/apache/spark/HttpServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpServer.scala
@@ -19,15 +19,18 @@ package org.apache.spark
import java.io.File
+import org.eclipse.jetty.util.security.{Constraint, Password}
+import org.eclipse.jetty.security.authentication.DigestAuthenticator
+import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService, SecurityHandler}
+
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.bio.SocketConnector
-import org.eclipse.jetty.server.handler.DefaultHandler
-import org.eclipse.jetty.server.handler.HandlerList
-import org.eclipse.jetty.server.handler.ResourceHandler
+import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler}
import org.eclipse.jetty.util.thread.QueuedThreadPool
import org.apache.spark.util.Utils
+
/**
* Exception type thrown by HttpServer when it is in the wrong state for an operation.
*/
@@ -38,7 +41,8 @@ private[spark] class ServerStateException(message: String) extends Exception(mes
* as well as classes created by the interpreter when the user types in code. This is just a wrapper
* around a Jetty server.
*/
-private[spark] class HttpServer(resourceBase: File) extends Logging {
+private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager)
+ extends Logging {
private var server: Server = null
private var port: Int = -1
@@ -59,14 +63,60 @@ private[spark] class HttpServer(resourceBase: File) extends Logging {
server.setThreadPool(threadPool)
val resHandler = new ResourceHandler
resHandler.setResourceBase(resourceBase.getAbsolutePath)
+
val handlerList = new HandlerList
handlerList.setHandlers(Array(resHandler, new DefaultHandler))
- server.setHandler(handlerList)
+
+ if (securityManager.isAuthenticationEnabled()) {
+ logDebug("HttpServer is using security")
+ val sh = setupSecurityHandler(securityManager)
+ // make sure we go through security handler to get resources
+ sh.setHandler(handlerList)
+ server.setHandler(sh)
+ } else {
+ logDebug("HttpServer is not using security")
+ server.setHandler(handlerList)
+ }
+
server.start()
port = server.getConnectors()(0).getLocalPort()
}
}
+ /**
+ * Setup Jetty to the HashLoginService using a single user with our
+ * shared secret. Configure it to use DIGEST-MD5 authentication so that the password
+ * isn't passed in plaintext.
+ */
+ private def setupSecurityHandler(securityMgr: SecurityManager): ConstraintSecurityHandler = {
+ val constraint = new Constraint()
+ // use DIGEST-MD5 as the authentication mechanism
+ constraint.setName(Constraint.__DIGEST_AUTH)
+ constraint.setRoles(Array("user"))
+ constraint.setAuthenticate(true)
+ constraint.setDataConstraint(Constraint.DC_NONE)
+
+ val cm = new ConstraintMapping()
+ cm.setConstraint(constraint)
+ cm.setPathSpec("/*")
+ val sh = new ConstraintSecurityHandler()
+
+ // the hashLoginService lets us do a single user and
+ // secret right now. This could be changed to use the
+ // JAASLoginService for other options.
+ val hashLogin = new HashLoginService()
+
+ val userCred = new Password(securityMgr.getSecretKey())
+ if (userCred == null) {
+ throw new Exception("Error: secret key is null with authentication on")
+ }
+ hashLogin.putUser(securityMgr.getHttpUser(), userCred, Array("user"))
+ sh.setLoginService(hashLogin)
+ sh.setAuthenticator(new DigestAuthenticator());
+ sh.setConstraintMappings(Array(cm))
+ sh
+ }
+
def stop() {
if (server == null) {
throw new ServerStateException("Server is already stopped")
diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
index 9b1601d5b95fa..ec11dbbffaaf8 100644
--- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
+++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
@@ -21,10 +21,20 @@ package org.apache.spark
* An iterator that wraps around an existing iterator to provide task killing functionality.
* It works by checking the interrupted flag in [[TaskContext]].
*/
-class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
+private[spark] class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
extends Iterator[T] {
- def hasNext: Boolean = !context.interrupted && delegate.hasNext
+ def hasNext: Boolean = {
+ // TODO(aarondav/rxin): Check Thread.interrupted instead of context.interrupted if interrupt
+ // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
+ // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
+ // introduces an expensive read fence.
+ if (context.interrupted) {
+ throw new TaskKilledException
+ } else {
+ delegate.hasNext
+ }
+ }
def next(): T = delegate.next()
}
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index b749e5414dab6..50d8e93e1f0d7 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -19,12 +19,21 @@ package org.apache.spark
import org.apache.log4j.{LogManager, PropertyConfigurator}
import org.slf4j.{Logger, LoggerFactory}
+import org.slf4j.impl.StaticLoggerBinder
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
/**
+ * :: DeveloperApi ::
* Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows
* logging messages at different levels using methods that only evaluate parameters lazily if the
* log level is enabled.
+ *
+ * NOTE: DO NOT USE this class outside of Spark. It is intended as an internal utility.
+ * This will likely be changed or removed in future releases.
*/
+@DeveloperApi
trait Logging {
// Make the log field transient so that objects with Logging can
// be serialized and used on another machine
@@ -52,7 +61,7 @@ trait Logging {
protected def logDebug(msg: => String) {
if (log.isDebugEnabled) log.debug(msg)
}
-
+
protected def logTrace(msg: => String) {
if (log.isTraceEnabled) log.trace(msg)
}
@@ -101,16 +110,17 @@ trait Logging {
}
private def initializeLogging() {
- // If Log4j doesn't seem initialized, load a default properties file
+ // If Log4j is being used, but is not initialized, load a default properties file
+ val binder = StaticLoggerBinder.getSingleton
+ val usingLog4j = binder.getLoggerFactoryClassStr.endsWith("Log4jLoggerFactory")
val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
- if (!log4jInitialized) {
+ if (!log4jInitialized && usingLog4j) {
val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
- val classLoader = this.getClass.getClassLoader
- Option(classLoader.getResource(defaultLogProps)) match {
- case Some(url) =>
+ Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
+ case Some(url) =>
PropertyConfigurator.configure(url)
log.info(s"Using Spark's default log4j profile: $defaultLogProps")
- case None =>
+ case None =>
System.err.println(s"Spark was unable to load $defaultLogProps")
}
}
@@ -125,4 +135,16 @@ trait Logging {
private object Logging {
@volatile private var initialized = false
val initLock = new Object()
+ try {
+ // We use reflection here to handle the case where users remove the
+ // slf4j-to-jul bridge order to route their logs to JUL.
+ val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler")
+ bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null)
+ val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean]
+ if (!installed) {
+ bridgeClass.getMethod("install").invoke(null)
+ }
+ } catch {
+ case e: ClassNotFoundException => // can't log anything yet so just fail silently
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 5968973132942..ee82d9fa7874b 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -20,28 +20,43 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.HashSet
+import scala.collection.mutable.{HashSet, HashMap, Map}
import scala.concurrent.Await
import akka.actor._
import akka.pattern.ask
-
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
+import org.apache.spark.util._
private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
-private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
+/** Actor class for MapOutputTrackerMaster */
+private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf)
extends Actor with Logging {
+ val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
+
def receive = {
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = sender.path.address.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
- sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
+ val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
+ val serializedSize = mapOutputStatuses.size
+ if (serializedSize > maxAkkaFrameSize) {
+ val msg = s"Map output statuses were $serializedSize bytes which " +
+ s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."
+
+ /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
+ * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
+ * will ultimately remove this entire code path. */
+ val exception = new SparkException(msg)
+ logError(msg, exception)
+ throw exception
+ }
+ sender ! mapOutputStatuses
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
@@ -50,26 +65,41 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
}
}
-private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
-
+/**
+ * Class that keeps track of the location of the map output of
+ * a stage. This is abstract because different versions of MapOutputTracker
+ * (driver and worker) use different HashMap to store its metadata.
+ */
+private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
private val timeout = AkkaUtils.askTimeout(conf)
- // Set to the MapOutputTrackerActor living on the driver
+ /** Set to the MapOutputTrackerActor living on the driver. */
var trackerActor: ActorRef = _
- protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ /**
+ * This HashMap has different behavior for the master and the workers.
+ *
+ * On the master, it serves as the source of map outputs recorded from ShuffleMapTasks.
+ * On the workers, it simply serves as a cache, in which a miss triggers a fetch from the
+ * master's corresponding HashMap.
+ */
+ protected val mapStatuses: Map[Int, Array[MapStatus]]
- // Incremented every time a fetch fails so that client nodes know to clear
- // their cache of map output locations if this happens.
+ /**
+ * Incremented every time a fetch fails so that client nodes know to clear
+ * their cache of map output locations if this happens.
+ */
protected var epoch: Long = 0
- protected val epochLock = new java.lang.Object
+ protected val epochLock = new AnyRef
- private val metadataCleaner =
- new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf)
+ /** Remembers which map output locations are currently being fetched on a worker. */
+ private val fetching = new HashSet[Int]
- // Send a message to the trackerActor and get its result within a default timeout, or
- // throw a SparkException if this fails.
- private def askTracker(message: Any): Any = {
+ /**
+ * Send a message to the trackerActor and get its result within a default timeout, or
+ * throw a SparkException if this fails.
+ */
+ protected def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
Await.result(future, timeout)
@@ -79,17 +109,17 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
}
}
- // Send a one-way message to the trackerActor, to which we expect it to reply with true.
- private def communicate(message: Any) {
+ /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */
+ protected def sendTracker(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}
- // Remembers which map output locations are currently being fetched on a worker
- private val fetching = new HashSet[Int]
-
- // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
+ /**
+ * Called from executors to get the server URIs and output sizes of the map outputs of
+ * a given shuffle.
+ */
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
@@ -137,8 +167,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
- }
- else {
+ } else {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
}
@@ -149,27 +178,18 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
}
}
- protected def cleanup(cleanupTime: Long) {
- mapStatuses.clearOldValues(cleanupTime)
- }
-
- def stop() {
- communicate(StopMapOutputTracker)
- mapStatuses.clear()
- metadataCleaner.cancel()
- trackerActor = null
- }
-
- // Called to get current epoch number
+ /** Called to get current epoch number. */
def getEpoch: Long = {
epochLock.synchronized {
return epoch
}
}
- // Called on workers to update the epoch number, potentially clearing old outputs
- // because of a fetch failure. (Each worker task calls this with the latest epoch
- // number on the master at the time it was created.)
+ /**
+ * Called from executors to update the epoch number, potentially clearing old outputs
+ * because of a fetch failure. Each worker task calls this with the latest epoch
+ * number on the master at the time it was created.
+ */
def updateEpoch(newEpoch: Long) {
epochLock.synchronized {
if (newEpoch > epoch) {
@@ -179,17 +199,40 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
}
}
}
+
+ /** Unregister shuffle data. */
+ def unregisterShuffle(shuffleId: Int) {
+ mapStatuses.remove(shuffleId)
+ }
+
+ /** Stop the tracker. */
+ def stop() { }
}
+/**
+ * MapOutputTracker for the driver. This uses TimeStampedHashMap to keep track of map
+ * output information, which allows old output information based on a TTL.
+ */
private[spark] class MapOutputTrackerMaster(conf: SparkConf)
extends MapOutputTracker(conf) {
- // Cache a serialized version of the output statuses for each shuffle to send them out faster
+ /** Cache a serialized version of the output statuses for each shuffle to send them out faster */
private var cacheEpoch = epoch
- private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
+
+ /**
+ * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master,
+ * so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set).
+ * Other than these two scenarios, nothing should be dropped from this HashMap.
+ */
+ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]()
+ private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]()
+
+ // For cleaning up TimeStampedHashMaps
+ private val metadataCleaner =
+ new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf)
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+ if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
@@ -201,6 +244,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}
+ /** Register multiple map output information for the given shuffle */
def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeEpoch) {
@@ -208,6 +252,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}
+ /** Unregister map output information of the given shuffle, mapper and block manager */
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
val arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
@@ -223,6 +268,17 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}
+ /** Unregister shuffle data */
+ override def unregisterShuffle(shuffleId: Int) {
+ mapStatuses.remove(shuffleId)
+ cachedSerializedStatuses.remove(shuffleId)
+ }
+
+ /** Check if the given shuffle is being tracked */
+ def containsShuffle(shuffleId: Int): Boolean = {
+ cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
+ }
+
def incrementEpoch() {
epochLock.synchronized {
epoch += 1
@@ -259,23 +315,26 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
bytes
}
- protected override def cleanup(cleanupTime: Long) {
- super.cleanup(cleanupTime)
- cachedSerializedStatuses.clearOldValues(cleanupTime)
- }
-
override def stop() {
- super.stop()
+ sendTracker(StopMapOutputTracker)
+ mapStatuses.clear()
+ trackerActor = null
+ metadataCleaner.cancel()
cachedSerializedStatuses.clear()
}
- override def updateEpoch(newEpoch: Long) {
- // This might be called on the MapOutputTrackerMaster if we're running in local mode.
+ private def cleanup(cleanupTime: Long) {
+ mapStatuses.clearOldValues(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
}
+}
- def has(shuffleId: Int): Boolean = {
- cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId)
- }
+/**
+ * MapOutputTracker for the workers, which fetches map output information from the driver's
+ * MapOutputTrackerMaster.
+ */
+private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
+ protected val mapStatuses = new HashMap[Int, Array[MapStatus]]
}
private[spark] object MapOutputTracker {
diff --git a/core/src/main/scala/org/apache/spark/Partition.scala b/core/src/main/scala/org/apache/spark/Partition.scala
index 87914a061f5d7..27892dbd2a0bc 100644
--- a/core/src/main/scala/org/apache/spark/Partition.scala
+++ b/core/src/main/scala/org/apache/spark/Partition.scala
@@ -25,7 +25,7 @@ trait Partition extends Serializable {
* Get the split's index within its parent RDD
*/
def index: Int
-
+
// A better default implementation of HashCode
override def hashCode(): Int = index
}
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index ad9988226470c..9155159cf6aeb 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -89,12 +89,14 @@ class HashPartitioner(partitions: Int) extends Partitioner {
* A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly
* equal ranges. The ranges are determined by sampling the content of the RDD passed in.
*/
-class RangePartitioner[K <% Ordered[K]: ClassTag, V](
+class RangePartitioner[K : Ordering : ClassTag, V](
partitions: Int,
@transient rdd: RDD[_ <: Product2[K,V]],
private val ascending: Boolean = true)
extends Partitioner {
+ private val ordering = implicitly[Ordering[K]]
+
// An array of upper bounds for the first (partitions - 1) partitions
private val rangeBounds: Array[K] = {
if (partitions == 1) {
@@ -103,7 +105,7 @@ class RangePartitioner[K <% Ordered[K]: ClassTag, V](
val rddSize = rdd.count()
val maxSampleSize = partitions * 20.0
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
- val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sortWith(_ < _)
+ val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sorted
if (rddSample.length == 0) {
Array()
} else {
@@ -126,7 +128,7 @@ class RangePartitioner[K <% Ordered[K]: ClassTag, V](
var partition = 0
if (rangeBounds.length < 1000) {
// If we have less than 100 partitions naive search
- while (partition < rangeBounds.length && k > rangeBounds(partition)) {
+ while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
partition += 1
}
} else {
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
new file mode 100644
index 0000000000000..b4b0067801259
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.net.{Authenticator, PasswordAuthentication}
+
+import org.apache.hadoop.io.Text
+
+import org.apache.spark.deploy.SparkHadoopUtil
+
+/**
+ * Spark class responsible for security.
+ *
+ * In general this class should be instantiated by the SparkEnv and most components
+ * should access it from that. There are some cases where the SparkEnv hasn't been
+ * initialized yet and this class must be instantiated directly.
+ *
+ * Spark currently supports authentication via a shared secret.
+ * Authentication can be configured to be on via the 'spark.authenticate' configuration
+ * parameter. This parameter controls whether the Spark communication protocols do
+ * authentication using the shared secret. This authentication is a basic handshake to
+ * make sure both sides have the same shared secret and are allowed to communicate.
+ * If the shared secret is not identical they will not be allowed to communicate.
+ *
+ * The Spark UI can also be secured by using javax servlet filters. A user may want to
+ * secure the UI if it has data that other users should not be allowed to see. The javax
+ * servlet filter specified by the user can authenticate the user and then once the user
+ * is logged in, Spark can compare that user versus the view acls to make sure they are
+ * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls'
+ * control the behavior of the acls. Note that the person who started the application
+ * always has view access to the UI.
+ *
+ * Spark does not currently support encryption after authentication.
+ *
+ * At this point spark has multiple communication protocols that need to be secured and
+ * different underlying mechanisms are used depending on the protocol:
+ *
+ * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality.
+ * Akka remoting allows you to specify a secure cookie that will be exchanged
+ * and ensured to be identical in the connection handshake between the client
+ * and the server. If they are not identical then the client will be refused
+ * to connect to the server. There is no control of the underlying
+ * authentication mechanism so its not clear if the password is passed in
+ * plaintext or uses DIGEST-MD5 or some other mechanism.
+ * Akka also has an option to turn on SSL, this option is not currently supported
+ * but we could add a configuration option in the future.
+ *
+ * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty
+ * for the HttpServer. Jetty supports multiple authentication mechanisms -
+ * Basic, Digest, Form, Spengo, etc. It also supports multiple different login
+ * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService
+ * to authenticate using DIGEST-MD5 via a single user and the shared secret.
+ * Since we are using DIGEST-MD5, the shared secret is not passed on the wire
+ * in plaintext.
+ * We currently do not support SSL (https), but Jetty can be configured to use it
+ * so we could add a configuration option for this in the future.
+ *
+ * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5.
+ * Any clients must specify the user and password. There is a default
+ * Authenticator installed in the SecurityManager to how it does the authentication
+ * and in this case gets the user name and password from the request.
+ *
+ * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
+ * exchange messages. For this we use the Java SASL
+ * (Simple Authentication and Security Layer) API and again use DIGEST-MD5
+ * as the authentication mechanism. This means the shared secret is not passed
+ * over the wire in plaintext.
+ * Note that SASL is pluggable as to what mechanism it uses. We currently use
+ * DIGEST-MD5 but this could be changed to use Kerberos or other in the future.
+ * Spark currently supports "auth" for the quality of protection, which means
+ * the connection is not supporting integrity or privacy protection (encryption)
+ * after authentication. SASL also supports "auth-int" and "auth-conf" which
+ * SPARK could be support in the future to allow the user to specify the quality
+ * of protection they want. If we support those, the messages will also have to
+ * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
+ *
+ * Since the connectionManager does asynchronous messages passing, the SASL
+ * authentication is a bit more complex. A ConnectionManager can be both a client
+ * and a Server, so for a particular connection is has to determine what to do.
+ * A ConnectionId was added to be able to track connections and is used to
+ * match up incoming messages with connections waiting for authentication.
+ * If its acting as a client and trying to send a message to another ConnectionManager,
+ * it blocks the thread calling sendMessage until the SASL negotiation has occurred.
+ * The ConnectionManager tracks all the sendingConnections using the ConnectionId
+ * and waits for the response from the server and does the handshake.
+ *
+ * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
+ * can be used. Yarn requires a specific AmIpFilter be installed for security to work
+ * properly. For non-Yarn deployments, users can write a filter to go through a
+ * companies normal login service. If an authentication filter is in place then the
+ * SparkUI can be configured to check the logged in user against the list of users who
+ * have view acls to see if that user is authorized.
+ * The filters can also be used for many different purposes. For instance filters
+ * could be used for logging, encryption, or compression.
+ *
+ * The exact mechanisms used to generate/distributed the shared secret is deployment specific.
+ *
+ * For Yarn deployments, the secret is automatically generated using the Akka remote
+ * Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed
+ * around via the Hadoop RPC mechanism. Hadoop RPC can be configured to support different levels
+ * of protection. See the Hadoop documentation for more details. Each Spark application on Yarn
+ * gets a different shared secret. On Yarn, the Spark UI gets configured to use the Hadoop Yarn
+ * AmIpFilter which requires the user to go through the ResourceManager Proxy. That Proxy is there
+ * to reduce the possibility of web based attacks through YARN. Hadoop can be configured to use
+ * filters to do authentication. That authentication then happens via the ResourceManager Proxy
+ * and Spark will use that to do authorization against the view acls.
+ *
+ * For other Spark deployments, the shared secret must be specified via the
+ * spark.authenticate.secret config.
+ * All the nodes (Master and Workers) and the applications need to have the same shared secret.
+ * This again is not ideal as one user could potentially affect another users application.
+ * This should be enhanced in the future to provide better protection.
+ * If the UI needs to be secured the user needs to install a javax servlet filter to do the
+ * authentication. Spark will then use that user to compare against the view acls to do
+ * authorization. If not filter is in place the user is generally null and no authorization
+ * can take place.
+ */
+
+private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
+
+ // key used to store the spark secret in the Hadoop UGI
+ private val sparkSecretLookupKey = "sparkCookie"
+
+ private val authOn = sparkConf.getBoolean("spark.authenticate", false)
+ private var uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false)
+
+ private var viewAcls: Set[String] = _
+ // always add the current user and SPARK_USER to the viewAcls
+ private val defaultAclUsers = Seq[String](System.getProperty("user.name", ""),
+ Option(System.getenv("SPARK_USER")).getOrElse(""))
+ setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", ""))
+
+ private val secretKey = generateSecretKey()
+ logInfo("SecurityManager, is authentication enabled: " + authOn +
+ " are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString())
+
+ // Set our own authenticator to properly negotiate user/password for HTTP connections.
+ // This is needed by the HTTP client fetching from the HttpServer. Put here so its
+ // only set once.
+ if (authOn) {
+ Authenticator.setDefault(
+ new Authenticator() {
+ override def getPasswordAuthentication(): PasswordAuthentication = {
+ var passAuth: PasswordAuthentication = null
+ val userInfo = getRequestingURL().getUserInfo()
+ if (userInfo != null) {
+ val parts = userInfo.split(":", 2)
+ passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray())
+ }
+ return passAuth
+ }
+ }
+ )
+ }
+
+ private[spark] def setViewAcls(defaultUsers: Seq[String], allowedUsers: String) {
+ viewAcls = (defaultUsers ++ allowedUsers.split(',')).map(_.trim()).filter(!_.isEmpty).toSet
+ logInfo("Changing view acls to: " + viewAcls.mkString(","))
+ }
+
+ private[spark] def setViewAcls(defaultUser: String, allowedUsers: String) {
+ setViewAcls(Seq[String](defaultUser), allowedUsers)
+ }
+
+ private[spark] def setUIAcls(aclSetting: Boolean) {
+ uiAclsOn = aclSetting
+ logInfo("Changing acls enabled to: " + uiAclsOn)
+ }
+
+ /**
+ * Generates or looks up the secret key.
+ *
+ * The way the key is stored depends on the Spark deployment mode. Yarn
+ * uses the Hadoop UGI.
+ *
+ * For non-Yarn deployments, If the config variable is not set
+ * we throw an exception.
+ */
+ private def generateSecretKey(): String = {
+ if (!isAuthenticationEnabled) return null
+ // first check to see if the secret is already set, else generate a new one if on yarn
+ val sCookie = if (SparkHadoopUtil.get.isYarnMode) {
+ val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey)
+ if (secretKey != null) {
+ logDebug("in yarn mode, getting secret from credentials")
+ return new Text(secretKey).toString
+ } else {
+ logDebug("getSecretKey: yarn mode, secret key from credentials is null")
+ }
+ val cookie = akka.util.Crypt.generateSecureCookie
+ // if we generated the secret then we must be the first so lets set it so t
+ // gets used by everyone else
+ SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, cookie)
+ logInfo("adding secret to credentials in yarn mode")
+ cookie
+ } else {
+ // user must have set spark.authenticate.secret config
+ sparkConf.getOption("spark.authenticate.secret") match {
+ case Some(value) => value
+ case None => throw new Exception("Error: a secret key must be specified via the " +
+ "spark.authenticate.secret config")
+ }
+ }
+ sCookie
+ }
+
+ /**
+ * Check to see if Acls for the UI are enabled
+ * @return true if UI authentication is enabled, otherwise false
+ */
+ def uiAclsEnabled(): Boolean = uiAclsOn
+
+ /**
+ * Checks the given user against the view acl list to see if they have
+ * authorization to view the UI. If the UI acls must are disabled
+ * via spark.ui.acls.enable, all users have view access.
+ *
+ * @param user to see if is authorized
+ * @return true is the user has permission, otherwise false
+ */
+ def checkUIViewPermissions(user: String): Boolean = {
+ logDebug("user=" + user + " uiAclsEnabled=" + uiAclsEnabled() + " viewAcls=" +
+ viewAcls.mkString(","))
+ if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true
+ }
+
+ /**
+ * Check to see if authentication for the Spark communication protocols is enabled
+ * @return true if authentication is enabled, otherwise false
+ */
+ def isAuthenticationEnabled(): Boolean = authOn
+
+ /**
+ * Gets the user used for authenticating HTTP connections.
+ * For now use a single hardcoded user.
+ * @return the HTTP user as a String
+ */
+ def getHttpUser(): String = "sparkHttpUser"
+
+ /**
+ * Gets the user used for authenticating SASL connections.
+ * For now use a single hardcoded user.
+ * @return the SASL user as a String
+ */
+ def getSaslUser(): String = "sparkSaslUser"
+
+ /**
+ * Gets the secret key.
+ * @return the secret key as a String if authentication is enabled, otherwise returns null
+ */
+ def getSecretKey(): String = secretKey
+}
diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala
index dff665cae6cb6..e50b9ac2291f9 100644
--- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala
+++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala
@@ -23,6 +23,9 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.ObjectWritable
import org.apache.hadoop.io.Writable
+import org.apache.spark.annotation.DeveloperApi
+
+@DeveloperApi
class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable {
def value = t
override def toString = t.toString
diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
index e8f756c408889..a4f69b6b22b2c 100644
--- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
@@ -29,7 +29,7 @@ private[spark] abstract class ShuffleFetcher {
shuffleId: Int,
reduceId: Int,
context: TaskContext,
- serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
+ serializer: Serializer = SparkEnv.get.serializer): Iterator[T]
/** Stop the fetcher */
def stop() {}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index b947feb891ee6..bd21fdc5a18e4 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -208,6 +208,82 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
new SparkConf(false).setAll(settings)
}
+ /** Checks for illegal or deprecated config settings. Throws an exception for the former. Not
+ * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */
+ private[spark] def validateSettings() {
+ if (settings.contains("spark.local.dir")) {
+ val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " +
+ "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)."
+ logWarning(msg)
+ }
+
+ val executorOptsKey = "spark.executor.extraJavaOptions"
+ val executorClasspathKey = "spark.executor.extraClassPath"
+ val driverOptsKey = "spark.driver.extraJavaOptions"
+ val driverClassPathKey = "spark.driver.extraClassPath"
+
+ // Validate spark.executor.extraJavaOptions
+ settings.get(executorOptsKey).map { javaOpts =>
+ if (javaOpts.contains("-Dspark")) {
+ val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts)'. " +
+ "Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit."
+ throw new Exception(msg)
+ }
+ if (javaOpts.contains("-Xmx") || javaOpts.contains("-Xms")) {
+ val msg = s"$executorOptsKey is not allowed to alter memory settings (was '$javaOpts'). " +
+ "Use spark.executor.memory instead."
+ throw new Exception(msg)
+ }
+ }
+
+ // Check for legacy configs
+ sys.env.get("SPARK_JAVA_OPTS").foreach { value =>
+ val error =
+ s"""
+ |SPARK_JAVA_OPTS was detected (set to '$value').
+ |This has undefined behavior when running on a cluster and is deprecated in Spark 1.0+.
+ |
+ |Please instead use:
+ | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application
+ | - ./spark-submit with --driver-java-options to set -X options for a driver
+ | - spark.executor.extraJavaOptions to set -X options for executors
+ | - SPARK_DAEMON_OPTS to set java options for standalone daemons (i.e. master, worker)
+ """.stripMargin
+ logError(error)
+
+ for (key <- Seq(executorOptsKey, driverOptsKey)) {
+ if (getOption(key).isDefined) {
+ throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.")
+ } else {
+ logWarning(s"Setting '$key' to '$value' as a work-around.")
+ set(key, value)
+ }
+ }
+ }
+
+ sys.env.get("SPARK_CLASSPATH").foreach { value =>
+ val error =
+ s"""
+ |SPARK_CLASSPATH was detected (set to '$value').
+ | This has undefined behavior when running on a cluster and is deprecated in Spark 1.0+.
+ |
+ |Please instead use:
+ | - ./spark-submit with --driver-class-path to augment the driver classpath
+ | - spark.executor.extraClassPath to augment the executor classpath
+ """.stripMargin
+ logError(error)
+
+ for (key <- Seq(executorClasspathKey, driverClassPathKey)) {
+ if (getOption(key).isDefined) {
+ throw new SparkException(s"Found both $key and SPARK_CLASSPATH. Use only the former.")
+ } else {
+ logWarning(s"Setting '$key' to '$value' as a work-around.")
+ set(key, value)
+ }
+ }
+ }
+ }
+
/**
* Return a string listing all keys and values, one per line. This is useful to print the
* configuration out for debugging.
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index a24f07e9a6e9a..eb14d87467af7 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -19,14 +19,14 @@ package org.apache.spark
import java.io._
import java.net.URI
-import java.util.{Properties, UUID}
import java.util.concurrent.atomic.AtomicInteger
-
+import java.util.{Properties, UUID}
+import java.util.UUID.randomUUID
import scala.collection.{Map, Set}
import scala.collection.generic.Growable
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.language.implicitConversions
import scala.reflect.{ClassTag, classTag}
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable}
@@ -35,7 +35,10 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.mesos.MesosNativeLibrary
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
+import org.apache.spark.input.WholeTextFileInputFormat
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
@@ -44,25 +47,38 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me
import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils}
/**
+ * :: DeveloperApi ::
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
*
* @param config a Spark Config object describing the application configuration. Any settings in
* this config overrides the default configs as well as system properties.
- * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. Can
- * be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]]
- * from a list of input files or InputFormats for the application.
*/
-class SparkContext(
- config: SparkConf,
- // This is used only by YARN for now, but should be relevant to other cluster types (Mesos,
- // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It
- // contains a map from hostname to a list of input format splits on the host.
- val preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map())
- extends Logging {
+
+@DeveloperApi
+class SparkContext(config: SparkConf) extends Logging {
+
+ // This is used only by YARN for now, but should be relevant to other cluster types (Mesos,
+ // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It
+ // contains a map from hostname to a list of input format splits on the host.
+ private[spark] var preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()
+
+ /**
+ * :: DeveloperApi ::
+ * Alternative constructor for setting preferred locations where Spark will create executors.
+ *
+ * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. Ca
+ * be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]]
+ * from a list of input files or InputFormats for the application.
+ */
+ @DeveloperApi
+ def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = {
+ this(config)
+ this.preferredNodeLocationData = preferredNodeLocationData
+ }
/**
* Alternative constructor that allows setting common Spark properties directly
@@ -92,11 +108,47 @@ class SparkContext(
environment: Map[String, String] = Map(),
preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) =
{
- this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment),
- preferredNodeLocationData)
+ this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment))
+ this.preferredNodeLocationData = preferredNodeLocationData
}
+ // NOTE: The below constructors could be consolidated using default arguments. Due to
+ // Scala bug SI-8479, however, this causes the compile step to fail when generating docs.
+ // Until we have a good workaround for that bug the constructors remain broken out.
+
+ /**
+ * Alternative constructor that allows setting common Spark properties directly
+ *
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param appName A name for your application, to display on the cluster web UI.
+ */
+ private[spark] def this(master: String, appName: String) =
+ this(master, appName, null, Nil, Map(), Map())
+
+ /**
+ * Alternative constructor that allows setting common Spark properties directly
+ *
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param appName A name for your application, to display on the cluster web UI.
+ * @param sparkHome Location where Spark is installed on cluster nodes.
+ */
+ private[spark] def this(master: String, appName: String, sparkHome: String) =
+ this(master, appName, sparkHome, Nil, Map(), Map())
+
+ /**
+ * Alternative constructor that allows setting common Spark properties directly
+ *
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param appName A name for your application, to display on the cluster web UI.
+ * @param sparkHome Location where Spark is installed on cluster nodes.
+ * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
+ * system or HDFS, HTTP, HTTPS, or FTP URLs.
+ */
+ private[spark] def this(master: String, appName: String, sparkHome: String, jars: Seq[String]) =
+ this(master, appName, sparkHome, jars, Map(), Map())
+
private[spark] val conf = config.clone()
+ conf.validateSettings()
/**
* Return a copy of this SparkContext's configuration. The configuration ''cannot'' be
@@ -108,7 +160,7 @@ class SparkContext(
throw new SparkException("A master URL must be set in your configuration")
}
if (!conf.contains("spark.app.name")) {
- throw new SparkException("An application must be set in your configuration")
+ throw new SparkException("An application name must be set in your configuration")
}
if (conf.getBoolean("spark.logConf", false)) {
@@ -119,17 +171,27 @@ class SparkContext(
conf.setIfMissing("spark.driver.host", Utils.localHostName())
conf.setIfMissing("spark.driver.port", "0")
- val jars: Seq[String] = if (conf.contains("spark.jars")) {
- conf.get("spark.jars").split(",").filter(_.size != 0)
- } else {
- null
- }
+ val jars: Seq[String] =
+ conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten
+
+ val files: Seq[String] =
+ conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten
val master = conf.get("spark.master")
val appName = conf.get("spark.app.name")
+ // Generate the random name for a temp folder in Tachyon
+ // Add a timestamp as the suffix here to make it more safe
+ val tachyonFolderName = "spark-" + randomUUID.toString()
+ conf.set("spark.tachyonStore.folderName", tachyonFolderName)
+
val isLocal = (master == "local" || master.startsWith("local["))
+ if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true")
+
+ // An asynchronous listener bus for Spark events
+ private[spark] val listenerBus = new LiveListenerBus
+
// Create the Spark execution environment (cache, map output tracker, etc)
private[spark] val env = SparkEnv.create(
conf,
@@ -137,7 +199,8 @@ class SparkContext(
conf.get("spark.driver.host"),
conf.get("spark.driver.port").toInt,
isDriver = true,
- isLocal = isLocal)
+ isLocal = isLocal,
+ listenerBus = listenerBus)
SparkEnv.set(env)
// Used to store a URL for each static file/jar together with the file's local timestamp
@@ -145,14 +208,50 @@ class SparkContext(
private[spark] val addedJars = HashMap[String, Long]()
// Keeps track of all persisted RDDs
- private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
+ private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]]
private[spark] val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf)
- // Initialize the Spark UI
+ // Initialize the Spark UI, registering all associated listeners
private[spark] val ui = new SparkUI(this)
ui.bind()
+ /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
+ val hadoopConfiguration: Configuration = {
+ val env = SparkEnv.get
+ val hadoopConf = SparkHadoopUtil.get.newConfiguration()
+ // Explicitly check for S3 environment variables
+ if (System.getenv("AWS_ACCESS_KEY_ID") != null &&
+ System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
+ hadoopConf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ hadoopConf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ hadoopConf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ hadoopConf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ }
+ // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
+ conf.getAll.foreach { case (key, value) =>
+ if (key.startsWith("spark.hadoop.")) {
+ hadoopConf.set(key.substring("spark.hadoop.".length), value)
+ }
+ }
+ val bufferSize = conf.get("spark.buffer.size", "65536")
+ hadoopConf.set("io.file.buffer.size", bufferSize)
+ hadoopConf
+ }
+
+ // Optionally log Spark events
+ private[spark] val eventLogger: Option[EventLoggingListener] = {
+ if (conf.getBoolean("spark.eventLog.enabled", false)) {
+ val logger = new EventLoggingListener(appName, conf, hadoopConfiguration)
+ logger.start()
+ listenerBus.addListener(logger)
+ Some(logger)
+ } else None
+ }
+
+ // At this point, all relevant SparkListeners have been registered, so begin releasing events
+ listenerBus.start()
+
val startTime = System.currentTimeMillis()
// Add each JAR given through the constructor
@@ -160,32 +259,36 @@ class SparkContext(
jars.foreach(addJar)
}
+ if (files != null) {
+ files.foreach(addFile)
+ }
+
+ private def warnSparkMem(value: String): String = {
+ logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " +
+ "deprecated, please use spark.executor.memory instead.")
+ value
+ }
+
private[spark] val executorMemory = conf.getOption("spark.executor.memory")
- .orElse(Option(System.getenv("SPARK_MEM")))
+ .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY")))
+ .orElse(Option(System.getenv("SPARK_MEM")).map(warnSparkMem))
.map(Utils.memoryStringToMb)
.getOrElse(512)
- if (!conf.contains("spark.executor.memory") && sys.env.contains("SPARK_MEM")) {
- logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " +
- "deprecated, instead use spark.executor.memory")
- }
+ // Environment variables to pass to our executors.
+ // NOTE: This should only be used for test related settings.
+ private[spark] val testExecutorEnvs = HashMap[String, String]()
- // Environment variables to pass to our executors
- private[spark] val executorEnvs = HashMap[String, String]()
- // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
- for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS");
- value <- Option(System.getenv(key))) {
- executorEnvs(key) = value
- }
// Convert java options to env vars as a work around
// since we can't set env vars directly in sbt.
- for { (envKey, propKey) <- Seq(("SPARK_HOME", "spark.home"), ("SPARK_TESTING", "spark.testing"))
+ for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing"))
value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} {
- executorEnvs(envKey) = value
+ testExecutorEnvs(envKey) = value
}
- // Since memory can be set with a system property too, use that
- executorEnvs("SPARK_MEM") = executorMemory + "m"
- executorEnvs ++= conf.getExecutorEnv
+ // The Mesos scheduler backend relies on this environment variable to set executor memory.
+ // TODO: Set this only in the Mesos scheduler.
+ testExecutorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m"
+ testExecutorEnvs ++= conf.getExecutorEnv
// Set SPARK_USER for user who is running SparkContext.
val sparkUser = Option {
@@ -193,39 +296,33 @@ class SparkContext(
}.getOrElse {
SparkContext.SPARK_UNKNOWN_USER
}
- executorEnvs("SPARK_USER") = sparkUser
+ testExecutorEnvs("SPARK_USER") = sparkUser
// Create and start the scheduler
- private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master, appName)
- taskScheduler.start()
-
- @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler)
- dagScheduler.start()
+ private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master)
+ @volatile private[spark] var dagScheduler: DAGScheduler = _
+ try {
+ dagScheduler = new DAGScheduler(this)
+ } catch {
+ case e: Exception => throw
+ new SparkException("DAGScheduler cannot be initialized due to %s".format(e.getMessage))
+ }
- ui.start()
+ // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's
+ // constructor
+ taskScheduler.start()
- /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
- val hadoopConfiguration = {
- val env = SparkEnv.get
- val hadoopConf = SparkHadoopUtil.get.newConfiguration()
- // Explicitly check for S3 environment variables
- if (System.getenv("AWS_ACCESS_KEY_ID") != null &&
- System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
- hadoopConf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
- hadoopConf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
- hadoopConf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
- hadoopConf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
- }
- // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
- conf.getAll.foreach { case (key, value) =>
- if (key.startsWith("spark.hadoop.")) {
- hadoopConf.set(key.substring("spark.hadoop.".length), value)
- }
+ private[spark] val cleaner: Option[ContextCleaner] = {
+ if (conf.getBoolean("spark.cleaner.referenceTracking", true)) {
+ Some(new ContextCleaner(this))
+ } else {
+ None
}
- val bufferSize = conf.get("spark.buffer.size", "65536")
- hadoopConf.set("io.file.buffer.size", bufferSize)
- hadoopConf
}
+ cleaner.foreach(_.start())
+
+ postEnvironmentUpdate()
+ postApplicationStart()
private[spark] var checkpointDir: Option[String] = None
@@ -240,6 +337,7 @@ class SparkContext(
localProperties.set(props)
}
+ @deprecated("Properties no longer need to be explicitly initialized.", "1.0.0")
def initLocalProperties() {
localProperties.set(new Properties())
}
@@ -290,16 +388,27 @@ class SparkContext(
* // In a separate thread:
* sc.cancelJobGroup("some_job_to_cancel")
* }}}
+ *
+ * If interruptOnCancel is set to true for the job group, then job cancellation will result
+ * in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure
+ * that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208,
+ * where HDFS may respond to Thread.interrupt() by marking nodes as dead.
*/
- def setJobGroup(groupId: String, description: String) {
+ def setJobGroup(groupId: String, description: String, interruptOnCancel: Boolean = false) {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
+ // Note: Specifying interruptOnCancel in setJobGroup (rather than cancelJobGroup) avoids
+ // changing several public APIs and allows Spark cancellations outside of the cancelJobGroup
+ // APIs to also take advantage of this property (e.g., internal job failures or canceling from
+ // JobProgressTab UI) on a per-job basis.
+ setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, interruptOnCancel.toString)
}
/** Clear the current thread's job group ID and its description. */
def clearJobGroup() {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
+ setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null)
}
// Post init
@@ -308,7 +417,7 @@ class SparkContext(
private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
- def initDriverMetrics() {
+ private def initDriverMetrics() {
SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
SparkEnv.get.metricsSystem.registerSource(blockManagerSource)
}
@@ -339,9 +448,50 @@ class SparkContext(
* Read a text file from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
- def textFile(path: String, minSplits: Int = defaultMinSplits): RDD[String] = {
+ def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = {
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
- minSplits).map(pair => pair._2.toString)
+ minPartitions).map(pair => pair._2.toString)
+ }
+
+ /**
+ * Read a directory of text files from HDFS, a local file system (available on all nodes), or any
+ * Hadoop-supported file system URI. Each file is read as a single record and returned in a
+ * key-value pair, where the key is the path of each file, the value is the content of each file.
+ *
+ *
For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do `val rdd = sparkContext.wholeTextFile("hdfs://a-hdfs-path")`,
+ *
+ *
then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred, large file is also allowable, but may cause bad performance.
+ *
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
+ */
+ def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions):
+ RDD[(String, String)] = {
+ val job = new NewHadoopJob(hadoopConfiguration)
+ NewFileInputFormat.addInputPath(job, new Path(path))
+ val updateConf = job.getConfiguration
+ new WholeTextFileRDD(
+ this,
+ classOf[WholeTextFileInputFormat],
+ classOf[String],
+ classOf[String],
+ updateConf,
+ minPartitions)
}
/**
@@ -350,10 +500,10 @@ class SparkContext(
* using the older MapReduce API (`org.apache.hadoop.mapred`).
*
* @param conf JobConf for setting up the dataset
- * @param inputFormatClass Class of the [[InputFormat]]
+ * @param inputFormatClass Class of the InputFormat
* @param keyClass Class of the keys
* @param valueClass Class of the values
- * @param minSplits Minimum number of Hadoop Splits to generate.
+ * @param minPartitions Minimum number of Hadoop Splits to generate.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
* record, directly caching the returned RDD will create many references to the same object.
@@ -365,11 +515,11 @@ class SparkContext(
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
- minSplits: Int = defaultMinSplits
+ minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
// Add necessary security credentials to the JobConf before broadcasting it.
SparkHadoopUtil.get.addCredentials(conf)
- new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
+ new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions)
}
/** Get an RDD for a Hadoop file with an arbitrary InputFormat
@@ -384,7 +534,7 @@ class SparkContext(
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
- minSplits: Int = defaultMinSplits
+ minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
@@ -396,7 +546,7 @@ class SparkContext(
inputFormatClass,
keyClass,
valueClass,
- minSplits)
+ minPartitions)
}
/**
@@ -404,7 +554,7 @@ class SparkContext(
* values and the InputFormat so that users don't need to pass them directly. Instead, callers
* can just write, for example,
* {{{
- * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minSplits)
+ * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minPartitions)
* }}}
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
@@ -413,13 +563,13 @@ class SparkContext(
* a `map` function.
*/
def hadoopFile[K, V, F <: InputFormat[K, V]]
- (path: String, minSplits: Int)
+ (path: String, minPartitions: Int)
(implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = {
hadoopFile(path,
fm.runtimeClass.asInstanceOf[Class[F]],
km.runtimeClass.asInstanceOf[Class[K]],
vm.runtimeClass.asInstanceOf[Class[V]],
- minSplits)
+ minPartitions)
}
/**
@@ -437,7 +587,7 @@ class SparkContext(
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
(implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] =
- hadoopFile[K, V, F](path, defaultMinSplits)
+ hadoopFile[K, V, F](path, defaultMinPartitions)
/** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]
@@ -498,10 +648,10 @@ class SparkContext(
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V],
- minSplits: Int
+ minPartitions: Int
): RDD[(K, V)] = {
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
- hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits)
+ hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)
}
/** Get an RDD for a Hadoop SequenceFile with given key and value types.
@@ -513,7 +663,7 @@ class SparkContext(
* */
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]
): RDD[(K, V)] =
- sequenceFile(path, keyClass, valueClass, defaultMinSplits)
+ sequenceFile(path, keyClass, valueClass, defaultMinPartitions)
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
@@ -537,7 +687,7 @@ class SparkContext(
* a `map` function.
*/
def sequenceFile[K, V]
- (path: String, minSplits: Int = defaultMinSplits)
+ (path: String, minPartitions: Int = defaultMinPartitions)
(implicit km: ClassTag[K], vm: ClassTag[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
@@ -546,7 +696,7 @@ class SparkContext(
val format = classOf[SequenceFileInputFormat[Writable, Writable]]
val writables = hadoopFile(path, format,
kc.writableClass(km).asInstanceOf[Class[Writable]],
- vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits)
+ vc.writableClass(vm).asInstanceOf[Class[Writable]], minPartitions)
writables.map { case (k, v) => (kc.convert(k), vc.convert(v)) }
}
@@ -560,13 +710,12 @@ class SparkContext(
*/
def objectFile[T: ClassTag](
path: String,
- minSplits: Int = defaultMinSplits
+ minPartitions: Int = defaultMinPartitions
): RDD[T] = {
- sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minSplits)
+ sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions)
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes))
}
-
protected[spark] def checkpointFile[T: ClassTag](
path: String
): RDD[T] = {
@@ -580,6 +729,9 @@ class SparkContext(
def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] =
new UnionRDD(this, Seq(first) ++ rest)
+ /** Get an RDD that has no partitions or elements. */
+ def emptyRDD[T: ClassTag] = new EmptyRDD[T](this)
+
// Methods for creating shared variables
/**
@@ -605,7 +757,7 @@ class SparkContext(
* standard mutable collections. So you can use this with mutable Map, Set, etc.
*/
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
- (initialValue: R) = {
+ (initialValue: R): Accumulable[R, T] = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
}
@@ -615,7 +767,11 @@ class SparkContext(
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*/
- def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
+ def broadcast[T](value: T): Broadcast[T] = {
+ val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
+ cleaner.foreach(_.registerBroadcastForCleanup(bc))
+ bc
+ }
/**
* Add a file to be downloaded with this Spark job on every node.
@@ -633,15 +789,24 @@ class SparkContext(
addedFiles(key) = System.currentTimeMillis
// Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
- Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
+ postEnvironmentUpdate()
}
+ /**
+ * :: DeveloperApi ::
+ * Register a listener to receive up-calls from events that happen during execution.
+ */
+ @DeveloperApi
def addSparkListener(listener: SparkListener) {
- dagScheduler.addSparkListener(listener)
+ listenerBus.addListener(listener)
}
+ /** The version of Spark on which this application is running. */
+ def version = SparkContext.SPARK_VERSION
+
/**
* Return a map from the slave to the max memory available for caching and the remaining
* memory available for caching.
@@ -666,10 +831,6 @@ class SparkContext(
*/
def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap
- def getStageInfo: Map[Stage,StageInfo] = {
- dagScheduler.stageToInfos
- }
-
/**
* Return information about blocks stored in all of the slaves
*/
@@ -693,7 +854,7 @@ class SparkContext(
}
/**
- * Return current scheduling mode
+ * Return current scheduling mode
*/
def getSchedulingMode: SchedulingMode.SchedulingMode = {
taskScheduler.schedulingMode
@@ -703,6 +864,7 @@ class SparkContext(
* Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes.
*/
+ @deprecated("adding files no longer creates local copies that need to be deleted", "1.0.0")
def clearFiles() {
addedFiles.clear()
}
@@ -717,6 +879,22 @@ class SparkContext(
dagScheduler.getPreferredLocs(rdd, partition)
}
+ /**
+ * Register an RDD to be persisted in memory and/or disk storage
+ */
+ private[spark] def persistRDD(rdd: RDD[_]) {
+ persistentRdds(rdd.id) = rdd
+ }
+
+ /**
+ * Unpersist an RDD from memory and/or disk storage
+ */
+ private[spark] def unpersistRDD(rddId: Int, blocking: Boolean = true) {
+ env.blockManager.master.removeRdd(rddId, blocking)
+ persistentRdds.remove(rddId)
+ listenerBus.post(SparkListenerUnpersistRDD(rddId))
+ }
+
/**
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
@@ -735,9 +913,11 @@ class SparkContext(
key = uri.getScheme match {
// A JAR file which exists only on the driver node
case null | "file" =>
- if (SparkHadoopUtil.get.isYarnMode() && master == "yarn-standalone") {
- // In order for this to work in yarn standalone mode the user must specify the
- // --addjars option to the client to upload the file into the distributed cache
+ // yarn-standalone is deprecated, but still supported
+ if (SparkHadoopUtil.get.isYarnMode() &&
+ (master == "yarn-standalone" || master == "yarn-cluster")) {
+ // In order for this to work in yarn-cluster mode the user must specify the
+ // --addjars option to the client to upload the file into the distributed cache
// of the AM to make it show up in the current working directory.
val fileName = new Path(uri.getPath).getName()
try {
@@ -745,7 +925,7 @@ class SparkContext(
} catch {
case e: Exception => {
// For now just log an error but allow to go through so spark examples work.
- // The spark examples don't really need the jar distributed since its also
+ // The spark examples don't really need the jar distributed since its also
// the app jar.
logError("Error adding jar (" + e + "), was the --addJars option used?")
null
@@ -766,18 +946,21 @@ class SparkContext(
logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
}
}
+ postEnvironmentUpdate()
}
/**
* Clear the job's list of JARs added by `addJar` so that they do not get downloaded to
* any new nodes.
*/
+ @deprecated("adding jars no longer creates local copies that need to be deleted", "1.0.0")
def clearJars() {
addedJars.clear()
}
/** Shut down the SparkContext. */
def stop() {
+ postApplicationEnd()
ui.stop()
// Do this only if not stopped already - best case effort.
// prevent NPE if stopped more than once.
@@ -785,16 +968,16 @@ class SparkContext(
dagScheduler = null
if (dagSchedulerCopy != null) {
metadataCleaner.cancel()
+ cleaner.foreach(_.stop())
dagSchedulerCopy.stop()
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
- // Clean up locally linked files
- clearFiles()
- clearJars()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
ResultTask.clearCache()
+ listenerBus.stop()
+ eventLogger.foreach(_.stop())
logInfo("Successfully stopped SparkContext")
} else {
logInfo("SparkContext already stopped")
@@ -825,13 +1008,13 @@ class SparkContext(
setLocalProperty("externalCallSite", null)
}
+ /**
+ * Capture the current user callsite and return a formatted version for printing. If the user
+ * has overridden the call site, this will return the user's version.
+ */
private[spark] def getCallSite(): String = {
- val callSite = getLocalProperty("externalCallSite")
- if (callSite == null) {
- Utils.formatSparkCallSite
- } else {
- callSite
- }
+ val defaultCallSite = Utils.getCallSiteInfo
+ Option(getLocalProperty("externalCallSite")).getOrElse(defaultCallSite.toString)
}
/**
@@ -846,6 +1029,9 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
+ if (dagScheduler == null) {
+ throw new SparkException("SparkContext has been shutdown")
+ }
val callSite = getCallSite
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
@@ -923,8 +1109,10 @@ class SparkContext(
}
/**
+ * :: DeveloperApi ::
* Run a job that can return approximate results.
*/
+ @DeveloperApi
def runApproximateJob[T, U, R](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
@@ -940,8 +1128,10 @@ class SparkContext(
}
/**
+ * :: Experimental ::
* Submit a job for execution and return a FutureJob holding the result.
*/
+ @Experimental
def submitJob[T, U, R](
rdd: RDD[T],
processPartition: Iterator[T] => U,
@@ -975,6 +1165,16 @@ class SparkContext(
dagScheduler.cancelAllJobs()
}
+ /** Cancel a given job if it's scheduled or running */
+ private[spark] def cancelJob(jobId: Int) {
+ dagScheduler.cancelJob(jobId)
+ }
+
+ /** Cancel a given stage and all jobs associated with it */
+ private[spark] def cancelStage(stageId: Int) {
+ dagScheduler.cancelStage(stageId)
+ }
+
/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
@@ -1003,8 +1203,12 @@ class SparkContext(
def defaultParallelism: Int = taskScheduler.defaultParallelism
/** Default min number of partitions for Hadoop RDDs when not given by user */
+ @deprecated("use defaultMinPartitions", "1.0.0")
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
+ /** Default min number of partitions for Hadoop RDDs when not given by user */
+ def defaultMinPartitions: Int = math.min(defaultParallelism, 2)
+
private val nextShuffleId = new AtomicInteger(0)
private[spark] def newShuffleId(): Int = nextShuffleId.getAndIncrement()
@@ -1014,6 +1218,29 @@ class SparkContext(
/** Register a new RDD, returning its RDD ID */
private[spark] def newRddId(): Int = nextRddId.getAndIncrement()
+ /** Post the application start event */
+ private def postApplicationStart() {
+ listenerBus.post(SparkListenerApplicationStart(appName, startTime, sparkUser))
+ }
+
+ /** Post the application end event */
+ private def postApplicationEnd() {
+ listenerBus.post(SparkListenerApplicationEnd(System.currentTimeMillis))
+ }
+
+ /** Post the environment update event once the task scheduler is ready */
+ private def postEnvironmentUpdate() {
+ if (taskScheduler != null) {
+ val schedulingMode = getSchedulingMode.toString
+ val addedJarPaths = addedJars.keys.toSeq
+ val addedFilePaths = addedFiles.keys.toSeq
+ val environmentDetails =
+ SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, addedFilePaths)
+ val environmentUpdate = SparkListenerEnvironmentUpdate(environmentDetails)
+ listenerBus.post(environmentUpdate)
+ }
+ }
+
/** Called by MetadataCleaner to clean up the persistentRdds map periodically */
private[spark] def cleanup(cleanupTime: Long) {
persistentRdds.clearOldValues(cleanupTime)
@@ -1024,12 +1251,16 @@ class SparkContext(
* The SparkContext object contains a number of implicit conversions and parameters for use with
* various Spark features.
*/
-object SparkContext {
+object SparkContext extends Logging {
+
+ private[spark] val SPARK_VERSION = "1.0.0"
private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
+ private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel"
+
private[spark] val SPARK_UNKNOWN_USER = ""
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
@@ -1054,16 +1285,18 @@ object SparkContext {
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
- implicit def rddToPairRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) =
+ implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)])
+ (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = {
new PairRDDFunctions(rdd)
+ }
implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd)
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
- rdd: RDD[(K, V)]) =
+ rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
- implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassTag, V: ClassTag](
+ implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag](
rdd: RDD[(K, V)]) =
new OrderedRDDFunctions[K, V, (K, V)](rdd)
@@ -1097,27 +1330,33 @@ object SparkContext {
}
// Helper objects for converting common types to Writable
- private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) = {
+ private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T)
+ : WritableConverter[T] = {
val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]]
new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W]))
}
- implicit def intWritableConverter() = simpleWritableConverter[Int, IntWritable](_.get)
+ implicit def intWritableConverter(): WritableConverter[Int] =
+ simpleWritableConverter[Int, IntWritable](_.get)
- implicit def longWritableConverter() = simpleWritableConverter[Long, LongWritable](_.get)
+ implicit def longWritableConverter(): WritableConverter[Long] =
+ simpleWritableConverter[Long, LongWritable](_.get)
- implicit def doubleWritableConverter() = simpleWritableConverter[Double, DoubleWritable](_.get)
+ implicit def doubleWritableConverter(): WritableConverter[Double] =
+ simpleWritableConverter[Double, DoubleWritable](_.get)
- implicit def floatWritableConverter() = simpleWritableConverter[Float, FloatWritable](_.get)
+ implicit def floatWritableConverter(): WritableConverter[Float] =
+ simpleWritableConverter[Float, FloatWritable](_.get)
- implicit def booleanWritableConverter() =
+ implicit def booleanWritableConverter(): WritableConverter[Boolean] =
simpleWritableConverter[Boolean, BooleanWritable](_.get)
- implicit def bytesWritableConverter() = {
+ implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = {
simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes)
}
- implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString)
+ implicit def stringWritableConverter(): WritableConverter[String] =
+ simpleWritableConverter[String, Text](_.toString)
implicit def writableWritableConverter[T <: Writable]() =
new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T])
@@ -1126,19 +1365,19 @@ object SparkContext {
* Find the JAR from which a given class was loaded, to make it easy for users to pass
* their JARs to SparkContext.
*/
- def jarOfClass(cls: Class[_]): Seq[String] = {
+ def jarOfClass(cls: Class[_]): Option[String] = {
val uri = cls.getResource("/" + cls.getName.replace('.', '/') + ".class")
if (uri != null) {
val uriStr = uri.toString
if (uriStr.startsWith("jar:file:")) {
// URI will be of the form "jar:file:/path/foo.jar!/package/cls.class",
// so pull out the /path/foo.jar
- List(uriStr.substring("jar:file:".length, uriStr.indexOf('!')))
+ Some(uriStr.substring("jar:file:".length, uriStr.indexOf('!')))
} else {
- Nil
+ None
}
} else {
- Nil
+ None
}
}
@@ -1147,7 +1386,7 @@ object SparkContext {
* to pass their JARs to SparkContext. In most cases you can call jarOfObject(this) in
* your driver program.
*/
- def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
+ def jarOfObject(obj: AnyRef): Option[String] = jarOfClass(obj.getClass)
/**
* Creates a modified version of a SparkConf with the parameters that can be passed separately
@@ -1177,11 +1416,9 @@ object SparkContext {
}
/** Creates a task scheduler based on a given master URL. Extracted for testing. */
- private def createTaskScheduler(sc: SparkContext, master: String, appName: String)
- : TaskScheduler =
- {
- // Regular expression used for local[N] master format
- val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
+ private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = {
+ // Regular expression used for local[N] and local[*] master formats
+ val LOCAL_N_REGEX = """local\[([0-9\*]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
@@ -1204,8 +1441,11 @@ object SparkContext {
scheduler
case LOCAL_N_REGEX(threads) =>
+ def localCpuCount = Runtime.getRuntime.availableProcessors()
+ // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads.
+ val threadCount = if (threads == "*") localCpuCount else threads.toInt
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
- val backend = new LocalBackend(scheduler, threads.toInt)
+ val backend = new LocalBackend(scheduler, threadCount)
scheduler.initialize(backend)
scheduler
@@ -1218,7 +1458,7 @@ object SparkContext {
case SPARK_REGEX(sparkUrl) =>
val scheduler = new TaskSchedulerImpl(sc)
val masterUrls = sparkUrl.split(",").map("spark://" + _)
- val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName)
+ val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls)
scheduler.initialize(backend)
scheduler
@@ -1235,14 +1475,18 @@ object SparkContext {
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val masterUrls = localCluster.start()
- val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName)
+ val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls)
scheduler.initialize(backend)
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
}
scheduler
- case "yarn-standalone" =>
+ case "yarn-standalone" | "yarn-cluster" =>
+ if (master == "yarn-standalone") {
+ logWarning(
+ "\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.")
+ }
val scheduler = try {
val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
@@ -1291,9 +1535,9 @@ object SparkContext {
val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false)
val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs
val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, sc, url, appName)
+ new CoarseMesosSchedulerBackend(scheduler, sc, url)
} else {
- new MesosSchedulerBackend(scheduler, sc, url, appName)
+ new MesosSchedulerBackend(scheduler, sc, url)
}
scheduler.initialize(backend)
scheduler
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 7ac65828f670f..bea435ec34ce9 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -17,31 +17,39 @@
package org.apache.spark
+import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.concurrent.Await
+import scala.util.Properties
import akka.actor._
import com.google.common.collect.MapMaker
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.storage.{BlockManager, BlockManagerMaster, BlockManagerMasterActor}
import org.apache.spark.network.ConnectionManager
-import org.apache.spark.serializer.{Serializer, SerializerManager}
+import org.apache.spark.scheduler.LiveListenerBus
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage._
import org.apache.spark.util.{AkkaUtils, Utils}
/**
+ * :: DeveloperApi ::
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
* including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
* Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these
* objects needs to have the right SparkEnv set. You can get the current environment with
* SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
+ *
+ * NOTE: This is not intended for external use. This is exposed for Shark and may be made private
+ * in a future release.
*/
-class SparkEnv private[spark] (
+@DeveloperApi
+class SparkEnv (
val executorId: String,
val actorSystem: ActorSystem,
- val serializerManager: SerializerManager,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
@@ -50,6 +58,7 @@ class SparkEnv private[spark] (
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
+ val securityManager: SecurityManager,
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
@@ -78,7 +87,7 @@ class SparkEnv private[spark] (
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
// down, but let's call it anyway in case it gets fixed in a later release
// UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it.
- //actorSystem.awaitTermination()
+ // actorSystem.awaitTermination()
}
private[spark]
@@ -88,6 +97,14 @@ class SparkEnv private[spark] (
pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
}
}
+
+ private[spark]
+ def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) {
+ synchronized {
+ val key = (pythonExec, envVars)
+ pythonWorkers(key).stop()
+ }
+ }
}
object SparkEnv extends Logging {
@@ -120,10 +137,18 @@ object SparkEnv extends Logging {
hostname: String,
port: Int,
isDriver: Boolean,
- isLocal: Boolean): SparkEnv = {
+ isLocal: Boolean,
+ listenerBus: LiveListenerBus = null): SparkEnv = {
+
+ // Listener bus is only used on the driver
+ if (isDriver) {
+ assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!")
+ }
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port,
- conf = conf)
+ val securityManager = new SecurityManager(conf)
+
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf,
+ securityManager = securityManager)
// Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port),
// figure out which port number Akka actually bound to and set spark.driver.port to it.
@@ -137,17 +162,22 @@ object SparkEnv extends Logging {
// defaultClassName if the property is not set, and return it as a T
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = conf.get(propertyName, defaultClassName)
- Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
+ val cls = Class.forName(name, true, classLoader)
+ // First try with the constructor that takes SparkConf. If we can't find one,
+ // use a no-arg constructor instead.
+ try {
+ cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
+ } catch {
+ case _: NoSuchMethodException =>
+ cls.getConstructor().newInstance().asInstanceOf[T]
+ }
}
- val serializerManager = new SerializerManager
+ val serializer = instantiateClass[Serializer](
+ "spark.serializer", "org.apache.spark.serializer.JavaSerializer")
- val serializer = serializerManager.setDefault(
- conf.get("spark.serializer", "org.apache.spark.serializer.JavaSerializer"), conf)
-
- val closureSerializer = serializerManager.get(
- conf.get("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"),
- conf)
+ val closureSerializer = instantiateClass[Serializer](
+ "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
@@ -164,40 +194,42 @@ object SparkEnv extends Logging {
}
}
+ val mapOutputTracker = if (isDriver) {
+ new MapOutputTrackerMaster(conf)
+ } else {
+ new MapOutputTrackerWorker(conf)
+ }
+
+ // Have to assign trackerActor after initialization as MapOutputTrackerActor
+ // requires the MapOutputTracker itself
+ mapOutputTracker.trackerActor = registerOrLookup(
+ "MapOutputTracker",
+ new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
+
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
- new BlockManagerMasterActor(isLocal, conf)), conf)
+ new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf)
+
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
- serializer, conf)
+ serializer, conf, securityManager, mapOutputTracker)
val connectionManager = blockManager.connectionManager
- val broadcastManager = new BroadcastManager(isDriver, conf)
+ val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
val cacheManager = new CacheManager(blockManager)
- // Have to assign trackerActor after initialization as MapOutputTrackerActor
- // requires the MapOutputTracker itself
- val mapOutputTracker = if (isDriver) {
- new MapOutputTrackerMaster(conf)
- } else {
- new MapOutputTracker(conf)
- }
- mapOutputTracker.trackerActor = registerOrLookup(
- "MapOutputTracker",
- new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
-
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
- val httpFileServer = new HttpFileServer()
+ val httpFileServer = new HttpFileServer(securityManager)
httpFileServer.initialize()
conf.set("spark.fileserver.uri", httpFileServer.serverUri)
val metricsSystem = if (isDriver) {
- MetricsSystem.createMetricsSystem("driver", conf)
+ MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
- MetricsSystem.createMetricsSystem("executor", conf)
+ MetricsSystem.createMetricsSystem("executor", conf, securityManager)
}
metricsSystem.start()
@@ -219,7 +251,6 @@ object SparkEnv extends Logging {
new SparkEnv(
executorId,
actorSystem,
- serializerManager,
serializer,
closureSerializer,
cacheManager,
@@ -228,9 +259,63 @@ object SparkEnv extends Logging {
broadcastManager,
blockManager,
connectionManager,
+ securityManager,
httpFileServer,
sparkFilesDir,
metricsSystem,
conf)
}
+
+ /**
+ * Return a map representation of jvm information, Spark properties, system properties, and
+ * class paths. Map keys define the category, and map values represent the corresponding
+ * attributes as a sequence of KV pairs. This is used mainly for SparkListenerEnvironmentUpdate.
+ */
+ private[spark]
+ def environmentDetails(
+ conf: SparkConf,
+ schedulingMode: String,
+ addedJars: Seq[String],
+ addedFiles: Seq[String]): Map[String, Seq[(String, String)]] = {
+
+ val jvmInformation = Seq(
+ ("Java Version", "%s (%s)".format(Properties.javaVersion, Properties.javaVendor)),
+ ("Java Home", Properties.javaHome),
+ ("Scala Version", Properties.versionString),
+ ("Scala Home", Properties.scalaHome)
+ ).sorted
+
+ // Spark properties
+ // This includes the scheduling mode whether or not it is configured (used by SparkUI)
+ val schedulerMode =
+ if (!conf.contains("spark.scheduler.mode")) {
+ Seq(("spark.scheduler.mode", schedulingMode))
+ } else {
+ Seq[(String, String)]()
+ }
+ val sparkProperties = (conf.getAll ++ schedulerMode).sorted
+
+ // System properties that are not java classpaths
+ val systemProperties = System.getProperties.iterator.toSeq
+ val otherProperties = systemProperties.filter { case (k, v) =>
+ k != "java.class.path" && !k.startsWith("spark.")
+ }.sorted
+
+ // Class paths including all added jars and files
+ val classPathProperty = systemProperties.find { case (k, v) =>
+ k == "java.class.path"
+ }.getOrElse(("", ""))
+ val classPathEntries = classPathProperty._2
+ .split(conf.get("path.separator", ":"))
+ .filterNot(e => e.isEmpty)
+ .map(e => (e, "System Classpath"))
+ val addedJarsAndFiles = (addedJars ++ addedFiles).map((_, "Added By User"))
+ val classPaths = (addedJarsAndFiles ++ classPathEntries).sorted
+
+ Map[String, Seq[(String, String)]](
+ "JVM Information" -> jvmInformation,
+ "Spark Properties" -> sparkProperties,
+ "System Properties" -> otherProperties,
+ "Classpath Entries" -> classPaths)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala
index d34e47e8cac22..4351ed74b67fc 100644
--- a/core/src/main/scala/org/apache/spark/SparkException.scala
+++ b/core/src/main/scala/org/apache/spark/SparkException.scala
@@ -20,5 +20,5 @@ package org.apache.spark
class SparkException(message: String, cause: Throwable)
extends Exception(message, cause) {
- def this(message: String) = this(message, null)
+ def this(message: String) = this(message, null)
}
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index d404459a8eb7e..f6703986bdf11 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -15,28 +15,26 @@
* limitations under the License.
*/
-package org.apache.hadoop.mapred
+package org.apache.spark
import java.io.IOException
import java.text.NumberFormat
import java.text.SimpleDateFormat
import java.util.Date
+import org.apache.hadoop.mapred._
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
-import org.apache.spark.Logging
-import org.apache.spark.SerializableWritable
+import org.apache.spark.rdd.HadoopRDD
/**
- * Internal helper class that saves an RDD using a Hadoop OutputFormat. This is only public
- * because we need to access this class from the `spark` package to use some package-private Hadoop
- * functions, but this class should not be used directly by users.
+ * Internal helper class that saves an RDD using a Hadoop OutputFormat.
*
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
-private[apache]
+private[spark]
class SparkHadoopWriter(@transient jobConf: JobConf)
extends Logging
with SparkHadoopMapRedUtil
@@ -44,7 +42,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
-
+
private var jobID = 0
private var splitID = 0
private var attemptID = 0
@@ -59,23 +57,24 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
def preSetup() {
setIDs(0, 0, 0)
- setConfParams()
-
- val jCtxt = getJobContext()
+ HadoopRDD.addLocalConfiguration("", 0, 0, 0, conf.value)
+
+ val jCtxt = getJobContext()
getOutputCommitter().setupJob(jCtxt)
}
def setup(jobid: Int, splitid: Int, attemptid: Int) {
setIDs(jobid, splitid, attemptid)
- setConfParams()
+ HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(now),
+ jobid, splitID, attemptID, conf.value)
}
def open() {
val numfmt = NumberFormat.getInstance()
numfmt.setMinimumIntegerDigits(5)
numfmt.setGroupingUsed(false)
-
+
val outputName = "part-" + numfmt.format(splitID)
val path = FileOutputFormat.getOutputPath(conf.value)
val fs: FileSystem = {
@@ -86,7 +85,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
}
}
- getOutputCommitter().setupTask(getTaskContext())
+ getOutputCommitter().setupTask(getTaskContext())
writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL)
}
@@ -104,18 +103,18 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
def commit() {
val taCtxt = getTaskContext()
- val cmtr = getOutputCommitter()
+ val cmtr = getOutputCommitter()
if (cmtr.needsTaskCommit(taCtxt)) {
try {
cmtr.commitTask(taCtxt)
logInfo (taID + ": Committed")
} catch {
- case e: IOException => {
+ case e: IOException => {
logError("Error committing the output of task: " + taID.value, e)
cmtr.abortTask(taCtxt)
throw e
}
- }
+ }
} else {
logWarning ("No need to commit output of task: " + taID.value)
}
@@ -145,7 +144,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
}
private def getJobContext(): JobContext = {
- if (jobContext == null) {
+ if (jobContext == null) {
jobContext = newJobContext(conf.value, jID.value)
}
jobContext
@@ -167,24 +166,16 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
taID = new SerializableWritable[TaskAttemptID](
new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID))
}
-
- private def setConfParams() {
- conf.value.set("mapred.job.id", jID.value.toString)
- conf.value.set("mapred.tip.id", taID.value.getTaskID.toString)
- conf.value.set("mapred.task.id", taID.value.toString)
- conf.value.setBoolean("mapred.task.is.map", true)
- conf.value.setInt("mapred.task.partition", splitID)
- }
}
-private[apache]
+private[spark]
object SparkHadoopWriter {
def createJobID(time: Date, id: Int): JobID = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
- val jobtrackerID = formatter.format(new Date())
+ val jobtrackerID = formatter.format(time)
new JobID(jobtrackerID, id)
}
-
+
def createPathFromString(path: String, conf: JobConf): Path = {
if (path == null) {
throw new IllegalArgumentException("Output path is null")
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
new file mode 100644
index 0000000000000..65003b6ac6a0a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.IOException
+import javax.security.auth.callback.Callback
+import javax.security.auth.callback.CallbackHandler
+import javax.security.auth.callback.NameCallback
+import javax.security.auth.callback.PasswordCallback
+import javax.security.auth.callback.UnsupportedCallbackException
+import javax.security.sasl.RealmCallback
+import javax.security.sasl.RealmChoiceCallback
+import javax.security.sasl.Sasl
+import javax.security.sasl.SaslClient
+import javax.security.sasl.SaslException
+
+import scala.collection.JavaConversions.mapAsJavaMap
+
+/**
+ * Implements SASL Client logic for Spark
+ */
+private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging {
+
+ /**
+ * Used to respond to server's counterpart, SaslServer with SASL tokens
+ * represented as byte arrays.
+ *
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST),
+ null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
+ new SparkSaslClientCallbackHandler(securityMgr))
+
+ /**
+ * Used to initiate SASL handshake with server.
+ * @return response to challenge if needed
+ */
+ def firstToken(): Array[Byte] = {
+ synchronized {
+ val saslToken: Array[Byte] =
+ if (saslClient != null && saslClient.hasInitialResponse()) {
+ logDebug("has initial response")
+ saslClient.evaluateChallenge(new Array[Byte](0))
+ } else {
+ new Array[Byte](0)
+ }
+ saslToken
+ }
+ }
+
+ /**
+ * Determines whether the authentication exchange has completed.
+ * @return true is complete, otherwise false
+ */
+ def isComplete(): Boolean = {
+ synchronized {
+ if (saslClient != null) saslClient.isComplete() else false
+ }
+ }
+
+ /**
+ * Respond to server's SASL token.
+ * @param saslTokenMessage contains server's SASL token
+ * @return client's response SASL token
+ */
+ def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
+ synchronized {
+ if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslClient might be using.
+ */
+ def dispose() {
+ synchronized {
+ if (saslClient != null) {
+ try {
+ saslClient.dispose()
+ } catch {
+ case e: SaslException => // ignored
+ } finally {
+ saslClient = null
+ }
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * that works with share secrets.
+ */
+ private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends
+ CallbackHandler {
+
+ private val userName: String =
+ SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8"))
+ private val secretKey = securityMgr.getSecretKey()
+ private val userPassword: Array[Char] = SparkSaslServer.encodePassword(
+ if (secretKey != null) secretKey.getBytes("utf-8") else "".getBytes("utf-8"))
+
+ /**
+ * Implementation used to respond to SASL request from the server.
+ *
+ * @param callbacks objects that indicate what credential information the
+ * server's SaslServer requires from the client.
+ */
+ override def handle(callbacks: Array[Callback]) {
+ logDebug("in the sasl client callback handler")
+ callbacks foreach {
+ case nc: NameCallback => {
+ logDebug("handle: SASL client callback: setting username: " + userName)
+ nc.setName(userName)
+ }
+ case pc: PasswordCallback => {
+ logDebug("handle: SASL client callback: setting userPassword")
+ pc.setPassword(userPassword)
+ }
+ case rc: RealmCallback => {
+ logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText())
+ rc.setText(rc.getDefaultText())
+ }
+ case cb: RealmChoiceCallback => {}
+ case cb: Callback => throw
+ new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback")
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
new file mode 100644
index 0000000000000..f6b0a9132aca4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import javax.security.auth.callback.Callback
+import javax.security.auth.callback.CallbackHandler
+import javax.security.auth.callback.NameCallback
+import javax.security.auth.callback.PasswordCallback
+import javax.security.auth.callback.UnsupportedCallbackException
+import javax.security.sasl.AuthorizeCallback
+import javax.security.sasl.RealmCallback
+import javax.security.sasl.Sasl
+import javax.security.sasl.SaslException
+import javax.security.sasl.SaslServer
+import scala.collection.JavaConversions.mapAsJavaMap
+import org.apache.commons.net.util.Base64
+
+/**
+ * Encapsulates SASL server logic
+ */
+private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging {
+
+ /**
+ * Actual SASL work done by this object from javax.security.sasl.
+ */
+ private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null,
+ SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
+ new SparkSaslDigestCallbackHandler(securityMgr))
+
+ /**
+ * Determines whether the authentication exchange has completed.
+ * @return true is complete, otherwise false
+ */
+ def isComplete(): Boolean = {
+ synchronized {
+ if (saslServer != null) saslServer.isComplete() else false
+ }
+ }
+
+ /**
+ * Used to respond to server SASL tokens.
+ * @param token Server's SASL token
+ * @return response to send back to the server.
+ */
+ def response(token: Array[Byte]): Array[Byte] = {
+ synchronized {
+ if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslServer might be using.
+ */
+ def dispose() {
+ synchronized {
+ if (saslServer != null) {
+ try {
+ saslServer.dispose()
+ } catch {
+ case e: SaslException => // ignore
+ } finally {
+ saslServer = null
+ }
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * for SASL DIGEST-MD5 mechanism
+ */
+ private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager)
+ extends CallbackHandler {
+
+ private val userName: String =
+ SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8"))
+
+ override def handle(callbacks: Array[Callback]) {
+ logDebug("In the sasl server callback handler")
+ callbacks foreach {
+ case nc: NameCallback => {
+ logDebug("handle: SASL server callback: setting username")
+ nc.setName(userName)
+ }
+ case pc: PasswordCallback => {
+ logDebug("handle: SASL server callback: setting userPassword")
+ val password: Array[Char] =
+ SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes("utf-8"))
+ pc.setPassword(password)
+ }
+ case rc: RealmCallback => {
+ logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText())
+ rc.setText(rc.getDefaultText())
+ }
+ case ac: AuthorizeCallback => {
+ val authid = ac.getAuthenticationID()
+ val authzid = ac.getAuthorizationID()
+ if (authid.equals(authzid)) {
+ logDebug("set auth to true")
+ ac.setAuthorized(true)
+ } else {
+ logDebug("set auth to false")
+ ac.setAuthorized(false)
+ }
+ if (ac.isAuthorized()) {
+ logDebug("sasl server is authorized")
+ ac.setAuthorizedID(authzid)
+ }
+ }
+ case cb: Callback => throw
+ new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback")
+ }
+ }
+ }
+}
+
+private[spark] object SparkSaslServer {
+
+ /**
+ * This is passed as the server name when creating the sasl client/server.
+ * This could be changed to be configurable in the future.
+ */
+ val SASL_DEFAULT_REALM = "default"
+
+ /**
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ val DIGEST = "DIGEST-MD5"
+
+ /**
+ * The quality of protection is just "auth". This means that we are doing
+ * authentication only, we are not supporting integrity or privacy protection of the
+ * communication channel after authentication. This could be changed to be configurable
+ * in the future.
+ */
+ val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true")
+
+ /**
+ * Encode a byte[] identifier as a Base64-encoded string.
+ *
+ * @param identifier identifier to encode
+ * @return Base64-encoded string
+ */
+ def encodeIdentifier(identifier: Array[Byte]): String = {
+ new String(Base64.encodeBase64(identifier), "utf-8")
+ }
+
+ /**
+ * Encode a password as a base64-encoded char[] array.
+ * @param password as a byte array.
+ * @return password as a char array.
+ */
+ def encodePassword(password: Array[Byte]): Array[Char] = {
+ new String(Base64.encodeBase64(password), "utf-8").toCharArray()
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index cae983ed4c652..dc012cc381346 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -19,15 +19,21 @@ package org.apache.spark
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+/**
+ * :: DeveloperApi ::
+ * Contextual information about a task which can be read or mutated during execution.
+ */
+@DeveloperApi
class TaskContext(
val stageId: Int,
val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
@volatile var interrupted: Boolean = false,
- private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty()
+ private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty
) extends Serializable {
@deprecated("use partitionId", "0.8.1")
@@ -46,6 +52,7 @@ class TaskContext(
}
def executeOnCompleteCallbacks() {
- onCompleteCallbacks.foreach{_()}
+ // Process complete callbacks in the reverse order of registration
+ onCompleteCallbacks.reverse.foreach{_()}
}
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 3fd6f5eb472f4..a3074916d13e7 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -17,29 +17,35 @@
package org.apache.spark
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
/**
+ * :: DeveloperApi ::
* Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry
* tasks several times for "ephemeral" failures, and only report back failures that require some
* old stages to be resubmitted, such as shuffle map fetch failures.
*/
-private[spark] sealed trait TaskEndReason
+@DeveloperApi
+sealed trait TaskEndReason
-private[spark] case object Success extends TaskEndReason
+@DeveloperApi
+case object Success extends TaskEndReason
-private[spark]
+@DeveloperApi
case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
-private[spark] case class FetchFailed(
+@DeveloperApi
+case class FetchFailed(
bmAddress: BlockManagerId,
shuffleId: Int,
mapId: Int,
reduceId: Int)
extends TaskEndReason
-private[spark] case class ExceptionFailure(
+@DeveloperApi
+case class ExceptionFailure(
className: String,
description: String,
stackTrace: Array[StackTraceElement],
@@ -47,22 +53,28 @@ private[spark] case class ExceptionFailure(
extends TaskEndReason
/**
+ * :: DeveloperApi ::
* The task finished successfully, but the result was lost from the executor's block manager before
* it was fetched.
*/
-private[spark] case object TaskResultLost extends TaskEndReason
+@DeveloperApi
+case object TaskResultLost extends TaskEndReason
-private[spark] case object TaskKilled extends TaskEndReason
+@DeveloperApi
+case object TaskKilled extends TaskEndReason
/**
+ * :: DeveloperApi ::
* The task failed because the executor that it was running on was lost. This may happen because
* the task crashed the JVM.
*/
-private[spark] case object ExecutorLostFailure extends TaskEndReason
+@DeveloperApi
+case object ExecutorLostFailure extends TaskEndReason
/**
+ * :: DeveloperApi ::
* We don't know why the task ended -- for example, because of a ClassNotFound exception when
* deserializing the task result.
*/
-private[spark] case object UnknownReason extends TaskEndReason
-
+@DeveloperApi
+case object UnknownReason extends TaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/TaskKilledException.scala b/core/src/main/scala/org/apache/spark/TaskKilledException.scala
new file mode 100644
index 0000000000000..cbd6b2866e4f9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskKilledException.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * Exception for a task getting killed.
+ */
+private[spark] class TaskKilledException extends RuntimeException
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
new file mode 100644
index 0000000000000..8ae02154823ee
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.{File, FileInputStream, FileOutputStream}
+import java.net.{URI, URL}
+import java.util.jar.{JarEntry, JarOutputStream}
+
+import scala.collection.JavaConversions._
+
+import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
+import com.google.common.io.Files
+
+/**
+ * Utilities for tests. Included in main codebase since it's used by multiple
+ * projects.
+ *
+ * TODO: See if we can move this to the test codebase by specifying
+ * test dependencies between projects.
+ */
+private[spark] object TestUtils {
+
+ /**
+ * Create a jar that defines classes with the given names.
+ *
+ * Note: if this is used during class loader tests, class names should be unique
+ * in order to avoid interference between tests.
+ */
+ def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = {
+ val tempDir = Files.createTempDir()
+ val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value)
+ val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
+ createJar(files, jarFile)
+ }
+
+
+ /**
+ * Create a jar file that contains this set of files. All files will be located at the root
+ * of the jar.
+ */
+ def createJar(files: Seq[File], jarFile: File): URL = {
+ val jarFileStream = new FileOutputStream(jarFile)
+ val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest())
+
+ for (file <- files) {
+ val jarEntry = new JarEntry(file.getName)
+ jarStream.putNextEntry(jarEntry)
+
+ val in = new FileInputStream(file)
+ val buffer = new Array[Byte](10240)
+ var nRead = 0
+ while (nRead <= 0) {
+ nRead = in.read(buffer, 0, buffer.length)
+ jarStream.write(buffer, 0, nRead)
+ }
+ in.close()
+ }
+ jarStream.close()
+ jarFileStream.close()
+
+ jarFile.toURI.toURL
+ }
+
+ // Adapted from the JavaCompiler.java doc examples
+ private val SOURCE = JavaFileObject.Kind.SOURCE
+ private def createURI(name: String) = {
+ URI.create(s"string:///${name.replace(".", "/")}${SOURCE.extension}")
+ }
+
+ private class JavaSourceFromString(val name: String, val code: String)
+ extends SimpleJavaFileObject(createURI(name), SOURCE) {
+ override def getCharContent(ignoreEncodingErrors: Boolean) = code
+ }
+
+ /** Creates a compiled class with the given name. Class file will be placed in destDir. */
+ def createCompiledClass(className: String, destDir: File, value: String = ""): File = {
+ val compiler = ToolProvider.getSystemJavaCompiler
+ val sourceFile = new JavaSourceFromString(className,
+ "public class " + className + " { @Override public String toString() { " +
+ "return \"" + value + "\";}}")
+
+ // Calling this outputs a class file in pwd. It's easier to just rename the file than
+ // build a custom FileManager that controls the output location.
+ compiler.getTask(null, null, null, null, null, Seq(sourceFile)).call()
+
+ val fileName = className + ".class"
+ val result = new File(fileName)
+ assert(result.exists(), "Compiled file not found: " + result.getAbsolutePath())
+ val out = new File(destDir, fileName)
+
+ // renameTo cannot handle in and out files in different filesystems
+ // use google's Files.move instead
+ Files.move(result, out)
+
+ assert(out.exists(), "Destination file not moved: " + out.getAbsolutePath())
+ out
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java b/core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java
new file mode 100644
index 0000000000000..db7b25c727d34
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.annotation;
+
+import java.lang.annotation.*;
+
+/**
+ * A new component of Spark which may have unstable API's.
+ *
+ * NOTE: If there exists a Scaladoc comment that immediately precedes this annotation, the first
+ * line of the comment must be ":: AlphaComponent ::" with no trailing blank line. This is because
+ * of the known issue that Scaladoc displays only either the annotation or the comment, whichever
+ * comes first.
+ */
+@Retention(RetentionPolicy.RUNTIME)
+@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER,
+ ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE})
+public @interface AlphaComponent {}
diff --git a/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java b/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java
new file mode 100644
index 0000000000000..0ecef6db0e039
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.annotation;
+
+import java.lang.annotation.*;
+
+/**
+ * A lower-level, unstable API intended for developers.
+ *
+ * Developer API's might change or be removed in minor versions of Spark.
+ *
+ * NOTE: If there exists a Scaladoc comment that immediately precedes this annotation, the first
+ * line of the comment must be ":: DeveloperApi ::" with no trailing blank line. This is because
+ * of the known issue that Scaladoc displays only either the annotation or the comment, whichever
+ * comes first.
+ */
+@Retention(RetentionPolicy.RUNTIME)
+@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER,
+ ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE})
+public @interface DeveloperApi {}
diff --git a/core/src/main/scala/org/apache/spark/annotation/Experimental.java b/core/src/main/scala/org/apache/spark/annotation/Experimental.java
new file mode 100644
index 0000000000000..ff8120291455f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/annotation/Experimental.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.annotation;
+
+import java.lang.annotation.*;
+
+/**
+ * An experimental user-facing API.
+ *
+ * Experimental API's might change or be removed in minor versions of Spark, or be adopted as
+ * first-class Spark API's.
+ *
+ * NOTE: If there exists a Scaladoc comment that immediately precedes this annotation, the first
+ * line of the comment must be ":: Experimental ::" with no trailing blank line. This is because
+ * of the known issue that Scaladoc displays only either the annotation or the comment, whichever
+ * comes first.
+ */
+@Retention(RetentionPolicy.RUNTIME)
+@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER,
+ ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE})
+public @interface Experimental {}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index 071044463d980..a6123bd108c11 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -19,15 +19,18 @@ package org.apache.spark.api.java
import java.lang.{Double => JDouble}
+import scala.language.implicitConversions
import scala.reflect.ClassTag
import org.apache.spark.Partitioner
import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions
+import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.StatCounter
+import org.apache.spark.util.Utils
class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, JavaDoubleRDD] {
@@ -83,7 +86,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: JFunction[JDouble, java.lang.Boolean]): JavaDoubleRDD =
- fromRDD(srdd.filter(x => f(x).booleanValue()))
+ fromRDD(srdd.filter(x => f.call(x).booleanValue()))
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
@@ -131,7 +134,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
/**
* Return a sampled subset of this RDD.
*/
- def sample(withReplacement: Boolean, fraction: JDouble, seed: Int): JavaDoubleRDD =
+ def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD =
+ sample(withReplacement, fraction, Utils.random.nextLong)
+
+ /**
+ * Return a sampled subset of this RDD.
+ */
+ def sample(withReplacement: Boolean, fraction: JDouble, seed: Long): JavaDoubleRDD =
fromRDD(srdd.sample(withReplacement, fraction, seed))
/**
@@ -140,6 +149,14 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
*/
def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd))
+ /**
+ * Return the intersection of this RDD and another one. The output will not contain any duplicate
+ * elements, even if the input RDDs did.
+ *
+ * Note that this method performs a shuffle internally.
+ */
+ def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd))
+
// Double RDD functions
/** Add up the elements in this RDD. */
@@ -176,14 +193,26 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
def meanApprox(timeout: Long, confidence: JDouble): PartialResult[BoundedDouble] =
srdd.meanApprox(timeout, confidence)
- /** (Experimental) Approximate operation to return the mean within a timeout. */
+ /**
+ * :: Experimental ::
+ * Approximate operation to return the mean within a timeout.
+ */
+ @Experimental
def meanApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.meanApprox(timeout)
- /** (Experimental) Approximate operation to return the sum within a timeout. */
+ /**
+ * :: Experimental ::
+ * Approximate operation to return the sum within a timeout.
+ */
+ @Experimental
def sumApprox(timeout: Long, confidence: JDouble): PartialResult[BoundedDouble] =
srdd.sumApprox(timeout, confidence)
- /** (Experimental) Approximate operation to return the sum within a timeout. */
+ /**
+ * :: Experimental ::
+ * Approximate operation to return the sum within a timeout.
+ */
+ @Experimental
def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout)
/**
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 3f672900cb90f..554c065358648 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -18,8 +18,10 @@
package org.apache.spark.api.java
import java.util.{Comparator, List => JList}
+import java.lang.{Iterable => JIterable}
import scala.collection.JavaConversions._
+import scala.language.implicitConversions
import scala.reflect.ClassTag
import com.google.common.base.Optional
@@ -31,11 +33,13 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
import org.apache.spark.{HashPartitioner, Partitioner}
import org.apache.spark.Partitioner._
import org.apache.spark.SparkContext.rddToPairRDDFunctions
+import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
-import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
+import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction}
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
(implicit val kClassTag: ClassTag[K], implicit val vClassTag: ClassTag[V])
@@ -89,7 +93,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] =
- new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue()))
+ new JavaPairRDD[K, V](rdd.filter(x => f.call(x).booleanValue()))
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
@@ -116,7 +120,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
/**
* Return a sampled subset of this RDD.
*/
- def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
+ def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] =
+ sample(withReplacement, fraction, Utils.random.nextLong)
+
+ /**
+ * Return a sampled subset of this RDD.
+ */
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))
/**
@@ -126,6 +136,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.union(other.rdd))
+ /**
+ * Return the intersection of this RDD and another one. The output will not contain any duplicate
+ * elements, even if the input RDDs did.
+ *
+ * Note that this method performs a shuffle internally.
+ */
+ def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd.intersection(other.rdd))
+
+
// first() has to be overridden here so that the generated method has the signature
// 'public scala.Tuple2 first()'; if the trait's definition is used,
// then the method has the signature 'public java.lang.Object first()',
@@ -165,9 +185,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Simplified version of combineByKey that hash-partitions the output RDD.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
- mergeValue: JFunction2[C, V, C],
- mergeCombiners: JFunction2[C, C, C],
- numPartitions: Int): JavaPairRDD[K, C] =
+ mergeValue: JFunction2[C, V, C],
+ mergeCombiners: JFunction2[C, C, C],
+ numPartitions: Int): JavaPairRDD[K, C] =
combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
/**
@@ -190,16 +210,20 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
/**
- * (Experimental) Approximate version of countByKey that can return a partial result if it does
+ * :: Experimental ::
+ * Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
+ @Experimental
def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
/**
- * (Experimental) Approximate version of countByKey that can return a partial result if it does
+ * :: Experimental ::
+ * Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
+ @Experimental
def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
: PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
@@ -240,14 +264,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Group the values for each key in the RDD into a single sequence. Allows controlling the
* partitioning of the resulting key-value pair RDD by passing a Partitioner.
*/
- def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JList[V]] =
+ def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JIterable[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey(partitioner)))
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
* resulting RDD with into `numPartitions` partitions.
*/
- def groupByKey(numPartitions: Int): JavaPairRDD[K, JList[V]] =
+ def groupByKey(numPartitions: Int): JavaPairRDD[K, JIterable[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey(numPartitions)))
/**
@@ -357,7 +381,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
* resulting RDD with the existing partitioner/parallelism level.
*/
- def groupByKey(): JavaPairRDD[K, JList[V]] =
+ def groupByKey(): JavaPairRDD[K, JIterable[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey()))
/**
@@ -442,7 +466,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
*/
def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = {
import scala.collection.JavaConverters._
- def fn = (x: V) => f.apply(x).asScala
+ def fn = (x: V) => f.call(x).asScala
implicit val ctag: ClassTag[U] = fakeClassTag
fromRDD(rdd.flatMapValues(fn))
}
@@ -452,7 +476,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
- : JavaPairRDD[K, (JList[V], JList[W])] =
+ : JavaPairRDD[K, (JIterable[V], JIterable[W])] =
fromRDD(cogroupResultToJava(rdd.cogroup(other, partitioner)))
/**
@@ -460,14 +484,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2],
- partitioner: Partitioner): JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
+ partitioner: Partitioner): JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner)))
/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
- def cogroup[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
+ def cogroup[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], JIterable[W])] =
fromRDD(cogroupResultToJava(rdd.cogroup(other)))
/**
@@ -475,7 +499,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2])
- : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
+ : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2)))
/**
@@ -483,7 +507,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: JavaPairRDD[K, W], numPartitions: Int)
- : JavaPairRDD[K, (JList[V], JList[W])] =
+ : JavaPairRDD[K, (JIterable[V], JIterable[W])] =
fromRDD(cogroupResultToJava(rdd.cogroup(other, numPartitions)))
/**
@@ -491,16 +515,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numPartitions: Int)
- : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
+ : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions)))
/** Alias for cogroup. */
- def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
+ def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], JIterable[W])] =
fromRDD(cogroupResultToJava(rdd.groupWith(other)))
/** Alias for cogroup. */
def groupWith[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2])
- : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
+ : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2)))
/**
@@ -511,49 +535,57 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
/** Output the RDD to any Hadoop-supported file system. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
- outputFormatClass: Class[F],
- conf: JobConf) {
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ conf: JobConf) {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
}
/** Output the RDD to any Hadoop-supported file system. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
- outputFormatClass: Class[F]) {
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F]) {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
/** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
- outputFormatClass: Class[F],
- codec: Class[_ <: CompressionCodec]) {
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ codec: Class[_ <: CompressionCodec]) {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec)
}
/** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
- outputFormatClass: Class[F],
- conf: Configuration) {
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ conf: Configuration) {
rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
}
+ /**
+ * Output the RDD to any Hadoop-supported storage system, using
+ * a Configuration object for that storage system.
+ */
+ def saveAsNewAPIHadoopDataset(conf: Configuration) {
+ rdd.saveAsNewAPIHadoopDataset(conf)
+ }
+
/** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
- outputFormatClass: Class[F]) {
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F]) {
rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
@@ -601,10 +633,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* order of the keys).
*/
def sortByKey(comp: Comparator[K], ascending: Boolean): JavaPairRDD[K, V] = {
- class KeyOrdering(val a: K) extends Ordered[K] {
- override def compare(b: K) = comp.compare(a, b)
- }
- implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
+ implicit val ordering = comp // Allow implicit conversion of Comparator to Ordering.
fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending))
}
@@ -615,10 +644,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* order of the keys).
*/
def sortByKey(comp: Comparator[K], ascending: Boolean, numPartitions: Int): JavaPairRDD[K, V] = {
- class KeyOrdering(val a: K) extends Ordered[K] {
- override def compare(b: K) = comp.compare(a, b)
- }
- implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
+ implicit val ordering = comp // Allow implicit conversion of Comparator to Ordering.
fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending, numPartitions))
}
@@ -677,21 +703,22 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
object JavaPairRDD {
private[spark]
- def groupByResultToJava[K: ClassTag, T](rdd: RDD[(K, Seq[T])]): RDD[(K, JList[T])] = {
- rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList)
+ def groupByResultToJava[K: ClassTag, T](rdd: RDD[(K, Iterable[T])]): RDD[(K, JIterable[T])] = {
+ rddToPairRDDFunctions(rdd).mapValues(asJavaIterable)
}
private[spark]
def cogroupResultToJava[K: ClassTag, V, W](
- rdd: RDD[(K, (Seq[V], Seq[W]))]): RDD[(K, (JList[V], JList[W]))] = {
- rddToPairRDDFunctions(rdd).mapValues(x => (seqAsJavaList(x._1), seqAsJavaList(x._2)))
+ rdd: RDD[(K, (Iterable[V], Iterable[W]))]): RDD[(K, (JIterable[V], JIterable[W]))] = {
+ rddToPairRDDFunctions(rdd).mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2)))
}
private[spark]
def cogroupResult2ToJava[K: ClassTag, V, W1, W2](
- rdd: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))]): RDD[(K, (JList[V], JList[W1], JList[W2]))] = {
+ rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))])
+ : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2]))] = {
rddToPairRDDFunctions(rdd)
- .mapValues(x => (seqAsJavaList(x._1), seqAsJavaList(x._2), seqAsJavaList(x._3)))
+ .mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3)))
}
def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = {
@@ -700,6 +727,15 @@ object JavaPairRDD {
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
+ private[spark]
+ implicit def toScalaFunction2[T1, T2, R](fun: JFunction2[T1, T2, R]): Function2[T1, T2, R] = {
+ (x: T1, x1: T2) => fun.call(x, x1)
+ }
+
+ private[spark] implicit def toScalaFunction[T, R](fun: JFunction[T, R]): T => R = x => fun.call(x)
+
+ private[spark]
+ implicit def pairFunToScalaFun[A, B, C](x: PairFunction[A, B, C]): A => (B, C) = y => x.call(y)
/** Convert a JavaRDD of key-value pairs to JavaPairRDD. */
def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = {
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index 0055c98844ded..dc698dea75e43 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -17,12 +17,14 @@
package org.apache.spark.api.java
+import scala.language.implicitConversions
import scala.reflect.ClassTag
import org.apache.spark._
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
extends JavaRDDLike[T, JavaRDD[T]] {
@@ -70,7 +72,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] =
- wrapRDD(rdd.filter((x => f(x).booleanValue())))
+ wrapRDD(rdd.filter((x => f.call(x).booleanValue())))
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
@@ -97,7 +99,13 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
/**
* Return a sampled subset of this RDD.
*/
- def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
+ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] =
+ sample(withReplacement, fraction, Utils.random.nextLong)
+
+ /**
+ * Return a sampled subset of this RDD.
+ */
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] =
wrapRDD(rdd.sample(withReplacement, fraction, seed))
/**
@@ -106,6 +114,15 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
*/
def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd))
+
+ /**
+ * Return the intersection of this RDD and another one. The output will not contain any duplicate
+ * elements, even if the input RDDs did.
+ *
+ * Note that this method performs a shuffle internally.
+ */
+ def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd))
+
/**
* Return an RDD with the elements from `this` that are not in `other`.
*
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 24a9925dbd22c..574a98636a619 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -17,9 +17,9 @@
package org.apache.spark.api.java
-import java.util.{Comparator, List => JList}
+import java.util.{Comparator, List => JList, Iterator => JIterator}
+import java.lang.{Iterable => JIterable}
-import scala.Tuple2
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
@@ -27,12 +27,14 @@ import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD._
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def wrapRDD(rdd: RDD[T]): This
@@ -67,14 +69,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[R](f: JFunction[T, R]): JavaRDD[R] =
- new JavaRDD(rdd.map(f)(f.returnType()))(f.returnType())
+ new JavaRDD(rdd.map(f)(fakeClassTag))(fakeClassTag)
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
def mapPartitionsWithIndex[R: ClassTag](
- f: JFunction2[Int, java.util.Iterator[T], java.util.Iterator[R]],
+ f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]],
preservesPartitioning: Boolean = false): JavaRDD[R] =
new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))),
preservesPartitioning))
@@ -82,15 +84,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
- def map[R](f: DoubleFunction[T]): JavaDoubleRDD =
- new JavaDoubleRDD(rdd.map(x => f(x).doubleValue()))
+ def mapToDouble[R](f: DoubleFunction[T]): JavaDoubleRDD = {
+ new JavaDoubleRDD(rdd.map(x => f.call(x).doubleValue()))
+ }
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
- def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
- val ctag = implicitly[ClassTag[Tuple2[K2, V2]]]
- new JavaPairRDD(rdd.map(f)(ctag))(f.keyType(), f.valueType())
+ def mapToPair[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
+ def cm = implicitly[ClassTag[(K2, V2)]]
+ new JavaPairRDD(rdd.map[(K2, V2)](f)(cm))(fakeClassTag[K2], fakeClassTag[V2])
}
/**
@@ -99,17 +102,17 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = {
import scala.collection.JavaConverters._
- def fn = (x: T) => f.apply(x).asScala
- JavaRDD.fromRDD(rdd.flatMap(fn)(f.elementType()))(f.elementType())
+ def fn = (x: T) => f.call(x).asScala
+ JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U])
}
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
- def flatMap(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = {
+ def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = {
import scala.collection.JavaConverters._
- def fn = (x: T) => f.apply(x).asScala
+ def fn = (x: T) => f.call(x).asScala
new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue()))
}
@@ -117,19 +120,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
- def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
+ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
import scala.collection.JavaConverters._
- def fn = (x: T) => f.apply(x).asScala
- val ctag = implicitly[ClassTag[Tuple2[K2, V2]]]
- JavaPairRDD.fromRDD(rdd.flatMap(fn)(ctag))(f.keyType(), f.valueType())
+ def fn = (x: T) => f.call(x).asScala
+ def cm = implicitly[ClassTag[(K2, V2)]]
+ JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2])
}
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = {
- def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
- JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType())
+ def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator())
+ JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U])
}
/**
@@ -137,52 +140,53 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U],
preservesPartitioning: Boolean): JavaRDD[U] = {
- def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
- JavaRDD.fromRDD(rdd.mapPartitions(fn, preservesPartitioning)(f.elementType()))(f.elementType())
+ def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator())
+ JavaRDD.fromRDD(
+ rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U])
}
/**
- * Return a new RDD by applying a function to each partition of this RDD.
+ * Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = {
- def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator())
new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue()))
}
/**
- * Return a new RDD by applying a function to each partition of this RDD.
+ * Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]):
+ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]):
JavaPairRDD[K2, V2] = {
- def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
- JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType())
+ def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator())
+ JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2])
}
-
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]],
- preservesPartitioning: Boolean): JavaDoubleRDD = {
- def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]],
+ preservesPartitioning: Boolean): JavaDoubleRDD = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator())
new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning)
- .map((x: java.lang.Double) => x.doubleValue()))
+ .map(x => x.doubleValue()))
}
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2],
+ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2],
preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = {
- def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
- JavaPairRDD.fromRDD(rdd.mapPartitions(fn, preservesPartitioning))(f.keyType(), f.valueType())
+ def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator())
+ JavaPairRDD.fromRDD(
+ rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2])
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: VoidFunction[java.util.Iterator[T]]) {
- rdd.foreachPartition((x => f(asJavaIterator(x))))
+ rdd.foreachPartition((x => f.call(asJavaIterator(x))))
}
/**
@@ -202,20 +206,20 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = {
+ def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JIterable[T]] = {
implicit val ctagK: ClassTag[K] = fakeClassTag
implicit val ctagV: ClassTag[JList[T]] = fakeClassTag
- JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))
+ JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(fakeClassTag)))
}
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = {
+ def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JIterable[T]] = {
implicit val ctagK: ClassTag[K] = fakeClassTag
implicit val ctagV: ClassTag[JList[T]] = fakeClassTag
- JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))
+ JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(fakeClassTag[K])))
}
/**
@@ -255,9 +259,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
other: JavaRDDLike[U, _],
f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = {
def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator(
- f.apply(asJavaIterator(x), asJavaIterator(y)).iterator())
+ f.call(asJavaIterator(x), asJavaIterator(y)).iterator())
JavaRDD.fromRDD(
- rdd.zipPartitions(other.rdd)(fn)(other.classTag, f.elementType()))(f.elementType())
+ rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V])
}
// Actions (launch a job to return a value to the user program)
@@ -266,7 +270,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Applies a function f to all elements of this RDD.
*/
def foreach(f: VoidFunction[T]) {
- val cleanF = rdd.context.clean(f)
+ val cleanF = rdd.context.clean((x: T) => f.call(x))
rdd.foreach(cleanF)
}
@@ -279,9 +283,22 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}
+ /**
+ * Return an iterator that contains all of the elements in this RDD.
+ *
+ * The iterator will consume as much memory as the largest partition in this RDD.
+ */
+ def toLocalIterator(): JIterator[T] = {
+ import scala.collection.JavaConversions._
+ rdd.toLocalIterator
+ }
+
+
/**
* Return an array that contains all of the elements in this RDD.
+ * @deprecated As of Spark 1.0.0, toArray() is deprecated, use {@link #collect()} instead
*/
+ @Deprecated
def toArray(): JList[T] = collect()
/**
@@ -320,7 +337,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U],
combOp: JFunction2[U, U, U]): U =
- rdd.aggregate(zeroValue)(seqOp, combOp)(seqOp.returnType)
+ rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U])
/**
* Return the number of elements in the RDD.
@@ -328,16 +345,20 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def count(): Long = rdd.count()
/**
- * (Experimental) Approximate version of count() that returns a potentially incomplete result
+ * :: Experimental ::
+ * Approximate version of count() that returns a potentially incomplete result
* within a timeout, even if not all tasks have finished.
*/
+ @Experimental
def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
rdd.countApprox(timeout, confidence)
/**
- * (Experimental) Approximate version of count() that returns a potentially incomplete result
+ * :: Experimental ::
+ * Approximate version of count() that returns a potentially incomplete result
* within a timeout, even if not all tasks have finished.
*/
+ @Experimental
def countApprox(timeout: Long): PartialResult[BoundedDouble] =
rdd.countApprox(timeout)
@@ -374,7 +395,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}
- def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = {
+ def takeSample(withReplacement: Boolean, num: Int): JList[T] =
+ takeSample(withReplacement, num, Utils.random.nextLong)
+
+ def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq
new java.util.ArrayList(arr)
@@ -388,19 +412,24 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/**
* Save this RDD as a text file, using string representations of elements.
*/
- def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
+ def saveAsTextFile(path: String): Unit = {
+ rdd.saveAsTextFile(path)
+ }
/**
* Save this RDD as a compressed text file, using string representations of elements.
*/
- def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) =
+ def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]): Unit = {
rdd.saveAsTextFile(path, codec)
+ }
/**
* Save this RDD as a SequenceFile of serialized objects.
*/
- def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path)
+ def saveAsObjectFile(path: String): Unit = {
+ rdd.saveAsObjectFile(path)
+ }
/**
* Creates tuples of the elements in this RDD by applying `f`.
@@ -417,7 +446,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* executed on this RDD. It is strongly recommended that this RDD is persisted in
* memory, otherwise saving it on a file will require recomputation.
*/
- def checkpoint() = rdd.checkpoint()
+ def checkpoint(): Unit = {
+ rdd.checkpoint()
+ }
/**
* Return whether this RDD has been checkpointed or not
@@ -475,6 +506,26 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}
+ /**
+ * Returns the maximum element from this RDD as defined by the specified
+ * Comparator[T].
+ * @param comp the comparator that defines ordering
+ * @return the maximum of the RDD
+ * */
+ def max(comp: Comparator[T]): T = {
+ rdd.max()(Ordering.comparatorToOrdering(comp))
+ }
+
+ /**
+ * Returns the minimum element from this RDD as defined by the specified
+ * Comparator[T].
+ * @param comp the comparator that defines ordering
+ * @return the minimum of the RDD
+ * */
+ def min(comp: Comparator[T]): T = {
+ rdd.min()(Ordering.comparatorToOrdering(comp))
+ }
+
/**
* Returns the first K elements from this RDD using the
* natural ordering for T while maintain the order.
@@ -498,8 +549,4 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def name(): String = rdd.name
- /** Reset generator */
- def setGenerator(_generator: String) = {
- rdd.setGenerator(_generator)
- }
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index dc26b7f621fee..8b95cda511643 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -17,10 +17,12 @@
package org.apache.spark.api.java
+import java.util
import java.util.{Map => JMap}
import scala.collection.JavaConversions
import scala.collection.JavaConversions._
+import scala.language.implicitConversions
import scala.reflect.ClassTag
import com.google.common.base.Optional
@@ -88,10 +90,36 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
*/
def this(master: String, appName: String, sparkHome: String, jars: Array[String],
environment: JMap[String, String]) =
- this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment))
+ this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment, Map()))
private[spark] val env = sc.env
+ def isLocal: java.lang.Boolean = sc.isLocal
+
+ def sparkUser: String = sc.sparkUser
+
+ def master: String = sc.master
+
+ def appName: String = sc.appName
+
+ def jars: util.List[String] = sc.jars
+
+ def startTime: java.lang.Long = sc.startTime
+
+ /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
+ def defaultParallelism: java.lang.Integer = sc.defaultParallelism
+
+ /**
+ * Default min number of partitions for Hadoop RDDs when not given by user.
+ * @deprecated As of Spark 1.0.0, defaultMinSplits is deprecated, use
+ * {@link #defaultMinPartitions()} instead
+ */
+ @deprecated("use defaultMinPartitions", "1.0.0")
+ def defaultMinSplits: java.lang.Integer = sc.defaultMinSplits
+
+ /** Default min number of partitions for Hadoop RDDs when not given by user */
+ def defaultMinPartitions: java.lang.Integer = sc.defaultMinPartitions
+
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = {
implicit val ctag: ClassTag[T] = fakeClassTag
@@ -133,7 +161,48 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* Read a text file from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
- def textFile(path: String, minSplits: Int): JavaRDD[String] = sc.textFile(path, minSplits)
+ def textFile(path: String, minPartitions: Int): JavaRDD[String] =
+ sc.textFile(path, minPartitions)
+
+ /**
+ * Read a directory of text files from HDFS, a local file system (available on all nodes), or any
+ * Hadoop-supported file system URI. Each file is read as a single record and returned in a
+ * key-value pair, where the key is the path of each file, the value is the content of each file.
+ *
+ *
For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do `JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")`,
+ *
+ *
then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred, large file is also allowable, but may cause bad performance.
+ *
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
+ */
+ def wholeTextFiles(path: String, minPartitions: Int): JavaPairRDD[String, String] =
+ new JavaPairRDD(sc.wholeTextFiles(path, minPartitions))
+
+ /**
+ * Read a directory of text files from HDFS, a local file system (available on all nodes), or any
+ * Hadoop-supported file system URI. Each file is read as a single record and returned in a
+ * key-value pair, where the key is the path of each file, the value is the content of each file.
+ *
+ * @see `wholeTextFiles(path: String, minPartitions: Int)`.
+ */
+ def wholeTextFiles(path: String): JavaPairRDD[String, String] =
+ new JavaPairRDD(sc.wholeTextFiles(path))
/** Get an RDD for a Hadoop SequenceFile with given key and value types.
*
@@ -145,11 +214,11 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V],
- minSplits: Int
+ minPartitions: Int
): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(keyClass)
implicit val ctagV: ClassTag[V] = ClassTag(valueClass)
- new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits))
+ new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minPartitions))
}
/** Get an RDD for a Hadoop SequenceFile.
@@ -173,9 +242,9 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* slow if you use the default serializer (Java serialization), though the nice thing about it is
* that there's very little effort required to save arbitrary objects.
*/
- def objectFile[T](path: String, minSplits: Int): JavaRDD[T] = {
+ def objectFile[T](path: String, minPartitions: Int): JavaRDD[T] = {
implicit val ctag: ClassTag[T] = fakeClassTag
- sc.objectFile(path, minSplits)(ctag)
+ sc.objectFile(path, minPartitions)(ctag)
}
/**
@@ -205,11 +274,11 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
inputFormatClass: Class[F],
keyClass: Class[K],
valueClass: Class[V],
- minSplits: Int
+ minPartitions: Int
): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(keyClass)
implicit val ctagV: ClassTag[V] = ClassTag(valueClass)
- new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits))
+ new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minPartitions))
}
/**
@@ -244,11 +313,11 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
inputFormatClass: Class[F],
keyClass: Class[K],
valueClass: Class[V],
- minSplits: Int
+ minPartitions: Int
): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(keyClass)
implicit val ctagV: ClassTag[V] = ClassTag(valueClass)
- new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits))
+ new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions))
}
/** Get an RDD for a Hadoop file with an arbitrary InputFormat
@@ -415,6 +484,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* Clear the job's list of JARs added by `addJar` so that they do not get downloaded to
* any new nodes.
*/
+ @deprecated("adding jars no longer creates local copies that need to be deleted", "1.0.0")
def clearJars() {
sc.clearJars()
}
@@ -423,6 +493,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes.
*/
+ @deprecated("adding files no longer creates local copies that need to be deleted", "1.0.0")
def clearFiles() {
sc.clearFiles()
}
@@ -442,7 +513,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
sc.setCheckpointDir(dir)
}
- def getCheckpointDir = JavaUtils.optionToOptional(sc.getCheckpointDir)
+ def getCheckpointDir: Optional[String] = JavaUtils.optionToOptional(sc.getCheckpointDir)
protected def checkpointFile[T](path: String): JavaRDD[T] = {
implicit val ctag: ClassTag[T] = fakeClassTag
@@ -499,6 +570,21 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* // In a separate thread:
* sc.cancelJobGroup("some_job_to_cancel");
* }}}
+ *
+ * If interruptOnCancel is set to true for the job group, then job cancellation will result
+ * in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure
+ * that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208,
+ * where HDFS may respond to Thread.interrupt() by marking nodes as dead.
+ */
+ def setJobGroup(groupId: String, description: String, interruptOnCancel: Boolean): Unit =
+ sc.setJobGroup(groupId, description, interruptOnCancel)
+
+ /**
+ * Assigns a group ID to all the jobs started by this thread until the group ID is set to a
+ * different value or cleared.
+ *
+ * @see `setJobGroup(groupId: String, description: String, interruptThread: Boolean)`.
+ * This method sets interruptOnCancel to false.
*/
def setJobGroup(groupId: String, description: String): Unit = sc.setJobGroup(groupId, description)
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
index ecbf18849ad48..22810cb1c662d 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -19,7 +19,7 @@ package org.apache.spark.api.java
import com.google.common.base.Optional
-object JavaUtils {
+private[spark] object JavaUtils {
def optionToOptional[T](option: Option[T]): Optional[T] =
option match {
case Some(value) => Optional.of(value)
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.scala
deleted file mode 100644
index 7500a8943634b..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.scala
+++ /dev/null
@@ -1,30 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import java.lang.{Double => JDouble, Iterable => JIterable}
-
-/**
- * A function that returns zero or more records of type Double from each input record.
- */
-// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
-// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
-abstract class DoubleFlatMapFunction[T] extends WrappedFunction1[T, JIterable[JDouble]]
- with Serializable {
- // Intentionally left blank
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.scala
deleted file mode 100644
index 2cdf2e92c3daa..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.scala
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import java.lang.{Double => JDouble}
-
-/**
- * A function that returns Doubles, and can be used to construct DoubleRDDs.
- */
-// DoubleFunction does not extend Function because some UDF functions, like map,
-// are overloaded for both Function and DoubleFunction.
-abstract class DoubleFunction[T] extends WrappedFunction1[T, JDouble] with Serializable {
- // Intentionally left blank
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
deleted file mode 100644
index bdb01f7670356..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
+++ /dev/null
@@ -1,27 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import scala.reflect.ClassTag
-
-/**
- * A function that returns zero or more output records from each input record.
- */
-abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
- def elementType(): ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]]
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
deleted file mode 100644
index aae1349c5e17c..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
+++ /dev/null
@@ -1,27 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import scala.reflect.ClassTag
-
-/**
- * A function that takes two inputs and returns zero or more output records.
- */
-abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
- def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]]
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.scala b/core/src/main/scala/org/apache/spark/api/java/function/Function.scala
deleted file mode 100644
index a5e1701f7718f..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function.scala
+++ /dev/null
@@ -1,31 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import scala.reflect.ClassTag
-import org.apache.spark.api.java.JavaSparkContext
-
-/**
- * Base class for functions whose return types do not create special RDDs. PairFunction and
- * DoubleFunction are handled separately, to allow PairRDDs and DoubleRDDs to be constructed
- * when mapping RDDs of other types.
- */
-abstract class Function[T, R] extends WrappedFunction1[T, R] with Serializable {
- def returnType(): ClassTag[R] = JavaSparkContext.fakeClassTag
-}
-
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function2.scala b/core/src/main/scala/org/apache/spark/api/java/function/Function2.scala
deleted file mode 100644
index fa3616cbcb4d2..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function2.scala
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import scala.reflect.ClassTag
-import org.apache.spark.api.java.JavaSparkContext
-
-/**
- * A two-argument function that takes arguments of type T1 and T2 and returns an R.
- */
-abstract class Function2[T1, T2, R] extends WrappedFunction2[T1, T2, R] with Serializable {
- def returnType(): ClassTag[R] = JavaSparkContext.fakeClassTag
-}
-
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function3.scala b/core/src/main/scala/org/apache/spark/api/java/function/Function3.scala
deleted file mode 100644
index 45152891e9272..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function3.scala
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import org.apache.spark.api.java.JavaSparkContext
-import scala.reflect.ClassTag
-
-/**
- * A three-argument function that takes arguments of type T1, T2 and T3 and returns an R.
- */
-abstract class Function3[T1, T2, T3, R] extends WrappedFunction3[T1, T2, T3, R] with Serializable {
- def returnType(): ClassTag[R] = JavaSparkContext.fakeClassTag
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.scala
deleted file mode 100644
index 8467bbb892ab0..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import java.lang.{Iterable => JIterable}
-import org.apache.spark.api.java.JavaSparkContext
-import scala.reflect.ClassTag
-
-/**
- * A function that returns zero or more key-value pair records from each input record. The
- * key-value pairs are represented as scala.Tuple2 objects.
- */
-// PairFlatMapFunction does not extend FlatMapFunction because flatMap is
-// overloaded for both FlatMapFunction and PairFlatMapFunction.
-abstract class PairFlatMapFunction[T, K, V] extends WrappedFunction1[T, JIterable[(K, V)]]
- with Serializable {
-
- def keyType(): ClassTag[K] = JavaSparkContext.fakeClassTag
-
- def valueType(): ClassTag[V] = JavaSparkContext.fakeClassTag
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.scala
deleted file mode 100644
index d0ba0b6307ee9..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.scala
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import scala.reflect.ClassTag
-import org.apache.spark.api.java.JavaSparkContext
-
-/**
- * A function that returns key-value pairs (Tuple2), and can be used to construct PairRDDs.
- */
-// PairFunction does not extend Function because some UDF functions, like map,
-// are overloaded for both Function and PairFunction.
-abstract class PairFunction[T, K, V] extends WrappedFunction1[T, (K, V)] with Serializable {
-
- def keyType(): ClassTag[K] = JavaSparkContext.fakeClassTag
-
- def valueType(): ClassTag[V] = JavaSparkContext.fakeClassTag
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala
deleted file mode 100644
index ea94313a4ab59..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-/**
- * A function with no return value.
- */
-// This allows Java users to write void methods without having to return Unit.
-abstract class VoidFunction[T] extends Serializable {
- @throws(classOf[Exception])
- def call(t: T) : Unit
-}
-
-// VoidFunction cannot extend AbstractFunction1 (because that would force users to explicitly
-// return Unit), so it is implicitly converted to a Function1[T, Unit]:
-object VoidFunction {
- implicit def toFunction[T](f: VoidFunction[T]) : Function1[T, Unit] = ((x : T) => f.call(x))
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala
deleted file mode 100644
index cfe694f65d558..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import scala.runtime.AbstractFunction1
-
-/**
- * Subclass of Function1 for ease of calling from Java. The main thing it does is re-expose the
- * apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply
- * isn't marked to allow that).
- */
-private[spark] abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] {
- @throws(classOf[Exception])
- def call(t: T): R
-
- final def apply(t: T): R = call(t)
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala
deleted file mode 100644
index eb9277c6fb4cb..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import scala.runtime.AbstractFunction2
-
-/**
- * Subclass of Function2 for ease of calling from Java. The main thing it does is re-expose the
- * apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply
- * isn't marked to allow that).
- */
-private[spark] abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] {
- @throws(classOf[Exception])
- def call(t1: T1, t2: T2): R
-
- final def apply(t1: T1, t2: T2): R = call(t1, t2)
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala
deleted file mode 100644
index d314dbdf1d980..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import scala.runtime.AbstractFunction3
-
-/**
- * Subclass of Function3 for ease of calling from Java. The main thing it does is re-expose the
- * apply() method as call() and declare that it can throw Exception (since AbstractFunction3.apply
- * isn't marked to allow that).
- */
-private[spark] abstract class WrappedFunction3[T1, T2, T3, R]
- extends AbstractFunction3[T1, T2, T3, R] {
- @throws(classOf[Exception])
- def call(t1: T1, t2: T2, t3: T3): R
-
- final def apply(t1: T1, t2: T2, t3: T3): R = call(t1, t2, t3)
-}
-
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index e4d0285710e84..672c344a56597 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,10 +19,14 @@ package org.apache.spark.api.python
import java.io._
import java.net._
+import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
+import scala.util.Try
+
+import net.razorvine.pickle.{Pickler, Unpickler}
import org.apache.spark._
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
@@ -86,20 +90,62 @@ private[spark] class PythonRDD[T: ClassTag](
dataOut.flush()
worker.shutdownOutput()
} catch {
+
case e: java.io.FileNotFoundException =>
readerException = e
- // Kill the Python worker process:
- worker.shutdownOutput()
+ Try(worker.shutdownOutput()) // kill Python worker process
+
case e: IOException =>
// This can happen for legitimate reasons if the Python code stops returning data
- // before we are done passing elements through, e.g., for take(). Just log a message
- // to say it happened.
- logInfo("stdin writer to Python finished early")
- logDebug("stdin writer to Python finished early", e)
+ // before we are done passing elements through, e.g., for take(). Just log a message to
+ // say it happened (as it could also be hiding a real IOException from a data source).
+ logInfo("stdin writer to Python finished early (may not be an error)", e)
+
+ case e: Exception =>
+ // We must avoid throwing exceptions here, because the thread uncaught exception handler
+ // will kill the whole executor (see Executor).
+ readerException = e
+ Try(worker.shutdownOutput()) // kill Python worker process
+ }
+ }
+ }.start()
+
+ // Necessary to distinguish between a task that has failed and a task that is finished
+ @volatile var complete: Boolean = false
+
+ // It is necessary to have a monitor thread for python workers if the user cancels with
+ // interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
+ // threads can block indefinitely.
+ new Thread(s"Worker Monitor for $pythonExec") {
+ override def run() {
+ // Kill the worker if it is interrupted or completed
+ // When a python task completes, the context is always set to interupted
+ while (!context.interrupted) {
+ Thread.sleep(2000)
+ }
+ if (!complete) {
+ try {
+ logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
+ env.destroyPythonWorker(pythonExec, envVars.toMap)
+ } catch {
+ case e: Exception =>
+ logError("Exception when trying to kill worker", e)
+ }
}
}
}.start()
+ /*
+ * Partial fix for SPARK-1019: Attempts to stop reading the input stream since
+ * other completion callbacks might invalidate the input. Because interruption
+ * is not synchronous this still leaves a potential race where the interruption is
+ * processed only after the stream becomes invalid.
+ */
+ context.addOnCompleteCallback{ () =>
+ complete = true // Indicate that the task has completed successfully
+ context.interrupted = true
+ }
+
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
val stdoutIterator = new Iterator[Array[Byte]] {
@@ -141,7 +187,7 @@ private[spark] class PythonRDD[T: ClassTag](
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
- throw new PythonException(new String(obj))
+ throw new PythonException(new String(obj, "utf-8"), readerException)
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
// read some accumulator updates:
@@ -151,15 +197,17 @@ private[spark] class PythonRDD[T: ClassTag](
val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator += Collections.singletonList(update)
-
}
Array.empty[Byte]
}
} catch {
- case eof: EOFException => {
+ case e: Exception if readerException != null =>
+ logError("Python worker exited unexpectedly (crashed)", e)
+ logError("Python crash may have been caused by prior exception:", readerException)
+ throw readerException
+
+ case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
- }
- case e: Throwable => throw e
}
}
@@ -174,7 +222,7 @@ private[spark] class PythonRDD[T: ClassTag](
}
/** Thrown for exceptions in user Python code. */
-private class PythonException(msg: String) extends Exception(msg)
+private class PythonException(msg: String, cause: Exception) extends RuntimeException(msg, cause)
/**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
@@ -198,6 +246,7 @@ private object SpecialLengths {
}
private[spark] object PythonRDD {
+ val UTF8 = Charset.forName("UTF-8")
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
@@ -258,7 +307,7 @@ private[spark] object PythonRDD {
}
def writeUTF(str: String, dataOut: DataOutputStream) {
- val bytes = str.getBytes("UTF-8")
+ val bytes = str.getBytes(UTF8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
@@ -274,11 +323,41 @@ private[spark] object PythonRDD {
file.close()
}
+ /**
+ * Convert an RDD of serialized Python dictionaries to Scala Maps
+ * TODO: Support more Python types.
+ */
+ def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ val unpickle = new Unpickler
+ // TODO: Figure out why flatMap is necessay for pyspark
+ iter.flatMap { row =>
+ unpickle.loads(row) match {
+ case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
+ // Incase the partition doesn't have a collection
+ case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
+ }
+ }
+ }
+ }
+
+ /**
+ * Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by
+ * PySpark.
+ */
+ def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
+ jRDD.rdd.mapPartitions { iter =>
+ val pickle = new Pickler
+ iter.map { row =>
+ pickle.dumps(row)
+ }
+ }
+ }
}
private
class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
- override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
+ override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index d113d4040594d..738a3b1bed7f3 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -18,9 +18,8 @@
package org.apache.spark.broadcast
import java.io.Serializable
-import java.util.concurrent.atomic.AtomicLong
-import org.apache.spark._
+import org.apache.spark.SparkException
/**
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
@@ -29,7 +28,8 @@ import org.apache.spark._
* attempts to distribute broadcast variables using efficient broadcast algorithms to reduce
* communication cost.
*
- * Broadcast variables are created from a variable `v` by calling [[SparkContext#broadcast]].
+ * Broadcast variables are created from a variable `v` by calling
+ * [[org.apache.spark.SparkContext#broadcast]].
* The broadcast variable is a wrapper around `v`, and its value can be accessed by calling the
* `value` method. The interpreter session below shows this:
*
@@ -51,48 +51,80 @@ import org.apache.spark._
* @tparam T Type of the data contained in the broadcast variable.
*/
abstract class Broadcast[T](val id: Long) extends Serializable {
- def value: T
- // We cannot have an abstract readObject here due to some weird issues with
- // readObject having to be 'private' in sub-classes.
+ /**
+ * Flag signifying whether the broadcast variable is valid
+ * (that is, not already destroyed) or not.
+ */
+ @volatile private var _isValid = true
- override def toString = "Broadcast(" + id + ")"
-}
-
-private[spark]
-class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging with Serializable {
-
- private var initialized = false
- private var broadcastFactory: BroadcastFactory = null
-
- initialize()
-
- // Called by SparkContext or Executor before using Broadcast
- private def initialize() {
- synchronized {
- if (!initialized) {
- val broadcastFactoryClass = conf.get(
- "spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
-
- broadcastFactory =
- Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
+ /** Get the broadcasted value. */
+ def value: T = {
+ assertValid()
+ getValue()
+ }
- // Initialize appropriate BroadcastFactory and BroadcastObject
- broadcastFactory.initialize(isDriver, conf)
+ /**
+ * Asynchronously delete cached copies of this broadcast on the executors.
+ * If the broadcast is used after this is called, it will need to be re-sent to each executor.
+ */
+ def unpersist() {
+ unpersist(blocking = false)
+ }
- initialized = true
- }
- }
+ /**
+ * Delete cached copies of this broadcast on the executors. If the broadcast is used after
+ * this is called, it will need to be re-sent to each executor.
+ * @param blocking Whether to block until unpersisting has completed
+ */
+ def unpersist(blocking: Boolean) {
+ assertValid()
+ doUnpersist(blocking)
}
- def stop() {
- broadcastFactory.stop()
+ /**
+ * Destroy all data and metadata related to this broadcast variable. Use this with caution;
+ * once a broadcast variable has been destroyed, it cannot be used again.
+ */
+ private[spark] def destroy(blocking: Boolean) {
+ assertValid()
+ _isValid = false
+ doDestroy(blocking)
}
- private val nextBroadcastId = new AtomicLong(0)
+ /**
+ * Whether this Broadcast is actually usable. This should be false once persisted state is
+ * removed from the driver.
+ */
+ private[spark] def isValid: Boolean = {
+ _isValid
+ }
- def newBroadcast[T](value_ : T, isLocal: Boolean) =
- broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
+ /**
+ * Actually get the broadcasted value. Concrete implementations of Broadcast class must
+ * define their own way to get the value.
+ */
+ private[spark] def getValue(): T
+
+ /**
+ * Actually unpersist the broadcasted value on the executors. Concrete implementations of
+ * Broadcast class must define their own logic to unpersist their own data.
+ */
+ private[spark] def doUnpersist(blocking: Boolean)
+
+ /**
+ * Actually destroy all data and metadata related to this broadcast variable.
+ * Implementation of Broadcast class must define their own logic to destroy their own
+ * state.
+ */
+ private[spark] def doDestroy(blocking: Boolean)
+
+ /** Check if this broadcast is valid. If not valid, exception is thrown. */
+ private[spark] def assertValid() {
+ if (!_isValid) {
+ throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString))
+ }
+ }
- def isDriver = _isDriver
+ override def toString = "Broadcast(" + id + ")"
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index 940e5ab805100..8c8ce9b1691ac 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -17,16 +17,21 @@
package org.apache.spark.broadcast
+import org.apache.spark.SecurityManager
import org.apache.spark.SparkConf
+import org.apache.spark.annotation.DeveloperApi
/**
- * An interface for all the broadcast implementations in Spark (to allow
+ * :: DeveloperApi ::
+ * An interface for all the broadcast implementations in Spark (to allow
* multiple broadcast implementations). SparkContext uses a user-specified
* BroadcastFactory implementation to instantiate a particular broadcast for the
* entire Spark job.
*/
+@DeveloperApi
trait BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf): Unit
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
+ def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
new file mode 100644
index 0000000000000..cf62aca4d45e8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark._
+
+private[spark] class BroadcastManager(
+ val isDriver: Boolean,
+ conf: SparkConf,
+ securityManager: SecurityManager)
+ extends Logging {
+
+ private var initialized = false
+ private var broadcastFactory: BroadcastFactory = null
+
+ initialize()
+
+ // Called by SparkContext or Executor before using Broadcast
+ private def initialize() {
+ synchronized {
+ if (!initialized) {
+ val broadcastFactoryClass =
+ conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+
+ broadcastFactory =
+ Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
+
+ // Initialize appropriate BroadcastFactory and BroadcastObject
+ broadcastFactory.initialize(isDriver, conf, securityManager)
+
+ initialized = true
+ }
+ }
+ }
+
+ def stop() {
+ broadcastFactory.stop()
+ }
+
+ private val nextBroadcastId = new AtomicLong(0)
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean) = {
+ broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
+ }
+
+ def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
+ broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 20207c261320b..29372f16f2cac 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -17,34 +17,64 @@
package org.apache.spark.broadcast
-import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream}
-import java.net.URL
+import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream}
+import java.io.{BufferedInputStream, BufferedOutputStream}
+import java.net.{URL, URLConnection, URI}
import java.util.concurrent.TimeUnit
-import it.unimi.dsi.fastutil.io.FastBufferedInputStream
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
-import org.apache.spark.{HttpServer, Logging, SparkConf, SparkEnv}
+import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
+/**
+ * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server
+ * as a broadcast mechanism. The first time a HTTP broadcast variable (sent as part of a
+ * task) is deserialized in the executor, the broadcasted data is fetched from the driver
+ * (through a HTTP server running at the driver) and stored in the BlockManager of the
+ * executor to speed up future accesses.
+ */
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
- def value = value_
+ def getValue = value_
- def blockId = BroadcastBlockId(id)
+ val blockId = BroadcastBlockId(id)
+ /*
+ * Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster
+ * does not need to be told about this block as not only need to know about this data block.
+ */
HttpBroadcast.synchronized {
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ SparkEnv.get.blockManager.putSingle(
+ blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
}
if (!isLocal) {
HttpBroadcast.write(id, value_)
}
- // Called by JVM when deserializing an object
+ /**
+ * Remove all persisted state associated with this HTTP broadcast on the executors.
+ */
+ def doUnpersist(blocking: Boolean) {
+ HttpBroadcast.unpersist(id, removeFromDriver = false, blocking)
+ }
+
+ /**
+ * Remove all persisted state associated with this HTTP broadcast on the executors and driver.
+ */
+ def doDestroy(blocking: Boolean) {
+ HttpBroadcast.unpersist(id, removeFromDriver = true, blocking)
+ }
+
+ /** Used by the JVM when serializing this object. */
+ private def writeObject(out: ObjectOutputStream) {
+ assertValid()
+ out.defaultWriteObject()
+ }
+
+ /** Used by the JVM when deserializing this object. */
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
HttpBroadcast.synchronized {
@@ -54,7 +84,13 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](id)
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ /*
+ * We cache broadcast data in the BlockManager so that subsequent tasks using it
+ * do not need to re-fetch. This data is only used locally and no other node
+ * needs to fetch this block, so we don't notify the master.
+ */
+ SparkEnv.get.blockManager.putSingle(
+ blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
@@ -63,40 +99,27 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
}
}
-/**
- * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
- */
-class HttpBroadcastFactory extends BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new HttpBroadcast[T](value_, isLocal, id)
-
- def stop() { HttpBroadcast.stop() }
-}
-
-private object HttpBroadcast extends Logging {
+private[spark] object HttpBroadcast extends Logging {
private var initialized = false
-
private var broadcastDir: File = null
private var compress: Boolean = false
private var bufferSize: Int = 65536
private var serverUri: String = null
private var server: HttpServer = null
+ private var securityManager: SecurityManager = null
// TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist
private val files = new TimeStampedHashSet[String]
- private var cleaner: MetadataCleaner = null
-
private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt
-
private var compressionCodec: CompressionCodec = null
+ private var cleaner: MetadataCleaner = null
- def initialize(isDriver: Boolean, conf: SparkConf) {
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
synchronized {
if (!initialized) {
bufferSize = conf.getInt("spark.buffer.size", 65536)
compress = conf.getBoolean("spark.broadcast.compress", true)
+ securityManager = securityMgr
if (isDriver) {
createServer(conf)
conf.set("spark.httpBroadcast.uri", serverUri)
@@ -126,19 +149,21 @@ private object HttpBroadcast extends Logging {
private def createServer(conf: SparkConf) {
broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf))
- server = new HttpServer(broadcastDir)
+ server = new HttpServer(broadcastDir, securityManager)
server.start()
serverUri = server.uri
logInfo("Broadcast server started at " + serverUri)
}
+ def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)
+
def write(id: Long, value: Any) {
- val file = new File(broadcastDir, BroadcastBlockId(id).name)
+ val file = getFile(id)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
} else {
- new FastBufferedOutputStream(new FileOutputStream(file), bufferSize)
+ new BufferedOutputStream(new FileOutputStream(file), bufferSize)
}
}
val ser = SparkEnv.get.serializer.newInstance()
@@ -149,15 +174,27 @@ private object HttpBroadcast extends Logging {
}
def read[T](id: Long): T = {
+ logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
val url = serverUri + "/" + BroadcastBlockId(id).name
+
+ var uc: URLConnection = null
+ if (securityManager.isAuthenticationEnabled()) {
+ logDebug("broadcast security enabled")
+ val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
+ uc = newuri.toURL.openConnection()
+ uc.setAllowUserInteraction(false)
+ } else {
+ logDebug("broadcast not using security")
+ uc = new URL(url).openConnection()
+ }
+
val in = {
- val httpConnection = new URL(url).openConnection()
- httpConnection.setReadTimeout(httpReadTimeout)
- val inputStream = httpConnection.getInputStream
+ uc.setReadTimeout(httpReadTimeout)
+ val inputStream = uc.getInputStream
if (compress) {
compressionCodec.compressedInputStream(inputStream)
} else {
- new FastBufferedInputStream(inputStream, bufferSize)
+ new BufferedInputStream(inputStream, bufferSize)
}
}
val ser = SparkEnv.get.serializer.newInstance()
@@ -167,20 +204,48 @@ private object HttpBroadcast extends Logging {
obj
}
- def cleanup(cleanupTime: Long) {
+ /**
+ * Remove all persisted blocks associated with this HTTP broadcast on the executors.
+ * If removeFromDriver is true, also remove these persisted blocks on the driver
+ * and delete the associated broadcast file.
+ */
+ def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {
+ SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
+ if (removeFromDriver) {
+ val file = getFile(id)
+ files.remove(file.toString)
+ deleteBroadcastFile(file)
+ }
+ }
+
+ /**
+ * Periodically clean up old broadcasts by removing the associated map entries and
+ * deleting the associated files.
+ */
+ private def cleanup(cleanupTime: Long) {
val iterator = files.internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val (file, time) = (entry.getKey, entry.getValue)
if (time < cleanupTime) {
- try {
- iterator.remove()
- new File(file.toString).delete()
- logInfo("Deleted broadcast file '" + file + "'")
- } catch {
- case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
+ iterator.remove()
+ deleteBroadcastFile(new File(file.toString))
+ }
+ }
+ }
+
+ private def deleteBroadcastFile(file: File) {
+ try {
+ if (file.exists) {
+ if (file.delete()) {
+ logInfo("Deleted broadcast file: %s".format(file))
+ } else {
+ logWarning("Could not delete broadcast file: %s".format(file))
}
}
+ } catch {
+ case e: Exception =>
+ logError("Exception while deleting broadcast file: %s".format(file), e)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
new file mode 100644
index 0000000000000..e3f6cdc6154dd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import org.apache.spark.{SecurityManager, SparkConf}
+
+/**
+ * A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a
+ * HTTP server as the broadcast mechanism. Refer to
+ * [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism.
+ */
+class HttpBroadcastFactory extends BroadcastFactory {
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
+ HttpBroadcast.initialize(isDriver, conf, securityMgr)
+ }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new HttpBroadcast[T](value_, isLocal, id)
+
+ def stop() { HttpBroadcast.stop() }
+
+ /**
+ * Remove all persisted state associated with the HTTP broadcast with the given ID.
+ * @param removeFromDriver Whether to remove state from the driver
+ * @param blocking Whether to block until unbroadcasted
+ */
+ def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
+ HttpBroadcast.unpersist(id, removeFromDriver, blocking)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 22d783c8590c6..2659274c5e98e 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -17,24 +17,43 @@
package org.apache.spark.broadcast
-import java.io._
+import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}
import scala.math
import scala.util.Random
-import org.apache.spark._
-import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
+import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
+import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.Utils
+/**
+ * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
+ * protocol to do a distributed transfer of the broadcasted data to the executors.
+ * The mechanism is as follows. The driver divides the serializes the broadcasted data,
+ * divides it into smaller chunks, and stores them in the BlockManager of the driver.
+ * These chunks are reported to the BlockManagerMaster so that all the executors can
+ * learn the location of those chunks. The first time the broadcast variable (sent as
+ * part of task) is deserialized at a executor, all the chunks are fetched using
+ * the BlockManager. When all the chunks are fetched (initially from the driver's
+ * BlockManager), they are combined and deserialized to recreate the broadcasted data.
+ * However, the chunks are also stored in the BlockManager and reported to the
+ * BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns
+ * multiple locations for each chunk. Hence, subsequent fetches of each chunk will be
+ * made to other executors who already have those chunks, resulting in a distributed
+ * fetching. This prevents the driver from being the bottleneck in sending out multiple
+ * copies of the broadcast data (one per executor) as done by the
+ * [[org.apache.spark.broadcast.HttpBroadcast]].
+ */
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
-extends Broadcast[T](id) with Logging with Serializable {
+ extends Broadcast[T](id) with Logging with Serializable {
- def value = value_
+ def getValue = value_
- def broadcastId = BroadcastBlockId(id)
+ val broadcastId = BroadcastBlockId(id)
TorrentBroadcast.synchronized {
- SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ SparkEnv.get.blockManager.putSingle(
+ broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
}
@transient var arrayOfBlocks: Array[TorrentBlock] = null
@@ -46,32 +65,52 @@ extends Broadcast[T](id) with Logging with Serializable {
sendBroadcast()
}
- def sendBroadcast() {
- var tInfo = TorrentBroadcast.blockifyObject(value_)
+ /**
+ * Remove all persisted state associated with this Torrent broadcast on the executors.
+ */
+ def doUnpersist(blocking: Boolean) {
+ TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)
+ }
+
+ /**
+ * Remove all persisted state associated with this Torrent broadcast on the executors
+ * and driver.
+ */
+ def doDestroy(blocking: Boolean) {
+ TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
+ }
+ def sendBroadcast() {
+ val tInfo = TorrentBroadcast.blockifyObject(value_)
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
hasBlocks = tInfo.totalBlocks
// Store meta-info
- val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ val metaId = BroadcastBlockId(id, "meta")
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
- metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
+ metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
}
// Store individual pieces
for (i <- 0 until totalBlocks) {
- val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
+ val pieceId = BroadcastBlockId(id, "piece" + i)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
- pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
+ pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
}
}
}
- // Called by JVM when deserializing an object
+ /** Used by the JVM when serializing this object. */
+ private def writeObject(out: ObjectOutputStream) {
+ assertValid()
+ out.defaultWriteObject()
+ }
+
+ /** Used by the JVM when deserializing this object. */
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TorrentBroadcast.synchronized {
@@ -86,18 +125,22 @@ extends Broadcast[T](id) with Logging with Serializable {
// Initialize @transient variables that will receive garbage values from the master.
resetWorkerVariables()
- if (receiveBroadcast(id)) {
+ if (receiveBroadcast()) {
value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- // Store the merged copy in cache so that the next worker doesn't need to rebuild it.
- // This creates a tradeoff between memory usage and latency.
- // Storing copy doubles the memory footprint; not storing doubles deserialization cost.
+ /* Store the merged copy in cache so that the next worker doesn't need to rebuild it.
+ * This creates a trade-off between memory usage and latency. Storing copy doubles
+ * the memory footprint; not storing doubles deserialization cost. Also,
+ * this does not need to be reported to BlockManagerMaster since other executors
+ * does not need to access this block (they only need to fetch the chunks,
+ * which are reported).
+ */
SparkEnv.get.blockManager.putSingle(
- broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
// Remove arrayOfBlocks from memory once value_ is on local cache
resetWorkerVariables()
- } else {
+ } else {
logError("Reading broadcast variable " + id + " failed")
}
@@ -114,9 +157,10 @@ extends Broadcast[T](id) with Logging with Serializable {
hasBlocks = 0
}
- def receiveBroadcast(variableID: Long): Boolean = {
- // Receive meta-info
- val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ def receiveBroadcast(): Boolean = {
+ // Receive meta-info about the size of broadcast data,
+ // the number of chunks it is divided into, etc.
+ val metaId = BroadcastBlockId(id, "meta")
var attemptId = 10
while (attemptId > 0 && totalBlocks == -1) {
TorrentBroadcast.synchronized {
@@ -138,17 +182,21 @@ extends Broadcast[T](id) with Logging with Serializable {
return false
}
- // Receive actual blocks
+ /*
+ * Fetch actual chunks of data. Note that all these chunks are stored in
+ * the BlockManager and reported to the master, so that other executors
+ * can find out and pull the chunks from this executor.
+ */
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
for (pid <- recvOrder) {
- val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
+ val pieceId = BroadcastBlockId(id, "piece" + pid)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
hasBlocks += 1
SparkEnv.get.blockManager.putSingle(
- pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
+ pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
@@ -156,18 +204,18 @@ extends Broadcast[T](id) with Logging with Serializable {
}
}
- (hasBlocks == totalBlocks)
+ hasBlocks == totalBlocks
}
}
-private object TorrentBroadcast
-extends Logging {
-
+private[spark] object TorrentBroadcast extends Logging {
+ private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
private var initialized = false
private var conf: SparkConf = null
+
def initialize(_isDriver: Boolean, conf: SparkConf) {
- TorrentBroadcast.conf = conf //TODO: we might have to fix it in tests
+ TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests
synchronized {
if (!initialized) {
initialized = true
@@ -179,39 +227,37 @@ extends Logging {
initialized = false
}
- lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
-
def blockifyObject[T](obj: T): TorrentInfo = {
val byteArray = Utils.serialize[T](obj)
val bais = new ByteArrayInputStream(byteArray)
- var blockNum = (byteArray.length / BLOCK_SIZE)
+ var blockNum = byteArray.length / BLOCK_SIZE
if (byteArray.length % BLOCK_SIZE != 0) {
blockNum += 1
}
- var retVal = new Array[TorrentBlock](blockNum)
- var blockID = 0
+ val blocks = new Array[TorrentBlock](blockNum)
+ var blockId = 0
for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
- var tempByteArray = new Array[Byte](thisBlockSize)
- val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
+ val tempByteArray = new Array[Byte](thisBlockSize)
+ bais.read(tempByteArray, 0, thisBlockSize)
- retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
- blockID += 1
+ blocks(blockId) = new TorrentBlock(blockId, tempByteArray)
+ blockId += 1
}
bais.close()
- val tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
- tInfo.hasBlocks = blockNum
-
- tInfo
+ val info = TorrentInfo(blocks, blockNum, byteArray.length)
+ info.hasBlocks = blockNum
+ info
}
- def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
- totalBytes: Int,
- totalBlocks: Int): T = {
+ def unBlockifyObject[T](
+ arrayOfBlocks: Array[TorrentBlock],
+ totalBytes: Int,
+ totalBlocks: Int): T = {
val retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
@@ -220,6 +266,13 @@ extends Logging {
Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
}
+ /**
+ * Remove all persisted blocks associated with this torrent broadcast on the executors.
+ * If removeFromDriver is true, also remove these persisted blocks on the driver.
+ */
+ def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {
+ SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
+ }
}
private[spark] case class TorrentBlock(
@@ -228,23 +281,10 @@ private[spark] case class TorrentBlock(
extends Serializable
private[spark] case class TorrentInfo(
- @transient arrayOfBlocks : Array[TorrentBlock],
+ @transient arrayOfBlocks: Array[TorrentBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {
@transient var hasBlocks = 0
}
-
-/**
- * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast.
- */
-class TorrentBroadcastFactory extends BroadcastFactory {
-
- def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new TorrentBroadcast[T](value_, isLocal, id)
-
- def stop() { TorrentBroadcast.stop() }
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
new file mode 100644
index 0000000000000..d216b58718148
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import org.apache.spark.{SecurityManager, SparkConf}
+
+/**
+ * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
+ * protocol to do a distributed transfer of the broadcasted data to the executors. Refer to
+ * [[org.apache.spark.broadcast.TorrentBroadcast]] for more details.
+ */
+class TorrentBroadcastFactory extends BroadcastFactory {
+
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
+ TorrentBroadcast.initialize(isDriver, conf)
+ }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new TorrentBroadcast[T](value_, isLocal, id)
+
+ def stop() { TorrentBroadcast.stop() }
+
+ /**
+ * Remove all persisted state associated with the torrent broadcast with the given ID.
+ * @param removeFromDriver Whether to remove state from the driver.
+ * @param blocking Whether to block until unbroadcasted
+ */
+ def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
+ TorrentBroadcast.unpersist(id, removeFromDriver, blocking)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
index 449b953530ff9..86305d2ea8a09 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
@@ -23,7 +23,8 @@ private[spark] class ApplicationDescription(
val memoryPerSlave: Int,
val command: Command,
val sparkHome: Option[String],
- val appUiUrl: String)
+ var appUiUrl: String,
+ val eventLogDir: Option[String] = None)
extends Serializable {
val user = System.getProperty("user.name", "")
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index eb5676b51d836..7ead1171525d2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -26,7 +26,7 @@ import akka.pattern.ask
import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
import org.apache.log4j.{Level, Logger}
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -54,8 +54,21 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends
System.getenv().foreach{case (k, v) => env(k) = v}
val mainClass = "org.apache.spark.deploy.worker.DriverWrapper"
+
+ val classPathConf = "spark.driver.extraClassPath"
+ val classPathEntries = sys.props.get(classPathConf).toSeq.flatMap { cp =>
+ cp.split(java.io.File.pathSeparator)
+ }
+
+ val libraryPathConf = "spark.driver.extraLibraryPath"
+ val libraryPathEntries = sys.props.get(libraryPathConf).toSeq.flatMap { cp =>
+ cp.split(java.io.File.pathSeparator)
+ }
+
+ val javaOptionsConf = "spark.driver.extraJavaOptions"
+ val javaOpts = sys.props.get(javaOptionsConf)
val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++
- driverArgs.driverOptions, env)
+ driverArgs.driverOptions, env, classPathEntries, libraryPathEntries, javaOpts)
val driverDescription = new DriverDescription(
driverArgs.jarUrl,
@@ -128,6 +141,9 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends
*/
object Client {
def main(args: Array[String]) {
+ println("WARNING: This client is deprecated and will be removed in a future version of Spark.")
+ println("Use ./bin/spark-submit with \"--master spark://host:port\"")
+
val conf = new SparkConf()
val driverArgs = new ClientArguments(args)
@@ -141,7 +157,7 @@ object Client {
// TODO: See if we can initialize akka so return messages are sent back using the same TCP
// flow. Else, this (sadly) requires the DriverClient be routable from the Master.
val (actorSystem, _) = AkkaUtils.createActorSystem(
- "driverClient", Utils.localHostName(), 0, false, conf)
+ "driverClient", Utils.localHostName(), 0, false, conf, new SecurityManager(conf))
actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf))
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
index 00f5cd54ad650..5da9615c9e9af 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -43,7 +43,7 @@ private[spark] class ClientArguments(args: Array[String]) {
// kill parameters
var driverId: String = ""
-
+
parse(args.toList)
def parse(args: List[String]): Unit = args match {
@@ -112,5 +112,5 @@ private[spark] class ClientArguments(args: Array[String]) {
}
object ClientArguments {
- def isValidJarUrl(s: String) = s.matches("(.+):(.+)jar")
+ def isValidJarUrl(s: String): Boolean = s.matches("(.+):(.+)jar")
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/Command.scala b/core/src/main/scala/org/apache/spark/deploy/Command.scala
index fa8af9a646750..32f3ba385084f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Command.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Command.scala
@@ -22,5 +22,8 @@ import scala.collection.Map
private[spark] case class Command(
mainClass: String,
arguments: Seq[String],
- environment: Map[String, String]) {
+ environment: Map[String, String],
+ classPathEntries: Seq[String],
+ libraryPathEntries: Seq[String],
+ extraJavaOptions: Option[String] = None) {
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index 83ce14a0a806a..a7368f9f3dfbe 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -86,6 +86,10 @@ private[deploy] object DeployMessages {
case class KillDriver(driverId: String) extends DeployMessage
+ // Worker internal
+
+ case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders
+
// AppClient to Master
case class RegisterApplication(appDescription: ApplicationDescription)
diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
index 190b331cfe7d8..47dbcd87c35b5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
@@ -25,24 +25,30 @@ import scala.collection.mutable.ListBuffer
import scala.concurrent.{Await, future, promise}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
+import scala.language.postfixOps
import scala.sys.process._
-import net.liftweb.json.JsonParser
+import org.json4s._
+import org.json4s.jackson.JsonMethods
-import org.apache.spark.{Logging, SparkContext}
-import org.apache.spark.deploy.master.RecoveryState
+import org.apache.spark.{Logging, SparkConf, SparkContext}
+import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil}
/**
* This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master.
* In order to mimic a real distributed cluster more closely, Docker is used.
* Execute using
- * ./spark-class org.apache.spark.deploy.FaultToleranceTest
+ * ./bin/spark-class org.apache.spark.deploy.FaultToleranceTest
*
- * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS:
+ * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS
+ * *and* SPARK_JAVA_OPTS:
* - spark.deploy.recoveryMode=ZOOKEEPER
* - spark.deploy.zookeeper.url=172.17.42.1:2181
* Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port.
*
+ * In case of failure, make sure to kill off prior docker containers before restarting:
+ * docker kill $(docker ps -q)
+ *
* Unfortunately, due to the Docker dependency this suite cannot be run automatically without a
* working installation of Docker. In addition to having Docker, the following are assumed:
* - Docker can run without sudo (see http://docs.docker.io/en/latest/use/basics/)
@@ -50,10 +56,16 @@ import org.apache.spark.deploy.master.RecoveryState
* docker/ directory. Run 'docker/spark-test/build' to generate these.
*/
private[spark] object FaultToleranceTest extends App with Logging {
+
+ val conf = new SparkConf()
+ val ZK_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark")
+
val masters = ListBuffer[TestMasterInfo]()
val workers = ListBuffer[TestWorkerInfo]()
var sc: SparkContext = _
+ val zk = SparkCuratorUtil.newClient(conf)
+
var numPassed = 0
var numFailed = 0
@@ -71,6 +83,10 @@ private[spark] object FaultToleranceTest extends App with Logging {
sc = null
}
terminateCluster()
+
+ // Clear ZK directories in between tests (for speed purposes)
+ SparkCuratorUtil.deleteRecursive(zk, ZK_DIR + "/spark_leader")
+ SparkCuratorUtil.deleteRecursive(zk, ZK_DIR + "/master_status")
}
test("sanity-basic") {
@@ -167,26 +183,34 @@ private[spark] object FaultToleranceTest extends App with Logging {
try {
fn
numPassed += 1
+ logInfo("==============================================")
logInfo("Passed: " + name)
+ logInfo("==============================================")
} catch {
case e: Exception =>
numFailed += 1
+ logInfo("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
logError("FAILED: " + name, e)
+ logInfo("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
+ sys.exit(1)
}
afterEach()
}
def addMasters(num: Int) {
+ logInfo(s">>>>> ADD MASTERS $num <<<<<")
(1 to num).foreach { _ => masters += SparkDocker.startMaster(dockerMountDir) }
}
def addWorkers(num: Int) {
+ logInfo(s">>>>> ADD WORKERS $num <<<<<")
val masterUrls = getMasterUrls(masters)
(1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) }
}
/** Creates a SparkContext, which constructs a Client to interact with our cluster. */
def createClient() = {
+ logInfo(">>>>> CREATE CLIENT <<<<<")
if (sc != null) { sc.stop() }
// Counter-hack: Because of a hack in SparkEnv#create() that changes this
// property, we need to reset it.
@@ -205,6 +229,7 @@ private[spark] object FaultToleranceTest extends App with Logging {
}
def killLeader(): Unit = {
+ logInfo(">>>>> KILL LEADER <<<<<")
masters.foreach(_.readState())
val leader = getLeader
masters -= leader
@@ -214,6 +239,7 @@ private[spark] object FaultToleranceTest extends App with Logging {
def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis)
def terminateCluster() {
+ logInfo(">>>>> TERMINATE CLUSTER <<<<<")
masters.foreach(_.kill())
workers.foreach(_.kill())
masters.clear()
@@ -244,6 +270,7 @@ private[spark] object FaultToleranceTest extends App with Logging {
* are all alive in a proper configuration (e.g., only one leader).
*/
def assertValidClusterState() = {
+ logInfo(">>>>> ASSERT VALID CLUSTER STATE <<<<<")
assertUsable()
var numAlive = 0
var numStandby = 0
@@ -311,7 +338,7 @@ private[spark] object FaultToleranceTest extends App with Logging {
private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile: File)
extends Logging {
- implicit val formats = net.liftweb.json.DefaultFormats
+ implicit val formats = org.json4s.DefaultFormats
var state: RecoveryState.Value = _
var liveWorkerIPs: List[String] = _
var numLiveApps = 0
@@ -321,11 +348,15 @@ private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val
def readState() {
try {
val masterStream = new InputStreamReader(new URL("http://%s:8080/json".format(ip)).openStream)
- val json = JsonParser.parse(masterStream, closeAutomatically = true)
+ val json = JsonMethods.parse(masterStream)
val workers = json \ "workers"
val liveWorkers = workers.children.filter(w => (w \ "state").extract[String] == "ALIVE")
- liveWorkerIPs = liveWorkers.map(w => (w \ "host").extract[String])
+ // Extract the worker IP from "webuiaddress" (rather than "host") because the host name
+ // on containers is a weird hash instead of the actual IP address.
+ liveWorkerIPs = liveWorkers.map {
+ w => (w \ "webuiaddress").extract[String].stripPrefix("http://").stripSuffix(":8081")
+ }
numLiveApps = (json \ "activeapps").children.size
@@ -349,7 +380,7 @@ private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val
private[spark] class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile: File)
extends Logging {
- implicit val formats = net.liftweb.json.DefaultFormats
+ implicit val formats = org.json4s.DefaultFormats
logDebug("Created worker: " + this)
@@ -402,7 +433,7 @@ private[spark] object Docker extends Logging {
def makeRunCmd(imageTag: String, args: String = "", mountDir: String = ""): ProcessBuilder = {
val mountCmd = if (mountDir != "") { " -v " + mountDir } else ""
- val cmd = "docker run %s %s %s".format(mountCmd, imageTag, args)
+ val cmd = "docker run -privileged %s %s %s".format(mountCmd, imageTag, args)
logDebug("Run command: " + cmd)
cmd
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
index 318beb5db5214..c4f5e294a393e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy
-import net.liftweb.json.JsonDSL._
+import org.json4s.JsonDSL._
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse}
import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo}
@@ -43,7 +43,6 @@ private[spark] object JsonProtocol {
("starttime" -> obj.startTime) ~
("id" -> obj.id) ~
("name" -> obj.desc.name) ~
- ("appuiurl" -> obj.appUiUrl) ~
("cores" -> obj.desc.maxCores) ~
("user" -> obj.desc.user) ~
("memoryperslave" -> obj.desc.memoryPerSlave) ~
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index a73b459c3cea1..9a7a113c95715 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -66,9 +66,9 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
// TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors!
// This is unfortunate, but for now we just comment it out.
workerActorSystems.foreach(_.shutdown())
- //workerActorSystems.foreach(_.awaitTermination())
+ // workerActorSystems.foreach(_.awaitTermination())
masterActorSystems.foreach(_.shutdown())
- //masterActorSystems.foreach(_.awaitTermination())
+ // masterActorSystems.foreach(_.awaitTermination())
masterActorSystems.clear()
workerActorSystems.clear()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index b479225b45ee9..498fcc520ac5e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -21,27 +21,24 @@ import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.{SparkContext, SparkException}
+import scala.collection.JavaConversions._
+
/**
* Contains util methods to interact with Hadoop from Spark.
*/
class SparkHadoopUtil {
- val conf = newConfiguration()
+ val conf: Configuration = newConfiguration()
UserGroupInformation.setConfiguration(conf)
def runAsUser(user: String)(func: () => Unit) {
- // if we are already running as the user intended there is no reason to do the doAs. It
- // will actually break secure HDFS access as it doesn't fill in the credentials. Also if
- // the user is UNKNOWN then we shouldn't be creating a remote unknown user
- // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only
- // in SparkContext.
- val currentUser = Option(System.getProperty("user.name")).
- getOrElse(SparkContext.SPARK_UNKNOWN_USER)
- if (user != SparkContext.SPARK_UNKNOWN_USER && currentUser != user) {
+ if (user != SparkContext.SPARK_UNKNOWN_USER) {
val ugi = UserGroupInformation.createRemoteUser(user)
+ transferCredentials(UserGroupInformation.getCurrentUser(), ugi)
ugi.doAs(new PrivilegedExceptionAction[Unit] {
def run: Unit = func()
})
@@ -50,6 +47,12 @@ class SparkHadoopUtil {
}
}
+ def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) {
+ for (token <- source.getTokens()) {
+ dest.addToken(token)
+ }
+ }
+
/**
* Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
* subsystems.
@@ -63,6 +66,19 @@ class SparkHadoopUtil {
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }
+
+ def getCurrentUserCredentials(): Credentials = { null }
+
+ def addCurrentUserCredentials(creds: Credentials) {}
+
+ def addSecretKeyToUserCredentials(key: String, secret: String) {}
+
+ def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null }
+
+ def loginUserFromKeytab(principalName: String, keytabFilename: String) {
+ UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename)
+ }
+
}
object SparkHadoopUtil {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
new file mode 100644
index 0000000000000..24edc60684376
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import java.io.{File, PrintStream}
+import java.net.{URI, URL}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
+
+import org.apache.spark.executor.ExecutorURLClassLoader
+
+/**
+ * Scala code behind the spark-submit script. The script handles setting up the classpath with
+ * relevant Spark dependencies and provides a layer over the different cluster managers and deploy
+ * modes that Spark supports.
+ */
+object SparkSubmit {
+ private val YARN = 1
+ private val STANDALONE = 2
+ private val MESOS = 4
+ private val LOCAL = 8
+ private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL
+
+ private var clusterManager: Int = LOCAL
+
+ /**
+ * A special jar name that indicates the class being run is inside of Spark itself,
+ * and therefore no user jar is needed.
+ */
+ private val RESERVED_JAR_NAME = "spark-internal"
+
+ def main(args: Array[String]) {
+ val appArgs = new SparkSubmitArguments(args)
+ if (appArgs.verbose) {
+ printStream.println(appArgs)
+ }
+ val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
+ launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose)
+ }
+
+ // Exposed for testing
+ private[spark] var printStream: PrintStream = System.err
+ private[spark] var exitFn: () => Unit = () => System.exit(-1)
+
+ private[spark] def printErrorAndExit(str: String) = {
+ printStream.println("error: " + str)
+ printStream.println("run with --help for more information or --verbose for debugging output")
+ exitFn()
+ }
+ private[spark] def printWarning(str: String) = printStream.println("warning: " + str)
+
+ /**
+ * @return
+ * a tuple containing the arguments for the child, a list of classpath
+ * entries for the child, a list of system propertes, a list of env vars
+ * and the main class for the child
+ */
+ private[spark] def createLaunchEnv(appArgs: SparkSubmitArguments): (ArrayBuffer[String],
+ ArrayBuffer[String], Map[String, String], String) = {
+ if (appArgs.master.startsWith("local")) {
+ clusterManager = LOCAL
+ } else if (appArgs.master.startsWith("yarn")) {
+ clusterManager = YARN
+ } else if (appArgs.master.startsWith("spark")) {
+ clusterManager = STANDALONE
+ } else if (appArgs.master.startsWith("mesos")) {
+ clusterManager = MESOS
+ } else {
+ printErrorAndExit("master must start with yarn, mesos, spark, or local")
+ }
+
+ // Because "yarn-cluster" and "yarn-client" encapsulate both the master
+ // and deploy mode, we have some logic to infer the master and deploy mode
+ // from each other if only one is specified, or exit early if they are at odds.
+ if (appArgs.deployMode == null &&
+ (appArgs.master == "yarn-standalone" || appArgs.master == "yarn-cluster")) {
+ appArgs.deployMode = "cluster"
+ }
+ if (appArgs.deployMode == "cluster" && appArgs.master == "yarn-client") {
+ printErrorAndExit("Deploy mode \"cluster\" and master \"yarn-client\" are not compatible")
+ }
+ if (appArgs.deployMode == "client" &&
+ (appArgs.master == "yarn-standalone" || appArgs.master == "yarn-cluster")) {
+ printErrorAndExit("Deploy mode \"client\" and master \"" + appArgs.master
+ + "\" are not compatible")
+ }
+ if (appArgs.deployMode == "cluster" && appArgs.master.startsWith("yarn")) {
+ appArgs.master = "yarn-cluster"
+ }
+ if (appArgs.deployMode != "cluster" && appArgs.master.startsWith("yarn")) {
+ appArgs.master = "yarn-client"
+ }
+
+ val deployOnCluster = Option(appArgs.deployMode).getOrElse("client") == "cluster"
+
+ val childClasspath = new ArrayBuffer[String]()
+ val childArgs = new ArrayBuffer[String]()
+ val sysProps = new HashMap[String, String]()
+ var childMainClass = ""
+
+ if (clusterManager == MESOS && deployOnCluster) {
+ printErrorAndExit("Mesos does not support running the driver on the cluster")
+ }
+
+ if (!deployOnCluster) {
+ childMainClass = appArgs.mainClass
+ if (appArgs.primaryResource != RESERVED_JAR_NAME) {
+ childClasspath += appArgs.primaryResource
+ }
+ } else if (clusterManager == YARN) {
+ childMainClass = "org.apache.spark.deploy.yarn.Client"
+ childArgs += ("--jar", appArgs.primaryResource)
+ childArgs += ("--class", appArgs.mainClass)
+ }
+
+ val options = List[OptionAssigner](
+ new OptionAssigner(appArgs.master, ALL_CLUSTER_MGRS, false, sysProp = "spark.master"),
+ new OptionAssigner(appArgs.driverExtraClassPath, STANDALONE | YARN, true,
+ sysProp = "spark.driver.extraClassPath"),
+ new OptionAssigner(appArgs.driverExtraJavaOptions, STANDALONE | YARN, true,
+ sysProp = "spark.driver.extraJavaOptions"),
+ new OptionAssigner(appArgs.driverExtraLibraryPath, STANDALONE | YARN, true,
+ sysProp = "spark.driver.extraLibraryPath"),
+ new OptionAssigner(appArgs.driverMemory, YARN, true, clOption = "--driver-memory"),
+ new OptionAssigner(appArgs.name, YARN, true, clOption = "--name"),
+ new OptionAssigner(appArgs.queue, YARN, true, clOption = "--queue"),
+ new OptionAssigner(appArgs.queue, YARN, false, sysProp = "spark.yarn.queue"),
+ new OptionAssigner(appArgs.numExecutors, YARN, true, clOption = "--num-executors"),
+ new OptionAssigner(appArgs.numExecutors, YARN, false, sysProp = "spark.executor.instances"),
+ new OptionAssigner(appArgs.executorMemory, YARN, true, clOption = "--executor-memory"),
+ new OptionAssigner(appArgs.executorMemory, STANDALONE | MESOS | YARN, false,
+ sysProp = "spark.executor.memory"),
+ new OptionAssigner(appArgs.driverMemory, STANDALONE, true, clOption = "--memory"),
+ new OptionAssigner(appArgs.driverCores, STANDALONE, true, clOption = "--cores"),
+ new OptionAssigner(appArgs.executorCores, YARN, true, clOption = "--executor-cores"),
+ new OptionAssigner(appArgs.executorCores, YARN, false, sysProp = "spark.executor.cores"),
+ new OptionAssigner(appArgs.totalExecutorCores, STANDALONE | MESOS, false,
+ sysProp = "spark.cores.max"),
+ new OptionAssigner(appArgs.files, YARN, false, sysProp = "spark.yarn.dist.files"),
+ new OptionAssigner(appArgs.files, YARN, true, clOption = "--files"),
+ new OptionAssigner(appArgs.archives, YARN, false, sysProp = "spark.yarn.dist.archives"),
+ new OptionAssigner(appArgs.archives, YARN, true, clOption = "--archives"),
+ new OptionAssigner(appArgs.jars, YARN, true, clOption = "--addJars"),
+ new OptionAssigner(appArgs.files, LOCAL | STANDALONE | MESOS, true, sysProp = "spark.files"),
+ new OptionAssigner(appArgs.jars, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.jars"),
+ new OptionAssigner(appArgs.name, LOCAL | STANDALONE | MESOS, false,
+ sysProp = "spark.app.name")
+ )
+
+ // For client mode make any added jars immediately visible on the classpath
+ if (appArgs.jars != null && !deployOnCluster) {
+ for (jar <- appArgs.jars.split(",")) {
+ childClasspath += jar
+ }
+ }
+
+ for (opt <- options) {
+ if (opt.value != null && deployOnCluster == opt.deployOnCluster &&
+ (clusterManager & opt.clusterManager) != 0) {
+ if (opt.clOption != null) {
+ childArgs += (opt.clOption, opt.value)
+ } else if (opt.sysProp != null) {
+ sysProps.put(opt.sysProp, opt.value)
+ }
+ }
+ }
+
+ // For standalone mode, add the application jar automatically so the user doesn't have to
+ // call sc.addJar. TODO: Standalone mode in the cluster
+ if (clusterManager == STANDALONE) {
+ val existingJars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq())
+ sysProps.put("spark.jars", (existingJars ++ Seq(appArgs.primaryResource)).mkString(","))
+ }
+
+ if (deployOnCluster && clusterManager == STANDALONE) {
+ if (appArgs.supervise) {
+ childArgs += "--supervise"
+ }
+
+ childMainClass = "org.apache.spark.deploy.Client"
+ childArgs += "launch"
+ childArgs += (appArgs.master, appArgs.primaryResource, appArgs.mainClass)
+ }
+
+ // Arguments to be passed to user program
+ if (appArgs.childArgs != null) {
+ if (!deployOnCluster || clusterManager == STANDALONE) {
+ childArgs ++= appArgs.childArgs
+ } else if (clusterManager == YARN) {
+ for (arg <- appArgs.childArgs) {
+ childArgs += ("--arg", arg)
+ }
+ }
+ }
+
+ for ((k, v) <- appArgs.getDefaultSparkProperties) {
+ if (!sysProps.contains(k)) sysProps(k) = v
+ }
+
+ (childArgs, childClasspath, sysProps, childMainClass)
+ }
+
+ private def launch(childArgs: ArrayBuffer[String], childClasspath: ArrayBuffer[String],
+ sysProps: Map[String, String], childMainClass: String, verbose: Boolean = false) {
+
+ if (verbose) {
+ printStream.println(s"Main class:\n$childMainClass")
+ printStream.println(s"Arguments:\n${childArgs.mkString("\n")}")
+ printStream.println(s"System properties:\n${sysProps.mkString("\n")}")
+ printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}")
+ printStream.println("\n")
+ }
+
+ val loader = new ExecutorURLClassLoader(new Array[URL](0),
+ Thread.currentThread.getContextClassLoader)
+ Thread.currentThread.setContextClassLoader(loader)
+
+ for (jar <- childClasspath) {
+ addJarToClasspath(jar, loader)
+ }
+
+ for ((key, value) <- sysProps) {
+ System.setProperty(key, value)
+ }
+
+ val mainClass = Class.forName(childMainClass, true, loader)
+ val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass)
+ mainMethod.invoke(null, childArgs.toArray)
+ }
+
+ private def addJarToClasspath(localJar: String, loader: ExecutorURLClassLoader) {
+ val localJarFile = new File(new URI(localJar).getPath())
+ if (!localJarFile.exists()) {
+ printWarning(s"Jar $localJar does not exist, skipping.")
+ }
+
+ val url = localJarFile.getAbsoluteFile.toURI.toURL
+ loader.addURL(url)
+ }
+}
+
+/**
+ * Provides an indirection layer for passing arguments as system properties or flags to
+ * the user's driver program or to downstream launcher tools.
+ */
+private[spark] class OptionAssigner(val value: String,
+ val clusterManager: Int,
+ val deployOnCluster: Boolean,
+ val clOption: String = null,
+ val sysProp: String = null
+) { }
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
new file mode 100644
index 0000000000000..58d9e9add764a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -0,0 +1,338 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import java.io.{File, FileInputStream, IOException}
+import java.util.Properties
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+import org.apache.spark.SparkException
+import org.apache.spark.util.Utils
+
+/**
+ * Parses and encapsulates arguments from the spark-submit script.
+ */
+private[spark] class SparkSubmitArguments(args: Seq[String]) {
+ var master: String = null
+ var deployMode: String = null
+ var executorMemory: String = null
+ var executorCores: String = null
+ var totalExecutorCores: String = null
+ var propertiesFile: String = null
+ var driverMemory: String = null
+ var driverExtraClassPath: String = null
+ var driverExtraLibraryPath: String = null
+ var driverExtraJavaOptions: String = null
+ var driverCores: String = null
+ var supervise: Boolean = false
+ var queue: String = null
+ var numExecutors: String = null
+ var files: String = null
+ var archives: String = null
+ var mainClass: String = null
+ var primaryResource: String = null
+ var name: String = null
+ var childArgs: ArrayBuffer[String] = new ArrayBuffer[String]()
+ var jars: String = null
+ var verbose: Boolean = false
+
+ parseOpts(args.toList)
+ loadDefaults()
+ checkRequiredArguments()
+
+ /** Return default present in the currently defined defaults file. */
+ def getDefaultSparkProperties = {
+ val defaultProperties = new HashMap[String, String]()
+ if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile")
+ Option(propertiesFile).foreach { filename =>
+ val file = new File(filename)
+ SparkSubmitArguments.getPropertiesFromFile(file).foreach { case (k, v) =>
+ if (k.startsWith("spark")) {
+ defaultProperties(k) = v
+ if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v")
+ } else {
+ SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v")
+ }
+ }
+ }
+ defaultProperties
+ }
+
+ /** Fill in any undefined values based on the current properties file or built-in defaults. */
+ private def loadDefaults() = {
+
+ // Use common defaults file, if not specified by user
+ if (propertiesFile == null) {
+ sys.env.get("SPARK_HOME").foreach { sparkHome =>
+ val sep = File.separator
+ val defaultPath = s"${sparkHome}${sep}conf${sep}spark-defaults.conf"
+ val file = new File(defaultPath)
+ if (file.exists()) {
+ propertiesFile = file.getAbsolutePath
+ }
+ }
+ }
+
+ val defaultProperties = getDefaultSparkProperties
+ // Use properties file as fallback for values which have a direct analog to
+ // arguments in this script.
+ master = Option(master).getOrElse(defaultProperties.get("spark.master").orNull)
+ executorMemory = Option(executorMemory)
+ .getOrElse(defaultProperties.get("spark.executor.memory").orNull)
+ executorCores = Option(executorCores)
+ .getOrElse(defaultProperties.get("spark.executor.cores").orNull)
+ totalExecutorCores = Option(totalExecutorCores)
+ .getOrElse(defaultProperties.get("spark.cores.max").orNull)
+ name = Option(name).getOrElse(defaultProperties.get("spark.app.name").orNull)
+ jars = Option(jars).getOrElse(defaultProperties.get("spark.jars").orNull)
+
+ // This supports env vars in older versions of Spark
+ master = Option(master).getOrElse(System.getenv("MASTER"))
+ deployMode = Option(deployMode).getOrElse(System.getenv("DEPLOY_MODE"))
+
+ // Global defaults. These should be keep to minimum to avoid confusing behavior.
+ master = Option(master).getOrElse("local[*]")
+ }
+
+ /** Ensure that required fields exists. Call this only once all defaults are loaded. */
+ private def checkRequiredArguments() = {
+ if (args.length == 0) printUsageAndExit(-1)
+ if (primaryResource == null) SparkSubmit.printErrorAndExit("Must specify a primary resource")
+ if (mainClass == null) SparkSubmit.printErrorAndExit("Must specify a main class with --class")
+
+ if (master.startsWith("yarn")) {
+ val hasHadoopEnv = sys.env.contains("HADOOP_CONF_DIR") || sys.env.contains("YARN_CONF_DIR")
+ if (!hasHadoopEnv && !Utils.isTesting) {
+ throw new Exception(s"When running with master '$master' " +
+ "either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.")
+ }
+ }
+ }
+
+ override def toString = {
+ s"""Parsed arguments:
+ | master $master
+ | deployMode $deployMode
+ | executorMemory $executorMemory
+ | executorCores $executorCores
+ | totalExecutorCores $totalExecutorCores
+ | propertiesFile $propertiesFile
+ | driverMemory $driverMemory
+ | driverCores $driverCores
+ | driverExtraClassPath $driverExtraClassPath
+ | driverExtraLibraryPath $driverExtraLibraryPath
+ | driverExtraJavaOptions $driverExtraJavaOptions
+ | supervise $supervise
+ | queue $queue
+ | numExecutors $numExecutors
+ | files $files
+ | archives $archives
+ | mainClass $mainClass
+ | primaryResource $primaryResource
+ | name $name
+ | childArgs [${childArgs.mkString(" ")}]
+ | jars $jars
+ | verbose $verbose
+ |
+ |Default properties from $propertiesFile:
+ |${getDefaultSparkProperties.mkString(" ", "\n ", "\n")}
+ """.stripMargin
+ }
+
+ /** Fill in values by parsing user options. */
+ private def parseOpts(opts: Seq[String]): Unit = {
+ // Delineates parsing of Spark options from parsing of user options.
+ var inSparkOpts = true
+ parse(opts)
+
+ def parse(opts: Seq[String]): Unit = opts match {
+ case ("--name") :: value :: tail =>
+ name = value
+ parse(tail)
+
+ case ("--master") :: value :: tail =>
+ master = value
+ parse(tail)
+
+ case ("--class") :: value :: tail =>
+ mainClass = value
+ parse(tail)
+
+ case ("--deploy-mode") :: value :: tail =>
+ if (value != "client" && value != "cluster") {
+ SparkSubmit.printErrorAndExit("--deploy-mode must be either \"client\" or \"cluster\"")
+ }
+ deployMode = value
+ parse(tail)
+
+ case ("--num-executors") :: value :: tail =>
+ numExecutors = value
+ parse(tail)
+
+ case ("--total-executor-cores") :: value :: tail =>
+ totalExecutorCores = value
+ parse(tail)
+
+ case ("--executor-cores") :: value :: tail =>
+ executorCores = value
+ parse(tail)
+
+ case ("--executor-memory") :: value :: tail =>
+ executorMemory = value
+ parse(tail)
+
+ case ("--driver-memory") :: value :: tail =>
+ driverMemory = value
+ parse(tail)
+
+ case ("--driver-cores") :: value :: tail =>
+ driverCores = value
+ parse(tail)
+
+ case ("--driver-class-path") :: value :: tail =>
+ driverExtraClassPath = value
+ parse(tail)
+
+ case ("--driver-java-options") :: value :: tail =>
+ driverExtraJavaOptions = value
+ parse(tail)
+
+ case ("--driver-library-path") :: value :: tail =>
+ driverExtraLibraryPath = value
+ parse(tail)
+
+ case ("--properties-file") :: value :: tail =>
+ propertiesFile = value
+ parse(tail)
+
+ case ("--supervise") :: tail =>
+ supervise = true
+ parse(tail)
+
+ case ("--queue") :: value :: tail =>
+ queue = value
+ parse(tail)
+
+ case ("--files") :: value :: tail =>
+ files = value
+ parse(tail)
+
+ case ("--archives") :: value :: tail =>
+ archives = value
+ parse(tail)
+
+ case ("--jars") :: value :: tail =>
+ jars = value
+ parse(tail)
+
+ case ("--help" | "-h") :: tail =>
+ printUsageAndExit(0)
+
+ case ("--verbose" | "-v") :: tail =>
+ verbose = true
+ parse(tail)
+
+ case value :: tail =>
+ if (inSparkOpts) {
+ value match {
+ // convert --foo=bar to --foo bar
+ case v if v.startsWith("--") && v.contains("=") && v.split("=").size == 2 =>
+ val parts = v.split("=")
+ parse(Seq(parts(0), parts(1)) ++ tail)
+ case v if v.startsWith("-") =>
+ val errMessage = s"Unrecognized option '$value'."
+ SparkSubmit.printErrorAndExit(errMessage)
+ case v =>
+ primaryResource = v
+ inSparkOpts = false
+ parse(tail)
+ }
+ } else {
+ childArgs += value
+ parse(tail)
+ }
+
+ case Nil =>
+ }
+ }
+
+ private def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ val outStream = SparkSubmit.printStream
+ if (unknownParam != null) {
+ outStream.println("Unknown/unsupported param " + unknownParam)
+ }
+ outStream.println(
+ """Usage: spark-submit [options] [app options]
+ |Options:
+ | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local.
+ | --deploy-mode DEPLOY_MODE Mode to deploy the app in, either 'client' or 'cluster'.
+ | --class CLASS_NAME Name of your app's main class (required for Java apps).
+ | --arg ARG Argument to be passed to your application's main class. This
+ | option can be specified multiple times for multiple args.
+ | --name NAME The name of your application (Default: 'Spark').
+ | --jars JARS A comma-separated list of local jars to include on the
+ | driver classpath and that SparkContext.addJar will work
+ | with. Doesn't work on standalone with 'cluster' deploy mode.
+ | --files FILES Comma separated list of files to be placed in the working dir
+ | of each executor.
+ | --properties-file FILE Path to a file from which to load extra properties. If not
+ | specified, this will look for conf/spark-defaults.conf.
+ |
+ | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512M).
+ | --driver-java-options Extra Java options to pass to the driver
+ | --driver-library-path Extra library path entries to pass to the driver
+ | --driver-class-path Extra class path entries to pass to the driver
+ |
+ | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G).
+ |
+ | Spark standalone with cluster deploy mode only:
+ | --driver-cores NUM Cores for driver (Default: 1).
+ | --supervise If given, restarts the driver on failure.
+ |
+ | Spark standalone and Mesos only:
+ | --total-executor-cores NUM Total cores for all executors.
+ |
+ | YARN-only:
+ | --executor-cores NUM Number of cores per executor (Default: 1).
+ | --queue QUEUE_NAME The YARN queue to submit to (Default: 'default').
+ | --num-executors NUM Number of executors to (Default: 2).
+ | --archives ARCHIVES Comma separated list of archives to be extracted into the
+ | working dir of each executor.""".stripMargin
+ )
+ SparkSubmit.exitFn()
+ }
+}
+
+object SparkSubmitArguments {
+ /** Load properties present in the given file. */
+ def getPropertiesFromFile(file: File): Seq[(String, String)] = {
+ require(file.exists(), s"Properties file ${file.getName} does not exist")
+ val inputStream = new FileInputStream(file)
+ val properties = new Properties()
+ try {
+ properties.load(inputStream)
+ } catch {
+ case e: IOException =>
+ val message = s"Failed when loading Spark properties file ${file.getName}"
+ throw new SparkException(message, e)
+ }
+ properties.stringPropertyNames().toSeq.map(k => (k, properties(k)))
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/WebUI.scala b/core/src/main/scala/org/apache/spark/deploy/WebUI.scala
deleted file mode 100644
index ae258b58b9cc5..0000000000000
--- a/core/src/main/scala/org/apache/spark/deploy/WebUI.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.deploy
-
-import java.text.SimpleDateFormat
-import java.util.Date
-
-/**
- * Utilities used throughout the web UI.
- */
-private[spark] object DeployWebUI {
- val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
-
- def formatDate(date: Date): String = DATE_FORMAT.format(date)
-
- def formatDate(timestamp: Long): String = DATE_FORMAT.format(new Date(timestamp))
-
- def formatDuration(milliseconds: Long): String = {
- val seconds = milliseconds.toDouble / 1000
- if (seconds < 60) {
- return "%.0f s".format(seconds)
- }
- val minutes = seconds / 60
- if (minutes < 10) {
- return "%.1f min".format(minutes)
- } else if (minutes < 60) {
- return "%.0f min".format(minutes)
- }
- val hours = minutes / 60
- return "%.1f h".format(hours)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
index 1550c3eb4286b..888dd45e93c6a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.client
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -45,11 +45,12 @@ private[spark] object TestClient {
def main(args: Array[String]) {
val url = args(0)
+ val conf = new SparkConf
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0,
- conf = new SparkConf)
+ conf = conf, securityManager = new SecurityManager(conf))
val desc = new ApplicationDescription(
- "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()),
- Some("dummy-spark-home"), "ignored")
+ "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(),
+ Seq()), Some("dummy-spark-home"), "ignored")
val listener = new TestListener
val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf)
client.start()
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
new file mode 100644
index 0000000000000..180c853ce3096
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.history
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import org.apache.spark.ui.{WebUIPage, UIUtils}
+
+private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val appRows = parent.appIdToInfo.values.toSeq.sortBy { app => -app.lastUpdated }
+ val appTable = UIUtils.listingTable(appHeader, appRow, appRows)
+ val content =
+
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
new file mode 100644
index 0000000000000..1238bbf9da2fd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.history
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.{FileStatus, Path}
+
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.scheduler._
+import org.apache.spark.ui.{WebUI, SparkUI}
+import org.apache.spark.ui.JettyUtils._
+import org.apache.spark.util.Utils
+
+/**
+ * A web server that renders SparkUIs of completed applications.
+ *
+ * For the standalone mode, MasterWebUI already achieves this functionality. Thus, the
+ * main use case of the HistoryServer is in other deploy modes (e.g. Yarn or Mesos).
+ *
+ * The logging directory structure is as follows: Within the given base directory, each
+ * application's event logs are maintained in the application's own sub-directory. This
+ * is the same structure as maintained in the event log write code path in
+ * EventLoggingListener.
+ *
+ * @param baseLogDir The base directory in which event logs are found
+ */
+class HistoryServer(
+ val baseLogDir: String,
+ securityManager: SecurityManager,
+ conf: SparkConf)
+ extends WebUI(securityManager, HistoryServer.WEB_UI_PORT, conf) with Logging {
+
+ import HistoryServer._
+
+ private val fileSystem = Utils.getHadoopFileSystem(baseLogDir)
+ private val localHost = Utils.localHostName()
+ private val publicHost = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHost)
+
+ // A timestamp of when the disk was last accessed to check for log updates
+ private var lastLogCheckTime = -1L
+
+ // Number of completed applications found in this directory
+ private var numCompletedApplications = 0
+
+ @volatile private var stopped = false
+
+ /**
+ * A background thread that periodically checks for event log updates on disk.
+ *
+ * If a log check is invoked manually in the middle of a period, this thread re-adjusts the
+ * time at which it performs the next log check to maintain the same period as before.
+ *
+ * TODO: Add a mechanism to update manually.
+ */
+ private val logCheckingThread = new Thread {
+ override def run() {
+ while (!stopped) {
+ val now = System.currentTimeMillis
+ if (now - lastLogCheckTime > UPDATE_INTERVAL_MS) {
+ checkForLogs()
+ Thread.sleep(UPDATE_INTERVAL_MS)
+ } else {
+ // If the user has manually checked for logs recently, wait until
+ // UPDATE_INTERVAL_MS after the last check time
+ Thread.sleep(lastLogCheckTime + UPDATE_INTERVAL_MS - now)
+ }
+ }
+ }
+ }
+
+ // A mapping of application ID to its history information, which includes the rendered UI
+ val appIdToInfo = mutable.HashMap[String, ApplicationHistoryInfo]()
+
+ initialize()
+
+ /**
+ * Initialize the history server.
+ *
+ * This starts a background thread that periodically synchronizes information displayed on
+ * this UI with the event logs in the provided base directory.
+ */
+ def initialize() {
+ attachPage(new HistoryPage(this))
+ attachHandler(createStaticHandler(STATIC_RESOURCE_DIR, "/static"))
+ }
+
+ /** Bind to the HTTP server behind this web interface. */
+ override def bind() {
+ super.bind()
+ logCheckingThread.start()
+ }
+
+ /**
+ * Check for any updates to event logs in the base directory. This is only effective once
+ * the server has been bound.
+ *
+ * If a new completed application is found, the server renders the associated SparkUI
+ * from the application's event logs, attaches this UI to itself, and stores metadata
+ * information for this application.
+ *
+ * If the logs for an existing completed application are no longer found, the server
+ * removes all associated information and detaches the SparkUI.
+ */
+ def checkForLogs() = synchronized {
+ if (serverInfo.isDefined) {
+ lastLogCheckTime = System.currentTimeMillis
+ logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTime))
+ try {
+ val logStatus = fileSystem.listStatus(new Path(baseLogDir))
+ val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]()
+ val logInfos = logDirs
+ .sortBy { dir => getModificationTime(dir) }
+ .map { dir => (dir, EventLoggingListener.parseLoggingInfo(dir.getPath, fileSystem)) }
+ .filter { case (dir, info) => info.applicationComplete }
+
+ // Logging information for applications that should be retained
+ val retainedLogInfos = logInfos.takeRight(RETAINED_APPLICATIONS)
+ val retainedAppIds = retainedLogInfos.map { case (dir, _) => dir.getPath.getName }
+
+ // Remove any applications that should no longer be retained
+ appIdToInfo.foreach { case (appId, info) =>
+ if (!retainedAppIds.contains(appId)) {
+ detachSparkUI(info.ui)
+ appIdToInfo.remove(appId)
+ }
+ }
+
+ // Render the application's UI if it is not already there
+ retainedLogInfos.foreach { case (dir, info) =>
+ val appId = dir.getPath.getName
+ if (!appIdToInfo.contains(appId)) {
+ renderSparkUI(dir, info)
+ }
+ }
+
+ // Track the total number of completed applications observed this round
+ numCompletedApplications = logInfos.size
+
+ } catch {
+ case t: Throwable => logError("Exception in checking for event log updates", t)
+ }
+ } else {
+ logWarning("Attempted to check for event log updates before binding the server.")
+ }
+ }
+
+ /**
+ * Render a new SparkUI from the event logs if the associated application is completed.
+ *
+ * HistoryServer looks for a special file that indicates application completion in the given
+ * directory. If this file exists, the associated application is regarded to be completed, in
+ * which case the server proceeds to render the SparkUI. Otherwise, the server does nothing.
+ */
+ private def renderSparkUI(logDir: FileStatus, elogInfo: EventLoggingInfo) {
+ val path = logDir.getPath
+ val appId = path.getName
+ val replayBus = new ReplayListenerBus(elogInfo.logPaths, fileSystem, elogInfo.compressionCodec)
+ val appListener = new ApplicationEventListener
+ replayBus.addListener(appListener)
+ val appConf = conf.clone()
+ val appSecManager = new SecurityManager(appConf)
+ val ui = new SparkUI(conf, appSecManager, replayBus, appId, "/history/" + appId)
+
+ // Do not call ui.bind() to avoid creating a new server for each application
+ replayBus.replay()
+ if (appListener.applicationStarted) {
+ appSecManager.setUIAcls(HISTORY_UI_ACLS_ENABLED)
+ appSecManager.setViewAcls(appListener.sparkUser, appListener.viewAcls)
+ attachSparkUI(ui)
+ val appName = appListener.appName
+ val sparkUser = appListener.sparkUser
+ val startTime = appListener.startTime
+ val endTime = appListener.endTime
+ val lastUpdated = getModificationTime(logDir)
+ ui.setAppName(appName + " (completed)")
+ appIdToInfo(appId) = ApplicationHistoryInfo(appId, appName, startTime, endTime,
+ lastUpdated, sparkUser, path, ui)
+ }
+ }
+
+ /** Stop the server and close the file system. */
+ override def stop() {
+ super.stop()
+ stopped = true
+ fileSystem.close()
+ }
+
+ /** Attach a reconstructed UI to this server. Only valid after bind(). */
+ private def attachSparkUI(ui: SparkUI) {
+ assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs")
+ ui.getHandlers.foreach(attachHandler)
+ addFilters(ui.getHandlers, conf)
+ }
+
+ /** Detach a reconstructed UI from this server. Only valid after bind(). */
+ private def detachSparkUI(ui: SparkUI) {
+ assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs")
+ ui.getHandlers.foreach(detachHandler)
+ }
+
+ /** Return the address of this server. */
+ def getAddress: String = "http://" + publicHost + ":" + boundPort
+
+ /** Return the number of completed applications found, whether or not the UI is rendered. */
+ def getNumApplications: Int = numCompletedApplications
+
+ /** Return when this directory was last modified. */
+ private def getModificationTime(dir: FileStatus): Long = {
+ try {
+ val logFiles = fileSystem.listStatus(dir.getPath)
+ if (logFiles != null && !logFiles.isEmpty) {
+ logFiles.map(_.getModificationTime).max
+ } else {
+ dir.getModificationTime
+ }
+ } catch {
+ case t: Throwable =>
+ logError("Exception in accessing modification time of %s".format(dir.getPath), t)
+ -1L
+ }
+ }
+}
+
+/**
+ * The recommended way of starting and stopping a HistoryServer is through the scripts
+ * start-history-server.sh and stop-history-server.sh. The path to a base log directory
+ * is must be specified, while the requested UI port is optional. For example:
+ *
+ * ./sbin/spark-history-server.sh /tmp/spark-events
+ * ./sbin/spark-history-server.sh hdfs://1.2.3.4:9000/spark-events
+ *
+ * This launches the HistoryServer as a Spark daemon.
+ */
+object HistoryServer {
+ private val conf = new SparkConf
+
+ // Interval between each check for event log updates
+ val UPDATE_INTERVAL_MS = conf.getInt("spark.history.updateInterval", 10) * 1000
+
+ // How many applications to retain
+ val RETAINED_APPLICATIONS = conf.getInt("spark.history.retainedApplications", 250)
+
+ // The port to which the web UI is bound
+ val WEB_UI_PORT = conf.getInt("spark.history.ui.port", 18080)
+
+ // set whether to enable or disable view acls for all applications
+ val HISTORY_UI_ACLS_ENABLED = conf.getBoolean("spark.history.ui.acls.enable", false)
+
+ val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR
+
+ def main(argStrings: Array[String]) {
+ initSecurity()
+ val args = new HistoryServerArguments(argStrings)
+ val securityManager = new SecurityManager(conf)
+ val server = new HistoryServer(args.logDir, securityManager, conf)
+ server.bind()
+
+ // Wait until the end of the world... or if the HistoryServer process is manually stopped
+ while(true) { Thread.sleep(Int.MaxValue) }
+ server.stop()
+ }
+
+ def initSecurity() {
+ // If we are accessing HDFS and it has security enabled (Kerberos), we have to login
+ // from a keytab file so that we can access HDFS beyond the kerberos ticket expiration.
+ // As long as it is using Hadoop rpc (hdfs://), a relogin will automatically
+ // occur from the keytab.
+ if (conf.getBoolean("spark.history.kerberos.enabled", false)) {
+ // if you have enabled kerberos the following 2 params must be set
+ val principalName = conf.get("spark.history.kerberos.principal")
+ val keytabFilename = conf.get("spark.history.kerberos.keytab")
+ SparkHadoopUtil.get.loginUserFromKeytab(principalName, keytabFilename)
+ }
+ }
+
+}
+
+
+private[spark] case class ApplicationHistoryInfo(
+ id: String,
+ name: String,
+ startTime: Long,
+ endTime: Long,
+ lastUpdated: Long,
+ sparkUser: String,
+ logDirPath: Path,
+ ui: SparkUI) {
+ def started = startTime != -1
+ def completed = endTime != -1
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
new file mode 100644
index 0000000000000..943c061743dbd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.history
+
+import java.net.URI
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.util.Utils
+
+/**
+ * Command-line parser for the master.
+ */
+private[spark] class HistoryServerArguments(args: Array[String]) {
+ var logDir = ""
+
+ parse(args.toList)
+
+ private def parse(args: List[String]): Unit = {
+ args match {
+ case ("--dir" | "-d") :: value :: tail =>
+ logDir = value
+ parse(tail)
+
+ case ("--help" | "-h") :: tail =>
+ printUsageAndExit(0)
+
+ case Nil =>
+
+ case _ =>
+ printUsageAndExit(1)
+ }
+ validateLogDir()
+ }
+
+ private def validateLogDir() {
+ if (logDir == "") {
+ System.err.println("Logging directory must be specified.")
+ printUsageAndExit(1)
+ }
+ val fileSystem = Utils.getHadoopFileSystem(new URI(logDir))
+ val path = new Path(logDir)
+ if (!fileSystem.exists(path)) {
+ System.err.println("Logging directory specified does not exist: %s".format(logDir))
+ printUsageAndExit(1)
+ }
+ if (!fileSystem.getFileStatus(path).isDir) {
+ System.err.println("Logging directory specified is not a directory: %s".format(logDir))
+ printUsageAndExit(1)
+ }
+ }
+
+ private def printUsageAndExit(exitCode: Int) {
+ System.err.println(
+ "Usage: HistoryServer [options]\n" +
+ "\n" +
+ "Options:\n" +
+ " -d DIR, --dir DIR Location of event log files")
+ System.exit(exitCode)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index e8867bc1691d3..46b9f4dc7d3ba 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -31,7 +31,6 @@ private[spark] class ApplicationInfo(
val desc: ApplicationDescription,
val submitDate: Date,
val driver: ActorRef,
- val appUiUrl: String,
defaultCores: Int)
extends Serializable {
@@ -45,11 +44,6 @@ private[spark] class ApplicationInfo(
init()
- private def readObject(in: java.io.ObjectInputStream) : Unit = {
- in.defaultReadObject()
- init()
- }
-
private def init() {
state = ApplicationState.WAITING
executors = new mutable.HashMap[Int, ExecutorInfo]
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
index f25a1ad3bf92a..4433a2ec29be6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -30,6 +30,7 @@ import org.apache.spark.deploy.master.MasterMessages.ElectedLeader
* [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]]
*/
private[spark] trait LeaderElectionAgent extends Actor {
+ // TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring.
val masterActor: ActorRef
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 51794ce40cb45..fdb633bd33608 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -23,28 +23,38 @@ import java.util.Date
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.concurrent.Await
import scala.concurrent.duration._
+import scala.language.postfixOps
import scala.util.Random
import akka.actor._
import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.SerializationExtension
+import org.apache.hadoop.fs.FileSystem
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
+import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{AkkaUtils, Utils}
-private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
+private[spark] class Master(
+ host: String,
+ port: Int,
+ webUiPort: Int,
+ val securityMgr: SecurityManager)
+ extends Actor with Logging {
+
import context.dispatcher // to use Akka's scheduler.schedule()
val conf = new SparkConf
- val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
+ def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000
val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200)
val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15)
@@ -63,6 +73,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
val completedApps = new ArrayBuffer[ApplicationInfo]
var nextAppNumber = 0
+ val appIdToUI = new HashMap[String, SparkUI]
+ val fileSystemsUsed = new HashSet[FileSystem]
+
val drivers = new HashSet[DriverInfo]
val completedDrivers = new ArrayBuffer[DriverInfo]
val waitingDrivers = new ArrayBuffer[DriverInfo] // Drivers currently spooled for scheduling
@@ -70,8 +83,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
Utils.checkHost(host, "Expected hostname")
- val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf)
- val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf)
+ val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr)
+ val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf,
+ securityMgr)
val masterSource = new MasterSource(this)
val webUi = new MasterWebUI(this, webUiPort)
@@ -105,8 +119,8 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
logInfo("Starting Spark master at " + masterUrl)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
- webUi.start()
- masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort.get
+ webUi.bind()
+ masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort
context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut)
masterMetricsSystem.registerSource(masterSource)
@@ -139,6 +153,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
override def postStop() {
webUi.stop()
+ fileSystemsUsed.foreach(_.close())
masterMetricsSystem.stop()
applicationMetricsSystem.stop()
persistenceEngine.close()
@@ -222,8 +237,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
if (waitingDrivers.contains(d)) {
waitingDrivers -= d
self ! DriverStateChanged(driverId, DriverState.KILLED, None)
- }
- else {
+ } else {
// We just notify the worker to kill the driver here. The final bookkeeping occurs
// on the return path when the worker submits a state change back to the master
// to notify it that the driver was successfully killed.
@@ -371,7 +385,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
case RequestWebUIPort => {
- sender ! WebUIPortResponse(webUi.boundPort.getOrElse(-1))
+ sender ! WebUIPortResponse(webUi.boundPort)
}
}
@@ -529,8 +543,15 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
val workerAddress = worker.actor.path.address
if (addressToWorker.contains(workerAddress)) {
- logInfo("Attempted to re-register worker at same address: " + workerAddress)
- return false
+ val oldWorker = addressToWorker(workerAddress)
+ if (oldWorker.state == WorkerState.UNKNOWN) {
+ // A worker registering from UNKNOWN implies that the worker was restarted during recovery.
+ // The old worker must thus be dead, so we will remove it and accept the new worker.
+ removeWorker(oldWorker)
+ } else {
+ logInfo("Attempted to re-register worker at same address: " + workerAddress)
+ return false
+ }
}
workers += worker
@@ -572,8 +593,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
- new ApplicationInfo(
- now, newApplicationId(date), desc, date, driver, desc.appUiUrl, defaultCores)
+ new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores)
}
def registerApplication(app: ApplicationInfo): Unit = {
@@ -605,12 +625,20 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
if (completedApps.size >= RETAINED_APPLICATIONS) {
val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1)
completedApps.take(toRemove).foreach( a => {
+ appIdToUI.remove(a.id).foreach { ui => webUi.detachSparkUI(ui) }
applicationMetricsSystem.removeSource(a.appSource)
})
completedApps.trimStart(toRemove)
}
completedApps += app // Remember it in our history
waitingApps -= app
+
+ // If application events are logged, use them to rebuild the UI
+ if (!rebuildSparkUI(app)) {
+ // Avoid broken links if the UI is not reconstructed
+ app.desc.appUiUrl = ""
+ }
+
for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id)
@@ -625,9 +653,40 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
}
+ /**
+ * Rebuild a new SparkUI from the given application's event logs.
+ * Return whether this is successful.
+ */
+ def rebuildSparkUI(app: ApplicationInfo): Boolean = {
+ val appName = app.desc.name
+ val eventLogDir = app.desc.eventLogDir.getOrElse { return false }
+ val fileSystem = Utils.getHadoopFileSystem(eventLogDir)
+ val eventLogInfo = EventLoggingListener.parseLoggingInfo(eventLogDir, fileSystem)
+ val eventLogPaths = eventLogInfo.logPaths
+ val compressionCodec = eventLogInfo.compressionCodec
+ if (!eventLogPaths.isEmpty) {
+ try {
+ val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec)
+ val ui = new SparkUI(
+ new SparkConf, replayBus, appName + " (completed)", "/history/" + app.id)
+ replayBus.replay()
+ app.desc.appUiUrl = ui.basePath
+ appIdToUI(app.id) = ui
+ webUi.attachSparkUI(ui)
+ return true
+ } catch {
+ case t: Throwable =>
+ logError("Exception in replaying log for application %s (%s)".format(appName, app.id), t)
+ }
+ } else {
+ logWarning("Application %s (%s) has no valid logs: %s".format(appName, app.id, eventLogDir))
+ }
+ false
+ }
+
/** Generate a new app ID given a app's submission date */
def newApplicationId(submitDate: Date): String = {
- val appId = "app-%s-%04d".format(DATE_FORMAT.format(submitDate), nextAppNumber)
+ val appId = "app-%s-%04d".format(createDateFormat.format(submitDate), nextAppNumber)
nextAppNumber += 1
appId
}
@@ -651,7 +710,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
def newDriverId(submitDate: Date): String = {
- val appId = "driver-%s-%04d".format(DATE_FORMAT.format(submitDate), nextDriverNumber)
+ val appId = "driver-%s-%04d".format(createDateFormat.format(submitDate), nextDriverNumber)
nextDriverNumber += 1
appId
}
@@ -708,11 +767,16 @@ private[spark] object Master {
}
}
- def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf)
- : (ActorSystem, Int, Int) =
- {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf)
- val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName)
+ def startSystemAndActor(
+ host: String,
+ port: Int,
+ webUiPort: Int,
+ conf: SparkConf): (ActorSystem, Int, Int) = {
+ val securityMgr = new SecurityManager(conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf,
+ securityManager = securityMgr)
+ val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort,
+ securityMgr), actorName)
val timeout = AkkaUtils.askTimeout(conf)
val respFuture = actor.ask(RequestWebUIPort)(timeout)
val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse]
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
index 74a9f8cd824fb..db72d8ae9bdaf 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
@@ -28,10 +28,6 @@ private[master] object MasterMessages {
case object RevokedLeadership
- // Actor System to LeaderElectionAgent
-
- case object CheckLeader
-
// Actor System to Master
case object CheckForWorkerTimeOut
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala
new file mode 100644
index 0000000000000..4781a80d470e1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import scala.collection.JavaConversions._
+
+import org.apache.curator.framework.{CuratorFramework, CuratorFrameworkFactory}
+import org.apache.curator.retry.ExponentialBackoffRetry
+import org.apache.zookeeper.KeeperException
+
+import org.apache.spark.{Logging, SparkConf}
+
+object SparkCuratorUtil extends Logging {
+
+ val ZK_CONNECTION_TIMEOUT_MILLIS = 15000
+ val ZK_SESSION_TIMEOUT_MILLIS = 60000
+ val RETRY_WAIT_MILLIS = 5000
+ val MAX_RECONNECT_ATTEMPTS = 3
+
+ def newClient(conf: SparkConf): CuratorFramework = {
+ val ZK_URL = conf.get("spark.deploy.zookeeper.url")
+ val zk = CuratorFrameworkFactory.newClient(ZK_URL,
+ ZK_SESSION_TIMEOUT_MILLIS, ZK_CONNECTION_TIMEOUT_MILLIS,
+ new ExponentialBackoffRetry(RETRY_WAIT_MILLIS, MAX_RECONNECT_ATTEMPTS))
+ zk.start()
+ zk
+ }
+
+ def mkdir(zk: CuratorFramework, path: String) {
+ if (zk.checkExists().forPath(path) == null) {
+ try {
+ zk.create().creatingParentsIfNeeded().forPath(path)
+ } catch {
+ case nodeExist: KeeperException.NodeExistsException =>
+ // do nothing, ignore node existing exception.
+ case e: Exception => throw e
+ }
+ }
+ }
+
+ def deleteRecursive(zk: CuratorFramework, path: String) {
+ if (zk.checkExists().forPath(path) != null) {
+ for (child <- zk.getChildren.forPath(path)) {
+ zk.delete().forPath(path + "/" + child)
+ }
+ zk.delete().forPath(path)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
deleted file mode 100644
index 57758055b19c0..0000000000000
--- a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
+++ /dev/null
@@ -1,205 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.deploy.master
-
-import scala.collection.JavaConversions._
-
-import org.apache.zookeeper._
-import org.apache.zookeeper.Watcher.Event.KeeperState
-import org.apache.zookeeper.data.Stat
-
-import org.apache.spark.{Logging, SparkConf}
-
-/**
- * Provides a Scala-side interface to the standard ZooKeeper client, with the addition of retry
- * logic. If the ZooKeeper session expires or otherwise dies, a new ZooKeeper session will be
- * created. If ZooKeeper remains down after several retries, the given
- * [[org.apache.spark.deploy.master.SparkZooKeeperWatcher SparkZooKeeperWatcher]] will be
- * informed via zkDown().
- *
- * Additionally, all commands sent to ZooKeeper will be retried until they either fail too many
- * times or a semantic exception is thrown (e.g., "node already exists").
- */
-private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher,
- conf: SparkConf) extends Logging {
- val ZK_URL = conf.get("spark.deploy.zookeeper.url", "")
-
- val ZK_ACL = ZooDefs.Ids.OPEN_ACL_UNSAFE
- val ZK_TIMEOUT_MILLIS = 30000
- val RETRY_WAIT_MILLIS = 5000
- val ZK_CHECK_PERIOD_MILLIS = 10000
- val MAX_RECONNECT_ATTEMPTS = 3
-
- private var zk: ZooKeeper = _
-
- private val watcher = new ZooKeeperWatcher()
- private var reconnectAttempts = 0
- private var closed = false
-
- /** Connect to ZooKeeper to start the session. Must be called before anything else. */
- def connect() {
- connectToZooKeeper()
-
- new Thread() {
- override def run() = sessionMonitorThread()
- }.start()
- }
-
- def sessionMonitorThread(): Unit = {
- while (!closed) {
- Thread.sleep(ZK_CHECK_PERIOD_MILLIS)
- if (zk.getState != ZooKeeper.States.CONNECTED) {
- reconnectAttempts += 1
- val attemptsLeft = MAX_RECONNECT_ATTEMPTS - reconnectAttempts
- if (attemptsLeft <= 0) {
- logError("Could not connect to ZooKeeper: system failure")
- zkWatcher.zkDown()
- close()
- } else {
- logWarning("ZooKeeper connection failed, retrying " + attemptsLeft + " more times...")
- connectToZooKeeper()
- }
- }
- }
- }
-
- def close() {
- if (!closed && zk != null) { zk.close() }
- closed = true
- }
-
- private def connectToZooKeeper() {
- if (zk != null) zk.close()
- zk = new ZooKeeper(ZK_URL, ZK_TIMEOUT_MILLIS, watcher)
- }
-
- /**
- * Attempts to maintain a live ZooKeeper exception despite (very) transient failures.
- * Mainly useful for handling the natural ZooKeeper session expiration.
- */
- private class ZooKeeperWatcher extends Watcher {
- def process(event: WatchedEvent) {
- if (closed) { return }
-
- event.getState match {
- case KeeperState.SyncConnected =>
- reconnectAttempts = 0
- zkWatcher.zkSessionCreated()
- case KeeperState.Expired =>
- connectToZooKeeper()
- case KeeperState.Disconnected =>
- logWarning("ZooKeeper disconnected, will retry...")
- case s => // Do nothing
- }
- }
- }
-
- def create(path: String, bytes: Array[Byte], createMode: CreateMode): String = {
- retry {
- zk.create(path, bytes, ZK_ACL, createMode)
- }
- }
-
- def exists(path: String, watcher: Watcher = null): Stat = {
- retry {
- zk.exists(path, watcher)
- }
- }
-
- def getChildren(path: String, watcher: Watcher = null): List[String] = {
- retry {
- zk.getChildren(path, watcher).toList
- }
- }
-
- def getData(path: String): Array[Byte] = {
- retry {
- zk.getData(path, false, null)
- }
- }
-
- def delete(path: String, version: Int = -1): Unit = {
- retry {
- zk.delete(path, version)
- }
- }
-
- /**
- * Creates the given directory (non-recursively) if it doesn't exist.
- * All znodes are created in PERSISTENT mode with no data.
- */
- def mkdir(path: String) {
- if (exists(path) == null) {
- try {
- create(path, "".getBytes, CreateMode.PERSISTENT)
- } catch {
- case e: Exception =>
- // If the exception caused the directory not to be created, bubble it up,
- // otherwise ignore it.
- if (exists(path) == null) { throw e }
- }
- }
- }
-
- /**
- * Recursively creates all directories up to the given one.
- * All znodes are created in PERSISTENT mode with no data.
- */
- def mkdirRecursive(path: String) {
- var fullDir = ""
- for (dentry <- path.split("/").tail) {
- fullDir += "/" + dentry
- mkdir(fullDir)
- }
- }
-
- /**
- * Retries the given function up to 3 times. The assumption is that failure is transient,
- * UNLESS it is a semantic exception (i.e., trying to get data from a node that doesn't exist),
- * in which case the exception will be thrown without retries.
- *
- * @param fn Block to execute, possibly multiple times.
- */
- def retry[T](fn: => T, n: Int = MAX_RECONNECT_ATTEMPTS): T = {
- try {
- fn
- } catch {
- case e: KeeperException.NoNodeException => throw e
- case e: KeeperException.NodeExistsException => throw e
- case e: Exception if n > 0 =>
- logError("ZooKeeper exception, " + n + " more retries...", e)
- Thread.sleep(RETRY_WAIT_MILLIS)
- retry(fn, n-1)
- }
- }
-}
-
-trait SparkZooKeeperWatcher {
- /**
- * Called whenever a ZK session is created --
- * this will occur when we create our first session as well as each time
- * the session expires or errors out.
- */
- def zkSessionCreated()
-
- /**
- * Called if ZK appears to be completely down (i.e., not just a transient error).
- * We will no longer attempt to reconnect to ZK, and the SparkZooKeeperSession is considered dead.
- */
- def zkDown()
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
index 47b8f67f8a45b..285f9b014e291 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -18,105 +18,67 @@
package org.apache.spark.deploy.master
import akka.actor.ActorRef
-import org.apache.zookeeper._
-import org.apache.zookeeper.Watcher.Event.EventType
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.master.MasterMessages._
+import org.apache.curator.framework.CuratorFramework
+import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch}
private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
masterUrl: String, conf: SparkConf)
- extends LeaderElectionAgent with SparkZooKeeperWatcher with Logging {
+ extends LeaderElectionAgent with LeaderLatchListener with Logging {
val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
- private val watcher = new ZooKeeperWatcher()
- private val zk = new SparkZooKeeperSession(this, conf)
+ private var zk: CuratorFramework = _
+ private var leaderLatch: LeaderLatch = _
private var status = LeadershipStatus.NOT_LEADER
- private var myLeaderFile: String = _
- private var leaderUrl: String = _
override def preStart() {
+
logInfo("Starting ZooKeeper LeaderElection agent")
- zk.connect()
- }
+ zk = SparkCuratorUtil.newClient(conf)
+ leaderLatch = new LeaderLatch(zk, WORKING_DIR)
+ leaderLatch.addListener(this)
- override def zkSessionCreated() {
- synchronized {
- zk.mkdirRecursive(WORKING_DIR)
- myLeaderFile =
- zk.create(WORKING_DIR + "/master_", masterUrl.getBytes, CreateMode.EPHEMERAL_SEQUENTIAL)
- self ! CheckLeader
- }
+ leaderLatch.start()
}
override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) {
- logError("LeaderElectionAgent failed, waiting " + zk.ZK_TIMEOUT_MILLIS + "...", reason)
- Thread.sleep(zk.ZK_TIMEOUT_MILLIS)
+ logError("LeaderElectionAgent failed...", reason)
super.preRestart(reason, message)
}
- override def zkDown() {
- logError("ZooKeeper down! LeaderElectionAgent shutting down Master.")
- System.exit(1)
- }
-
override def postStop() {
+ leaderLatch.close()
zk.close()
}
override def receive = {
- case CheckLeader => checkLeader()
+ case _ =>
}
- private class ZooKeeperWatcher extends Watcher {
- def process(event: WatchedEvent) {
- if (event.getType == EventType.NodeDeleted) {
- logInfo("Leader file disappeared, a master is down!")
- self ! CheckLeader
+ override def isLeader() {
+ synchronized {
+ // could have lost leadership by now.
+ if (!leaderLatch.hasLeadership) {
+ return
}
- }
- }
- /** Uses ZK leader election. Navigates several ZK potholes along the way. */
- def checkLeader() {
- val masters = zk.getChildren(WORKING_DIR).toList
- val leader = masters.sorted.head
- val leaderFile = WORKING_DIR + "/" + leader
-
- // Setup a watch for the current leader.
- zk.exists(leaderFile, watcher)
-
- try {
- leaderUrl = new String(zk.getData(leaderFile))
- } catch {
- // A NoNodeException may be thrown if old leader died since the start of this method call.
- // This is fine -- just check again, since we're guaranteed to see the new values.
- case e: KeeperException.NoNodeException =>
- logInfo("Leader disappeared while reading it -- finding next leader")
- checkLeader()
- return
+ logInfo("We have gained leadership")
+ updateLeadershipStatus(true)
}
+ }
- // Synchronization used to ensure no interleaving between the creation of a new session and the
- // checking of a leader, which could cause us to delete our real leader file erroneously.
+ override def notLeader() {
synchronized {
- val isLeader = myLeaderFile == leaderFile
- if (!isLeader && leaderUrl == masterUrl) {
- // We found a different master file pointing to this process.
- // This can happen in the following two cases:
- // (1) The master process was restarted on the same node.
- // (2) The ZK server died between creating the file and returning the name of the file.
- // For this case, we will end up creating a second file, and MUST explicitly delete the
- // first one, since our ZK session is still open.
- // Note that this deletion will cause a NodeDeleted event to be fired so we check again for
- // leader changes.
- assert(leaderFile < myLeaderFile)
- logWarning("Cleaning up old ZK master election file that points to this master.")
- zk.delete(leaderFile)
- } else {
- updateLeadershipStatus(isLeader)
+ // could have gained leadership by now.
+ if (leaderLatch.hasLeadership) {
+ return
}
+
+ logInfo("We have lost leadership")
+ updateLeadershipStatus(false)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 48b2fc06a9d70..834dfedee52ce 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -17,36 +17,29 @@
package org.apache.spark.deploy.master
+import scala.collection.JavaConversions._
+
import akka.serialization.Serialization
-import org.apache.zookeeper._
+import org.apache.curator.framework.CuratorFramework
+import org.apache.zookeeper.CreateMode
import org.apache.spark.{Logging, SparkConf}
class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
extends PersistenceEngine
- with SparkZooKeeperWatcher
with Logging
{
val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
+ val zk: CuratorFramework = SparkCuratorUtil.newClient(conf)
- val zk = new SparkZooKeeperSession(this, conf)
-
- zk.connect()
-
- override def zkSessionCreated() {
- zk.mkdirRecursive(WORKING_DIR)
- }
-
- override def zkDown() {
- logError("PersistenceEngine disconnected from ZooKeeper -- ZK looks down.")
- }
+ SparkCuratorUtil.mkdir(zk, WORKING_DIR)
override def addApplication(app: ApplicationInfo) {
serializeIntoFile(WORKING_DIR + "/app_" + app.id, app)
}
override def removeApplication(app: ApplicationInfo) {
- zk.delete(WORKING_DIR + "/app_" + app.id)
+ zk.delete().forPath(WORKING_DIR + "/app_" + app.id)
}
override def addDriver(driver: DriverInfo) {
@@ -54,7 +47,7 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
}
override def removeDriver(driver: DriverInfo) {
- zk.delete(WORKING_DIR + "/driver_" + driver.id)
+ zk.delete().forPath(WORKING_DIR + "/driver_" + driver.id)
}
override def addWorker(worker: WorkerInfo) {
@@ -62,7 +55,7 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
}
override def removeWorker(worker: WorkerInfo) {
- zk.delete(WORKING_DIR + "/worker_" + worker.id)
+ zk.delete().forPath(WORKING_DIR + "/worker_" + worker.id)
}
override def close() {
@@ -70,26 +63,34 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
}
override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
- val sortedFiles = zk.getChildren(WORKING_DIR).toList.sorted
+ val sortedFiles = zk.getChildren().forPath(WORKING_DIR).toList.sorted
val appFiles = sortedFiles.filter(_.startsWith("app_"))
- val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
+ val apps = appFiles.map(deserializeFromFile[ApplicationInfo]).flatten
val driverFiles = sortedFiles.filter(_.startsWith("driver_"))
- val drivers = driverFiles.map(deserializeFromFile[DriverInfo])
+ val drivers = driverFiles.map(deserializeFromFile[DriverInfo]).flatten
val workerFiles = sortedFiles.filter(_.startsWith("worker_"))
- val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
+ val workers = workerFiles.map(deserializeFromFile[WorkerInfo]).flatten
(apps, drivers, workers)
}
private def serializeIntoFile(path: String, value: AnyRef) {
val serializer = serialization.findSerializerFor(value)
val serialized = serializer.toBinary(value)
- zk.create(path, serialized, CreateMode.PERSISTENT)
+ zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized)
}
- def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): T = {
- val fileData = zk.getData(WORKING_DIR + "/" + filename)
+ def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): Option[T] = {
+ val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename)
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
- serializer.fromBinary(fileData).asInstanceOf[T]
+ try {
+ Some(serializer.fromBinary(fileData).asInstanceOf[T])
+ } catch {
+ case e: Exception => {
+ logWarning("Exception while reading persisted file, deleting", e)
+ zk.delete().forPath(WORKING_DIR + "/" + filename)
+ None
+ }
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index 5cc4adbe448b7..b5cd4d2ea963f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -23,20 +23,21 @@ import scala.concurrent.Await
import scala.xml.Node
import akka.pattern.ask
-import net.liftweb.json.JsonAST.JValue
+import org.json4s.JValue
import org.apache.spark.deploy.JsonProtocol
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
import org.apache.spark.deploy.master.ExecutorInfo
-import org.apache.spark.ui.UIUtils
+import org.apache.spark.ui.{WebUIPage, UIUtils}
import org.apache.spark.util.Utils
-private[spark] class ApplicationPage(parent: MasterWebUI) {
- val master = parent.masterActorRef
- val timeout = parent.timeout
+private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") {
+
+ private val master = parent.masterActorRef
+ private val timeout = parent.timeout
/** Executor details for a particular application */
- def renderJson(request: HttpServletRequest): JValue = {
+ override def renderJson(request: HttpServletRequest): JValue = {
val appId = request.getParameter("appId")
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
val state = Await.result(stateFuture, timeout)
@@ -82,7 +83,7 @@ private[spark] class ApplicationPage(parent: MasterWebUI) {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
deleted file mode 100644
index 01c8f9065e50a..0000000000000
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
+++ /dev/null
@@ -1,194 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.deploy.master.ui
-
-import javax.servlet.http.HttpServletRequest
-
-import scala.concurrent.Await
-import scala.xml.Node
-
-import akka.pattern.ask
-import net.liftweb.json.JsonAST.JValue
-
-import org.apache.spark.deploy.{DeployWebUI, JsonProtocol}
-import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
-import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo}
-import org.apache.spark.ui.UIUtils
-import org.apache.spark.util.Utils
-
-private[spark] class IndexPage(parent: MasterWebUI) {
- val master = parent.masterActorRef
- val timeout = parent.timeout
-
- def renderJson(request: HttpServletRequest): JValue = {
- val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- val state = Await.result(stateFuture, timeout)
- JsonProtocol.writeMasterState(state)
- }
-
- /** Index view listing applications and executors */
- def render(request: HttpServletRequest): Seq[Node] = {
- val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- val state = Await.result(stateFuture, timeout)
-
- val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory")
- val workers = state.workers.sortBy(_.id)
- val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers)
-
- val appHeaders = Seq("ID", "Name", "Cores", "Memory per Node", "Submitted Time", "User",
- "State", "Duration")
- val activeApps = state.activeApps.sortBy(_.startTime).reverse
- val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps)
- val completedApps = state.completedApps.sortBy(_.endTime).reverse
- val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps)
-
- val driverHeaders = Seq("ID", "Submitted Time", "Worker", "State", "Cores", "Memory",
- "Main Class")
- val activeDrivers = state.activeDrivers.sortBy(_.startTime).reverse
- val activeDriversTable = UIUtils.listingTable(driverHeaders, driverRow, activeDrivers)
- val completedDrivers = state.completedDrivers.sortBy(_.startTime).reverse
- val completedDriversTable = UIUtils.listingTable(driverHeaders, driverRow, completedDrivers)
-
- // For now we only show driver information if the user has submitted drivers to the cluster.
- // This is until we integrate the notion of drivers and applications in the UI.
- def hasDrivers = activeDrivers.length > 0 || completedDrivers.length > 0
-
- val content =
-
-
-
-
URL: {state.uri}
-
Workers: {state.workers.size}
-
Cores: {state.workers.map(_.cores).sum} Total,
- {state.workers.map(_.coresUsed).sum} Used
-
Memory:
- {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total,
- {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
new file mode 100644
index 0000000000000..7ca3b08a28728
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master.ui
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.concurrent.Await
+import scala.xml.Node
+
+import akka.pattern.ask
+import org.json4s.JValue
+
+import org.apache.spark.deploy.JsonProtocol
+import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
+import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo}
+import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.util.Utils
+
+private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
+ private val master = parent.masterActorRef
+ private val timeout = parent.timeout
+
+ override def renderJson(request: HttpServletRequest): JValue = {
+ val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
+ val state = Await.result(stateFuture, timeout)
+ JsonProtocol.writeMasterState(state)
+ }
+
+ /** Index view listing applications and executors */
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
+ val state = Await.result(stateFuture, timeout)
+
+ val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory")
+ val workers = state.workers.sortBy(_.id)
+ val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers)
+
+ val appHeaders = Seq("ID", "Name", "Cores", "Memory per Node", "Submitted Time", "User",
+ "State", "Duration")
+ val activeApps = state.activeApps.sortBy(_.startTime).reverse
+ val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps)
+ val completedApps = state.completedApps.sortBy(_.endTime).reverse
+ val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps)
+
+ val driverHeaders = Seq("ID", "Submitted Time", "Worker", "State", "Cores", "Memory",
+ "Main Class")
+ val activeDrivers = state.activeDrivers.sortBy(_.startTime).reverse
+ val activeDriversTable = UIUtils.listingTable(driverHeaders, driverRow, activeDrivers)
+ val completedDrivers = state.completedDrivers.sortBy(_.startTime).reverse
+ val completedDriversTable = UIUtils.listingTable(driverHeaders, driverRow, completedDrivers)
+
+ // For now we only show driver information if the user has submitted drivers to the cluster.
+ // This is until we integrate the notion of drivers and applications in the UI.
+ def hasDrivers = activeDrivers.length > 0 || completedDrivers.length > 0
+
+ val content =
+
+
+
+
URL: {state.uri}
+
Workers: {state.workers.size}
+
Cores: {state.workers.map(_.cores).sum} Total,
+ {state.workers.map(_.coresUsed).sum} Used
+
Memory:
+ {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total,
+ {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
new file mode 100644
index 0000000000000..8381f59672ea3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker.ui
+
+import java.io.File
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.util.Utils
+
+private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") {
+ private val worker = parent.worker
+ private val workDir = parent.workDir
+
+ def renderLog(request: HttpServletRequest): String = {
+ val defaultBytes = 100 * 1024
+
+ val appId = Option(request.getParameter("appId"))
+ val executorId = Option(request.getParameter("executorId"))
+ val driverId = Option(request.getParameter("driverId"))
+ val logType = request.getParameter("logType")
+ val offset = Option(request.getParameter("offset")).map(_.toLong)
+ val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
+
+ val path = (appId, executorId, driverId) match {
+ case (Some(a), Some(e), None) =>
+ s"${workDir.getPath}/$appId/$executorId/$logType"
+ case (None, None, Some(d)) =>
+ s"${workDir.getPath}/$driverId/$logType"
+ case _ =>
+ throw new Exception("Request must specify either application or driver identifiers")
+ }
+
+ val (startByte, endByte) = getByteRange(path, offset, byteLength)
+ val file = new File(path)
+ val logLength = file.length
+
+ val pre = s"==== Bytes $startByte-$endByte of $logLength of $path ====\n"
+ pre + Utils.offsetBytes(path, startByte, endByte)
+ }
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val defaultBytes = 100 * 1024
+ val appId = Option(request.getParameter("appId"))
+ val executorId = Option(request.getParameter("executorId"))
+ val driverId = Option(request.getParameter("driverId"))
+ val logType = request.getParameter("logType")
+ val offset = Option(request.getParameter("offset")).map(_.toLong)
+ val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
+
+ val (path, params) = (appId, executorId, driverId) match {
+ case (Some(a), Some(e), None) =>
+ (s"${workDir.getPath}/$a/$e/$logType", s"appId=$a&executorId=$e")
+ case (None, None, Some(d)) =>
+ (s"${workDir.getPath}/$d/$logType", s"driverId=$d")
+ case _ =>
+ throw new Exception("Request must specify either application or driver identifiers")
+ }
+
+ val (startByte, endByte) = getByteRange(path, offset, byteLength)
+ val file = new File(path)
+ val logLength = file.length
+ val logText = {Utils.offsetBytes(path, startByte, endByte)}
+ val linkToMaster =
+ val range = Bytes {startByte.toString} - {endByte.toString} of {logLength}
+
+ val backButton =
+ if (startByte > 0) {
+
+
+
+ } else {
+
+ }
+
+ val nextButton =
+ if (endByte < logLength) {
+
+
+
+ } else {
+
+ }
+
+ val content =
+
+
+ {linkToMaster}
+
+
{backButton}
+
{range}
+
{nextButton}
+
+
+
+
{logText}
+
+
+
+ UIUtils.basicSparkPage(content, logType + " log page for " + appId)
+ }
+
+ /** Determine the byte range for a log or log page. */
+ private def getByteRange(path: String, offset: Option[Long], byteLength: Int): (Long, Long) = {
+ val defaultBytes = 100 * 1024
+ val maxBytes = 1024 * 1024
+ val file = new File(path)
+ val logLength = file.length()
+ val getOffset = offset.getOrElse(logLength - defaultBytes)
+ val startByte =
+ if (getOffset < 0) {
+ 0L
+ } else if (getOffset > logLength) {
+ logLength
+ } else {
+ getOffset
+ }
+ val logPageLength = math.min(byteLength, maxBytes)
+ val endByte = math.min(startByte + logPageLength, logLength)
+ (startByte, endByte)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
new file mode 100644
index 0000000000000..d4513118ced05
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker.ui
+
+import scala.concurrent.Await
+import scala.xml.Node
+
+import akka.pattern.ask
+import javax.servlet.http.HttpServletRequest
+import org.json4s.JValue
+
+import org.apache.spark.deploy.JsonProtocol
+import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse}
+import org.apache.spark.deploy.master.DriverState
+import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner}
+import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.util.Utils
+
+private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
+ val workerActor = parent.worker.self
+ val worker = parent.worker
+ val timeout = parent.timeout
+
+ override def renderJson(request: HttpServletRequest): JValue = {
+ val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
+ val workerState = Await.result(stateFuture, timeout)
+ JsonProtocol.writeWorkerState(workerState)
+ }
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
+ val workerState = Await.result(stateFuture, timeout)
+
+ val executorHeaders = Seq("ExecutorID", "Cores", "Memory", "Job Details", "Logs")
+ val runningExecutorTable =
+ UIUtils.listingTable(executorHeaders, executorRow, workerState.executors)
+ val finishedExecutorTable =
+ UIUtils.listingTable(executorHeaders, executorRow, workerState.finishedExecutors)
+
+ val driverHeaders = Seq("DriverID", "Main Class", "State", "Cores", "Memory", "Logs", "Notes")
+ val runningDrivers = workerState.drivers.sortBy(_.driverId).reverse
+ val runningDriverTable = UIUtils.listingTable(driverHeaders, driverRow, runningDrivers)
+ val finishedDrivers = workerState.finishedDrivers.sortBy(_.driverId).reverse
+ def finishedDriverTable = UIUtils.listingTable(driverHeaders, driverRow, finishedDrivers)
+
+ // For now we only show driver information if the user has submitted drivers to the cluster.
+ // This is until we integrate the notion of drivers and applications in the UI.
+ def hasDrivers = runningDrivers.length > 0 || finishedDrivers.length > 0
+
+ val content =
+
-
- val range = Bytes {startByte.toString} - {endByte.toString} of {logLength}
-
- val backButton =
- if (startByte > 0) {
-
-
-
- }
- else {
-
- }
-
- val nextButton =
- if (endByte < logLength) {
-
-
-
- }
- else {
-
- }
-
- val content =
-
-
- {linkToMaster}
-
-
{backButton}
-
{range}
-
{nextButton}
-
-
-
-
{logText}
-
-
-
- UIUtils.basicSparkPage(content, logType + " log page for " + appId)
- }
-
- /** Determine the byte range for a log or log page. */
- def getByteRange(path: String, offset: Option[Long], byteLength: Int)
- : (Long, Long) = {
- val defaultBytes = 100 * 1024
- val maxBytes = 1024 * 1024
-
- val file = new File(path)
- val logLength = file.length()
- val getOffset = offset.getOrElse(logLength-defaultBytes)
-
- val startByte =
- if (getOffset < 0) 0L
- else if (getOffset > logLength) logLength
- else getOffset
-
- val logPageLength = math.min(byteLength, maxBytes)
-
- val endByte = math.min(startByte + logPageLength, logLength)
-
- (startByte, endByte)
- }
+ val timeout = AkkaUtils.askTimeout(worker.conf)
- def stop() {
- server.foreach(_.stop())
+ initialize()
+
+ /** Initialize all components of the server. */
+ def initialize() {
+ val logPage = new LogPage(this)
+ attachPage(logPage)
+ attachPage(new WorkerPage(this))
+ attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static"))
+ attachHandler(createServletHandler("/log",
+ (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr))
+ worker.metricsSystem.getServletHandlers.foreach(attachHandler)
}
}
private[spark] object WorkerWebUI {
- val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
- val DEFAULT_PORT="8081"
+ val DEFAULT_PORT = 8081
+ val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR
+
+ def getUIPort(requestedPort: Option[Int], conf: SparkConf): Int = {
+ requestedPort.getOrElse(conf.getInt("worker.ui.port", WorkerWebUI.DEFAULT_PORT))
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 0aae569b17272..9ac7365f47f9f 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import akka.actor._
import akka.remote._
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.worker.WorkerWatcher
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
@@ -53,7 +53,8 @@ private[spark] class CoarseGrainedExecutorBackend(
case RegisteredExecutor(sparkProperties) =>
logInfo("Successfully registered with driver")
// Make this host instead of hostPort ?
- executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties)
+ executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties,
+ false)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@@ -68,12 +69,12 @@ private[spark] class CoarseGrainedExecutorBackend(
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
}
- case KillTask(taskId, _) =>
+ case KillTask(taskId, _, interruptThread) =>
if (executor == null) {
logError("Received KillTask command but executor was null")
System.exit(1)
} else {
- executor.killTask(taskId)
+ executor.killTask(taskId, interruptThread)
}
case x: DisassociatedEvent =>
@@ -97,14 +98,16 @@ private[spark] object CoarseGrainedExecutorBackend {
// Debug code
Utils.checkHost(hostname)
+ val conf = new SparkConf
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0,
- indestructible = true, conf = new SparkConf)
+ indestructible = true, conf = conf, new SecurityManager(conf))
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
- Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
+ Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId,
+ sparkHostPort, cores),
name = "Executor")
workerUrl.foreach{ url =>
actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 989d666f15600..272bcda5f8f2f 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -29,7 +29,7 @@ import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AkkaUtils, Utils}
/**
* Spark executor used with Mesos, YARN, and the standalone scheduler.
@@ -64,16 +64,12 @@ private[spark] class Executor(
// to what Yarn on this system said was available. This will be used later when SparkEnv
// created.
if (java.lang.Boolean.valueOf(
- System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))))
- {
+ System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))) {
conf.set("spark.local.dir", getYarnLocalDirs())
+ } else if (sys.env.contains("SPARK_LOCAL_DIRS")) {
+ conf.set("spark.local.dir", sys.env("SPARK_LOCAL_DIRS"))
}
- // Create our ClassLoader and set it on this thread
- private val urlClassLoader = createClassLoader()
- private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
- Thread.currentThread.setContextClassLoader(replClassLoader)
-
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
@@ -117,11 +113,14 @@ private[spark] class Executor(
}
}
+ // Create our ClassLoader
+ // do this after SparkEnv creation so can access the SecurityManager
+ private val urlClassLoader = createClassLoader()
+ private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
+
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
- private val akkaFrameSize = {
- env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
- }
+ private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
// Start worker thread pool
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
@@ -137,10 +136,10 @@ private[spark] class Executor(
threadPool.execute(tr)
}
- def killTask(taskId: Long) {
+ def killTask(taskId: Long, interruptThread: Boolean) {
val tr = runningTasks.get(taskId)
if (tr != null) {
- tr.kill()
+ tr.kill(interruptThread)
}
}
@@ -162,16 +161,14 @@ private[spark] class Executor(
class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {
- object TaskKilledException extends Exception
-
@volatile private var killed = false
@volatile private var task: Task[Any] = _
- def kill() {
+ def kill(interruptThread: Boolean) {
logInfo("Executor is trying to kill task " + taskId)
killed = true
if (task != null) {
- task.kill()
+ task.kill(interruptThread)
}
}
@@ -201,7 +198,7 @@ private[spark] class Executor(
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
- throw TaskKilledException
+ throw new TaskKilledException
}
attemptedTask = Some(task)
@@ -215,7 +212,7 @@ private[spark] class Executor(
// If the task has been killed, let's fail it.
if (task.killed) {
- throw TaskKilledException
+ throw new TaskKilledException
}
val resultSer = SparkEnv.get.serializer.newInstance()
@@ -225,10 +222,10 @@ private[spark] class Executor(
for (m <- task.metrics) {
m.hostname = Utils.localHostName()
- m.executorDeserializeTime = (taskStart - startTime).toInt
- m.executorRunTime = (taskFinish - taskStart).toInt
+ m.executorDeserializeTime = taskStart - startTime
+ m.executorRunTime = taskFinish - taskStart
m.jvmGCTime = gcTime - startGCTime
- m.resultSerializationTime = (afterSerialization - beforeSerialization).toInt
+ m.resultSerializationTime = afterSerialization - beforeSerialization
}
val accumUpdates = Accumulators.values
@@ -258,13 +255,13 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}
- case TaskKilledException => {
+ case _: TaskKilledException | _: InterruptedException if task.killed => {
logInfo("Executor killed task " + taskId)
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
}
case t: Throwable => {
- val serviceTime = (System.currentTimeMillis() - taskStart).toInt
+ val serviceTime = System.currentTimeMillis() - taskStart
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
m.executorRunTime = serviceTime
@@ -277,7 +274,6 @@ private[spark] class Executor(
// have left some weird state around depending on when the exception was thrown, but on
// the other hand, maybe we could detect that when future tasks fail and exit then.
logError("Exception in task ID " + taskId, t)
- //System.exit(1)
}
} finally {
// TODO: Unregister shuffle memory only for ResultTask
@@ -294,15 +290,19 @@ private[spark] class Executor(
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path
*/
- private def createClassLoader(): ExecutorURLClassLoader = {
- val loader = this.getClass.getClassLoader
+ private def createClassLoader(): MutableURLClassLoader = {
+ val currentLoader = Utils.getContextOrSparkClassLoader
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
val urls = currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}.toArray
- new ExecutorURLClassLoader(urls, loader)
+ val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false)
+ userClassPathFirst match {
+ case true => new ChildExecutorURLClassLoader(urls, currentLoader)
+ case false => new ExecutorURLClassLoader(urls, currentLoader)
+ }
}
/**
@@ -313,11 +313,14 @@ private[spark] class Executor(
val classUri = conf.get("spark.repl.class.uri", null)
if (classUri != null) {
logInfo("Using REPL class URI: " + classUri)
+ val userClassPathFirst: java.lang.Boolean =
+ conf.getBoolean("spark.files.userClassPathFirst", false)
try {
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
- val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
- constructor.newInstance(classUri, parent)
+ val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader],
+ classOf[Boolean])
+ constructor.newInstance(classUri, parent, userClassPathFirst)
} catch {
case _: ClassNotFoundException =>
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
@@ -338,12 +341,12 @@ private[spark] class Executor(
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
index 210f3dbeebaca..38be2c58b333f 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
@@ -34,13 +34,19 @@ object ExecutorExitCode {
logging the exception. */
val UNCAUGHT_EXCEPTION_TWICE = 51
- /** The default uncaught exception handler was reached, and the uncaught exception was an
+ /** The default uncaught exception handler was reached, and the uncaught exception was an
OutOfMemoryError. */
val OOM = 52
/** DiskStore failed to create a local temporary directory after many attempts. */
val DISK_STORE_FAILED_TO_CREATE_DIR = 53
+ /** TachyonStore failed to initialize after many attempts. */
+ val TACHYON_STORE_FAILED_TO_INITIALIZE = 54
+
+ /** TachyonStore failed to create a local temporary directory after many attempts. */
+ val TACHYON_STORE_FAILED_TO_CREATE_DIR = 55
+
def explainExitCode(exitCode: Int): String = {
exitCode match {
case UNCAUGHT_EXCEPTION => "Uncaught exception"
@@ -48,7 +54,10 @@ object ExecutorExitCode {
case OOM => "OutOfMemoryError"
case DISK_STORE_FAILED_TO_CREATE_DIR =>
"Failed to create local directory (bad spark.local.dir?)"
- case _ =>
+ case TACHYON_STORE_FAILED_TO_INITIALIZE => "TachyonStore failed to initialize."
+ case TACHYON_STORE_FAILED_TO_CREATE_DIR =>
+ "TachyonStore failed to create a local temporary directory."
+ case _ =>
"Unknown executor exit code (" + exitCode + ")" + (
if (exitCode > 128) {
" (died from signal " + (exitCode - 128) + "?)"
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
index 127f5e90f3e1a..0ed52cfe9df61 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.fs.FileSystem
import org.apache.spark.metrics.source.Source
-class ExecutorSource(val executor: Executor, executorId: String) extends Source {
+private[spark] class ExecutorSource(val executor: Executor, executorId: String) extends Source {
private def fileStats(scheme: String) : Option[FileSystem.Statistics] =
FileSystem.getAllStatistics().filter(s => s.getScheme.equals(scheme)).headOption
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala
index f9bfe8ed2f5ba..218ed7b5d2d39 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala
@@ -19,13 +19,56 @@ package org.apache.spark.executor
import java.net.{URLClassLoader, URL}
+import org.apache.spark.util.ParentClassLoader
+
/**
* The addURL method in URLClassLoader is protected. We subclass it to make this accessible.
+ * We also make changes so user classes can come before the default classes.
*/
+
+private[spark] trait MutableURLClassLoader extends ClassLoader {
+ def addURL(url: URL)
+ def getURLs: Array[URL]
+}
+
+private[spark] class ChildExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader)
+ extends MutableURLClassLoader {
+
+ private object userClassLoader extends URLClassLoader(urls, null){
+ override def addURL(url: URL) {
+ super.addURL(url)
+ }
+ override def findClass(name: String): Class[_] = {
+ super.findClass(name)
+ }
+ }
+
+ private val parentClassLoader = new ParentClassLoader(parent)
+
+ override def findClass(name: String): Class[_] = {
+ try {
+ userClassLoader.findClass(name)
+ } catch {
+ case e: ClassNotFoundException => {
+ parentClassLoader.loadClass(name)
+ }
+ }
+ }
+
+ def addURL(url: URL) {
+ userClassLoader.addURL(url)
+ }
+
+ def getURLs() = {
+ userClassLoader.getURLs()
+ }
+}
+
private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader)
- extends URLClassLoader(urls, parent) {
+ extends URLClassLoader(urls, parent) with MutableURLClassLoader {
override def addURL(url: URL) {
super.addURL(url)
}
}
+
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index 6fc702fdb1512..64e24506e8038 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -76,7 +76,8 @@ private[spark] class MesosExecutorBackend
if (executor == null) {
logError("Received KillTask but executor was null")
} else {
- executor.killTask(t.getValue.toLong)
+ // TODO: Determine the 'interruptOnCancel' property set for the given job.
+ executor.killTask(t.getValue.toLong, interruptThread = false)
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 455339943f42d..350fd74173f65 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,21 +17,29 @@
package org.apache.spark.executor
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.storage.{BlockId, BlockStatus}
+
+/**
+ * :: DeveloperApi ::
+ * Metrics tracked during the execution of a task.
+ */
+@DeveloperApi
class TaskMetrics extends Serializable {
/**
- * Host's name the task runs on
+ * Host's name the task runs on
*/
var hostname: String = _
/**
* Time taken on the executor to deserialize this task
*/
- var executorDeserializeTime: Int = _
+ var executorDeserializeTime: Long = _
/**
* Time the executor spends actually running the task (including fetching shuffle data)
*/
- var executorRunTime: Int = _
+ var executorRunTime: Long = _
/**
* The number of bytes this task transmitted back to the driver as the TaskResult
@@ -68,13 +76,23 @@ class TaskMetrics extends Serializable {
* here
*/
var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None
+
+ /**
+ * Storage statuses of any blocks that have been updated as a result of this task.
+ */
+ var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None
}
-object TaskMetrics {
- private[spark] def empty(): TaskMetrics = new TaskMetrics
+private[spark] object TaskMetrics {
+ def empty: TaskMetrics = new TaskMetrics
}
+/**
+ * :: DeveloperApi ::
+ * Metrics pertaining to shuffle data read in a given task.
+ */
+@DeveloperApi
class ShuffleReadMetrics extends Serializable {
/**
* Absolute time when this task finished reading shuffle data
@@ -103,19 +121,17 @@ class ShuffleReadMetrics extends Serializable {
*/
var fetchWaitTime: Long = _
- /**
- * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all
- * input blocks. Since block fetches are both pipelined and parallelized, this can
- * exceed fetchWaitTime and executorRunTime.
- */
- var remoteFetchTime: Long = _
-
/**
* Total number of remote bytes read from the shuffle by this task
*/
var remoteBytesRead: Long = _
}
+/**
+ * :: DeveloperApi ::
+ * Metrics pertaining to shuffle data written in a given task.
+ */
+@DeveloperApi
class ShuffleWriteMetrics extends Serializable {
/**
* Number of bytes written for the shuffle by this task
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
new file mode 100644
index 0000000000000..4cb450577796a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.InputSplit
+import org.apache.hadoop.mapreduce.JobContext
+import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat
+import org.apache.hadoop.mapreduce.RecordReader
+import org.apache.hadoop.mapreduce.TaskAttemptContext
+import org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader
+import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
+
+/**
+ * A [[org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat CombineFileInputFormat]] for
+ * reading whole text files. Each file is read as key-value pair, where the key is the file path and
+ * the value is the entire content of file.
+ */
+
+private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] {
+ override protected def isSplitable(context: JobContext, file: Path): Boolean = false
+
+ override def createRecordReader(
+ split: InputSplit,
+ context: TaskAttemptContext): RecordReader[String, String] = {
+
+ new CombineFileRecordReader[String, String](
+ split.asInstanceOf[CombineFileSplit],
+ context,
+ classOf[WholeTextFileRecordReader])
+ }
+
+ /**
+ * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API.
+ */
+ def setMaxSplitSize(context: JobContext, minPartitions: Int) {
+ val files = listStatus(context)
+ val totalLen = files.map { file =>
+ if (file.isDir) 0L else file.getLen
+ }.sum
+ val maxSplitSize = Math.ceil(totalLen * 1.0 /
+ (if (minPartitions == 0) 1 else minPartitions)).toLong
+ super.setMaxSplitSize(maxSplitSize)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
new file mode 100644
index 0000000000000..c3dabd2e79995
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import com.google.common.io.{ByteStreams, Closeables}
+
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapreduce.InputSplit
+import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
+import org.apache.hadoop.mapreduce.RecordReader
+import org.apache.hadoop.mapreduce.TaskAttemptContext
+
+/**
+ * A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file
+ * out in a key-value pair, where the key is the file path and the value is the entire content of
+ * the file.
+ */
+private[spark] class WholeTextFileRecordReader(
+ split: CombineFileSplit,
+ context: TaskAttemptContext,
+ index: Integer)
+ extends RecordReader[String, String] {
+
+ private val path = split.getPath(index)
+ private val fs = path.getFileSystem(context.getConfiguration)
+
+ // True means the current file has been processed, then skip it.
+ private var processed = false
+
+ private val key = path.toString
+ private var value: String = null
+
+ override def initialize(split: InputSplit, context: TaskAttemptContext) = {}
+
+ override def close() = {}
+
+ override def getProgress = if (processed) 1.0f else 0.0f
+
+ override def getCurrentKey = key
+
+ override def getCurrentValue = value
+
+ override def nextKeyValue = {
+ if (!processed) {
+ val fileIn = fs.open(path)
+ val innerBuffer = ByteStreams.toByteArray(fileIn)
+
+ value = new Text(innerBuffer).toString
+ Closeables.close(fileIn, false)
+
+ processed = true
+ true
+ } else {
+ false
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
index 848b5c439bb5b..e1a5ee316bb69 100644
--- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -23,11 +23,18 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import org.xerial.snappy.{SnappyInputStream, SnappyOutputStream}
import org.apache.spark.SparkConf
+import org.apache.spark.annotation.DeveloperApi
/**
+ * :: DeveloperApi ::
* CompressionCodec allows the customization of choosing different compression implementations
* to be used in block storage.
+ *
+ * Note: The wire protocol for a codec is not guaranteed compatible across versions of Spark.
+ * This is intended for use as an internal compression utility within a single
+ * Spark application.
*/
+@DeveloperApi
trait CompressionCodec {
def compressedOutputStream(s: OutputStream): OutputStream
@@ -38,8 +45,7 @@ trait CompressionCodec {
private[spark] object CompressionCodec {
def createCodec(conf: SparkConf): CompressionCodec = {
- createCodec(conf, conf.get(
- "spark.io.compression.codec", classOf[LZFCompressionCodec].getName))
+ createCodec(conf, conf.get("spark.io.compression.codec", DEFAULT_COMPRESSION_CODEC))
}
def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
@@ -47,12 +53,20 @@ private[spark] object CompressionCodec {
.getConstructor(classOf[SparkConf])
ctor.newInstance(conf).asInstanceOf[CompressionCodec]
}
+
+ val DEFAULT_COMPRESSION_CODEC = classOf[LZFCompressionCodec].getName
}
/**
+ * :: DeveloperApi ::
* LZF implementation of [[org.apache.spark.io.CompressionCodec]].
+ *
+ * Note: The wire protocol for this codec is not guaranteed to be compatible across versions
+ * of Spark. This is intended for use as an internal compression utility within a single Spark
+ * application.
*/
+@DeveloperApi
class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec {
override def compressedOutputStream(s: OutputStream): OutputStream = {
@@ -64,9 +78,15 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec {
/**
+ * :: DeveloperApi ::
* Snappy implementation of [[org.apache.spark.io.CompressionCodec]].
* Block size can be configured by spark.io.compression.snappy.block.size.
+ *
+ * Note: The wire protocol for this codec is not guaranteed to be compatible across versions
+ * of Spark. This is intended for use as an internal compression utility within a single Spark
+ * application.
*/
+@DeveloperApi
class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec {
override def compressedOutputStream(s: OutputStream): OutputStream = {
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
index 6883a54494598..1b7a5d1f1980a 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable
import scala.util.matching.Regex
import org.apache.spark.Logging
+import org.apache.spark.util.Utils
private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging {
@@ -42,7 +43,7 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi
}
def initialize() {
- //Add default properties in case there's no properties file
+ // Add default properties in case there's no properties file
setDefaultProperties(properties)
// If spark.metrics.conf is not set, try to get file in class path
@@ -50,7 +51,7 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi
try {
is = configFile match {
case Some(f) => new FileInputStream(f)
- case None => getClass.getClassLoader.getResourceAsStream(METRICS_CONF)
+ case None => Utils.getSparkClassLoader.getResourceAsStream(METRICS_CONF)
}
if (is != null) {
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 966c092124266..651511da1b7fe 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry}
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.metrics.sink.{MetricsServlet, Sink}
import org.apache.spark.metrics.source.Source
@@ -64,7 +64,7 @@ import org.apache.spark.metrics.source.Source
* [options] is the specific property of this source or sink.
*/
private[spark] class MetricsSystem private (val instance: String,
- conf: SparkConf) extends Logging {
+ conf: SparkConf, securityMgr: SecurityManager) extends Logging {
val confFile = conf.get("spark.metrics.conf", null)
val metricsConfig = new MetricsConfig(Option(confFile))
@@ -129,17 +129,19 @@ private[spark] class MetricsSystem private (val instance: String,
sinkConfigs.foreach { kv =>
val classPath = kv._2.getProperty("class")
- try {
- val sink = Class.forName(classPath)
- .getConstructor(classOf[Properties], classOf[MetricRegistry])
- .newInstance(kv._2, registry)
- if (kv._1 == "servlet") {
- metricsServlet = Some(sink.asInstanceOf[MetricsServlet])
- } else {
- sinks += sink.asInstanceOf[Sink]
+ if (null != classPath) {
+ try {
+ val sink = Class.forName(classPath)
+ .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
+ .newInstance(kv._2, registry, securityMgr)
+ if (kv._1 == "servlet") {
+ metricsServlet = Some(sink.asInstanceOf[MetricsServlet])
+ } else {
+ sinks += sink.asInstanceOf[Sink]
+ }
+ } catch {
+ case e: Exception => logError("Sink class " + classPath + " cannot be instantialized", e)
}
- } catch {
- case e: Exception => logError("Sink class " + classPath + " cannot be instantialized", e)
}
}
}
@@ -160,6 +162,7 @@ private[spark] object MetricsSystem {
}
}
- def createMetricsSystem(instance: String, conf: SparkConf): MetricsSystem =
- new MetricsSystem(instance, conf)
+ def createMetricsSystem(instance: String, conf: SparkConf,
+ securityMgr: SecurityManager): MetricsSystem =
+ new MetricsSystem(instance, conf, securityMgr)
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
index 98fa1dbd7c6ab..05852f1f98993 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
@@ -22,9 +22,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.{ConsoleReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class ConsoleSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+private[spark] class ConsoleSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val CONSOLE_DEFAULT_PERIOD = 10
val CONSOLE_DEFAULT_UNIT = "SECONDS"
@@ -36,7 +38,7 @@ class ConsoleSink(val property: Properties, val registry: MetricRegistry) extend
case None => CONSOLE_DEFAULT_PERIOD
}
- val pollUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match {
+ val pollUnit: TimeUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match {
case Some(s) => TimeUnit.valueOf(s.toUpperCase())
case None => TimeUnit.valueOf(CONSOLE_DEFAULT_UNIT)
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
index 40f64768e6885..542dce65366b2 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
@@ -23,9 +23,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.{CsvReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class CsvSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+private[spark] class CsvSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val CSV_KEY_PERIOD = "period"
val CSV_KEY_UNIT = "unit"
val CSV_KEY_DIR = "directory"
@@ -39,11 +41,11 @@ class CsvSink(val property: Properties, val registry: MetricRegistry) extends Si
case None => CSV_DEFAULT_PERIOD
}
- val pollUnit = Option(property.getProperty(CSV_KEY_UNIT)) match {
+ val pollUnit: TimeUnit = Option(property.getProperty(CSV_KEY_UNIT)) match {
case Some(s) => TimeUnit.valueOf(s.toUpperCase())
case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT)
}
-
+
MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod)
val pollDir = Option(property.getProperty(CSV_KEY_DIR)) match {
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
index e09be001421fc..aeb4ad44a0647 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -24,9 +24,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.graphite.{Graphite, GraphiteReporter}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+private[spark] class GraphiteSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val GRAPHITE_DEFAULT_PERIOD = 10
val GRAPHITE_DEFAULT_UNIT = "SECONDS"
val GRAPHITE_DEFAULT_PREFIX = ""
@@ -37,7 +39,7 @@ class GraphiteSink(val property: Properties, val registry: MetricRegistry) exten
val GRAPHITE_KEY_UNIT = "unit"
val GRAPHITE_KEY_PREFIX = "prefix"
- def propertyToOption(prop: String) = Option(property.getProperty(prop))
+ def propertyToOption(prop: String): Option[String] = Option(property.getProperty(prop))
if (!propertyToOption(GRAPHITE_KEY_HOST).isDefined) {
throw new Exception("Graphite sink requires 'host' property.")
@@ -55,7 +57,7 @@ class GraphiteSink(val property: Properties, val registry: MetricRegistry) exten
case None => GRAPHITE_DEFAULT_PERIOD
}
- val pollUnit = propertyToOption(GRAPHITE_KEY_UNIT) match {
+ val pollUnit: TimeUnit = propertyToOption(GRAPHITE_KEY_UNIT) match {
case Some(s) => TimeUnit.valueOf(s.toUpperCase())
case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT)
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
index b5cf210af2119..ed27234b4e760 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
@@ -20,8 +20,11 @@ package org.apache.spark.metrics.sink
import java.util.Properties
import com.codahale.metrics.{JmxReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
+
+private[spark] class JmxSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
-class JmxSink(val property: Properties, val registry: MetricRegistry) extends Sink {
val reporter: JmxReporter = JmxReporter.forRegistry(registry).build()
override def start() {
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
index 3cdfe26d40f66..571539ba5e467 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
@@ -19,16 +19,19 @@ package org.apache.spark.metrics.sink
import java.util.Properties
import java.util.concurrent.TimeUnit
+
import javax.servlet.http.HttpServletRequest
import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.json.MetricsModule
import com.fasterxml.jackson.databind.ObjectMapper
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
-import org.apache.spark.ui.JettyUtils
+import org.apache.spark.SecurityManager
+import org.apache.spark.ui.JettyUtils._
-class MetricsServlet(val property: Properties, val registry: MetricRegistry) extends Sink {
+private[spark] class MetricsServlet(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val SERVLET_KEY_PATH = "path"
val SERVLET_KEY_SAMPLE = "sample"
@@ -42,8 +45,9 @@ class MetricsServlet(val property: Properties, val registry: MetricRegistry) ext
val mapper = new ObjectMapper().registerModule(
new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample))
- def getHandlers = Array[(String, Handler)](
- (servletPath, JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json"))
+ def getHandlers = Array[ServletContextHandler](
+ createServletHandler(servletPath,
+ new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr)
)
def getMetricsSnapshot(request: HttpServletRequest): String = {
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala
index 3a739aa563eae..6f2b5a06027ea 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala
@@ -17,7 +17,7 @@
package org.apache.spark.metrics.sink
-trait Sink {
+private[spark] trait Sink {
def start: Unit
def stop: Unit
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala
index 75cb2b8973aa1..f865f9648a91e 100644
--- a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala
@@ -20,7 +20,7 @@ package org.apache.spark.metrics.source
import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.jvm.{GarbageCollectorMetricSet, MemoryUsageGaugeSet}
-class JvmSource extends Source {
+private[spark] class JvmSource extends Source {
val sourceName = "jvm"
val metricRegistry = new MetricRegistry()
diff --git a/core/src/main/scala/org/apache/spark/metrics/source/Source.scala b/core/src/main/scala/org/apache/spark/metrics/source/Source.scala
index 3fee55cc6dcd5..1dda2cd83b2a9 100644
--- a/core/src/main/scala/org/apache/spark/metrics/source/Source.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/source/Source.scala
@@ -19,7 +19,7 @@ package org.apache.spark.metrics.source
import com.codahale.metrics.MetricRegistry
-trait Source {
+private[spark] trait Source {
def sourceName: String
def metricRegistry: MetricRegistry
}
diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
index d3c09b16063d6..04df2f3b0d696 100644
--- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
+++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
@@ -45,9 +45,10 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
throw new Exception("Max chunk size is " + maxChunkSize)
}
+ val security = if (isSecurityNeg) 1 else 0
if (size == 0 && !gotChunkForSendingOnce) {
val newChunk = new MessageChunk(
- new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+ new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
gotChunkForSendingOnce = true
return Some(newChunk)
}
@@ -65,7 +66,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
}
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
gotChunkForSendingOnce = true
return Some(newChunk)
}
@@ -79,6 +80,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
throw new Exception("Attempting to get chunk from message with multiple data buffers")
}
val buffer = buffers(0)
+ val security = if (isSecurityNeg) 1 else 0
if (buffer.remaining > 0) {
if (buffer.remaining < chunkSize) {
throw new Exception("Not enough space in data buffer for receiving chunk")
@@ -86,7 +88,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
return Some(newChunk)
}
None
diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala
index f2e3c1a14ecc6..3ffaaab23d0f5 100644
--- a/core/src/main/scala/org/apache/spark/network/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/Connection.scala
@@ -17,6 +17,11 @@
package org.apache.spark.network
+import org.apache.spark._
+import org.apache.spark.SparkSaslServer
+
+import scala.collection.mutable.{HashMap, Queue, ArrayBuffer}
+
import java.net._
import java.nio._
import java.nio.channels._
@@ -27,20 +32,23 @@ import org.apache.spark._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val socketRemoteConnectionManagerId: ConnectionManagerId)
+ val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
extends Logging {
- def this(channel_ : SocketChannel, selector_ : Selector) = {
+ var sparkSaslServer: SparkSaslServer = null
+ var sparkSaslClient: SparkSaslClient = null
+
+ def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
- channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]))
+ channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]), id_)
}
channel.configureBlocking(false)
channel.socket.setTcpNoDelay(true)
channel.socket.setReuseAddress(true)
channel.socket.setKeepAlive(true)
- /*channel.socket.setReceiveBufferSize(32768) */
+ /* channel.socket.setReceiveBufferSize(32768) */
@volatile private var closed = false
var onCloseCallback: Connection => Unit = null
@@ -49,6 +57,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteAddress = getRemoteAddress()
+ /**
+ * Used to synchronize client requests: client's work-related requests must
+ * wait until SASL authentication completes.
+ */
+ private val authenticated = new Object()
+
+ def getAuthenticated(): Object = authenticated
+
+ def isSaslComplete(): Boolean
+
def resetForceReregister(): Boolean
// Read channels typically do not register for write and write does not for read
@@ -69,6 +87,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
// Will be true for ReceivingConnection, false for SendingConnection.
def changeInterestForRead(): Boolean
+ private def disposeSasl() {
+ if (sparkSaslServer != null) {
+ sparkSaslServer.dispose();
+ }
+
+ if (sparkSaslClient != null) {
+ sparkSaslClient.dispose()
+ }
+ }
+
// On receiving a write event, should we change the interest for this channel or not ?
// Will be false for ReceivingConnection, true for SendingConnection.
// Actually, for now, should not get triggered for ReceivingConnection
@@ -101,6 +129,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
k.cancel()
}
channel.close()
+ disposeSasl()
callOnCloseCallback()
}
@@ -168,17 +197,21 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
private[spark]
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
- remoteId_ : ConnectionManagerId)
- extends Connection(SocketChannel.open, selector_, remoteId_) {
+ remoteId_ : ConnectionManagerId, id_ : ConnectionId)
+ extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
- private class Outbox(fair: Int = 0) {
+ def isSaslComplete(): Boolean = {
+ if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
+ }
+
+ private class Outbox {
val messages = new Queue[Message]()
- val defaultChunkSize = 65536 //32768 //16384
+ val defaultChunkSize = 65536
var nextMessageToBeUsed = 0
def addMessage(message: Message) {
messages.synchronized{
- /*messages += message*/
+ /* messages += message */
messages.enqueue(message)
logDebug("Added [" + message + "] to outbox for sending to " +
"[" + getRemoteConnectionManagerId() + "]")
@@ -186,42 +219,10 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
}
def getChunk(): Option[MessageChunk] = {
- fair match {
- case 0 => getChunkFIFO()
- case 1 => getChunkRR()
- case _ => throw new Exception("Unexpected fairness policy in outbox")
- }
- }
-
- private def getChunkFIFO(): Option[MessageChunk] = {
- /*logInfo("Using FIFO")*/
messages.synchronized {
while (!messages.isEmpty) {
- val message = messages(0)
- val chunk = message.getChunkForSending(defaultChunkSize)
- if (chunk.isDefined) {
- messages += message // this is probably incorrect, it wont work as fifo
- if (!message.started) {
- logDebug("Starting to send [" + message + "]")
- message.started = true
- message.startTime = System.currentTimeMillis
- }
- return chunk
- } else {
- message.finishTime = System.currentTimeMillis
- logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
- "] in " + message.timeTaken )
- }
- }
- }
- None
- }
-
- private def getChunkRR(): Option[MessageChunk] = {
- messages.synchronized {
- while (!messages.isEmpty) {
- /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
- /*val message = messages(nextMessageToBeUsed)*/
+ /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
+ /* val message = messages(nextMessageToBeUsed) */
val message = messages.dequeue
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
@@ -247,20 +248,21 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
}
}
- // outbox is used as a lock - ensure that it is always used as a leaf (since methods which
+ // outbox is used as a lock - ensure that it is always used as a leaf (since methods which
// lock it are invoked in context of other locks)
- private val outbox = new Outbox(1)
+ private val outbox = new Outbox()
/*
- This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly
- different purpose. This flag is to see if we need to force reregister for write even when we
+ This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly
+ different purpose. This flag is to see if we need to force reregister for write even when we
do not have any pending bytes to write to socket.
- This can happen due to a race between adding pending buffers, and checking for existing of
+ This can happen due to a race between adding pending buffers, and checking for existing of
data as detailed in https://github.com/mesos/spark/pull/791
*/
private var needForceReregister = false
+
val currentBuffers = new ArrayBuffer[ByteBuffer]()
- /*channel.socket.setSendBufferSize(256 * 1024)*/
+ /* channel.socket.setSendBufferSize(256 * 1024) */
override def getRemoteAddress() = address
@@ -348,11 +350,12 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// If we have 'seen' pending messages, then reset flag - since we handle that as
// normal registering of event (below)
if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister()
+
currentBuffers ++= buffers
}
case None => {
// changeConnectionKeyInterest(0)
- /*key.interestOps(0)*/
+ /* key.interestOps(0) */
return false
}
}
@@ -416,8 +419,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// Must be created within selector loop - else deadlock
-private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
- extends Connection(channel_, selector_) {
+private[spark] class ReceivingConnection(
+ channel_ : SocketChannel,
+ selector_ : Selector,
+ id_ : ConnectionId)
+ extends Connection(channel_, selector_, id_) {
+
+ def isSaslComplete(): Boolean = {
+ if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
+ }
class Inbox() {
val messages = new HashMap[Int, BufferMessage]()
@@ -428,6 +438,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
newMessage.started = true
newMessage.startTime = System.currentTimeMillis
+ newMessage.isSecurityNeg = header.securityNeg == 1
logDebug(
"Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
messages += ((newMessage.id, newMessage))
@@ -473,7 +484,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
val inbox = new Inbox()
val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
- var onReceiveCallback: (Connection , Message) => Unit = null
+ var onReceiveCallback: (Connection, Message) => Unit = null
var currentChunk: MessageChunk = null
channel.register(selector, SelectionKey.OP_READ)
@@ -529,10 +540,10 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
return false
}
- /*logDebug("Read " + bytesRead + " bytes for the buffer")*/
+ /* logDebug("Read " + bytesRead + " bytes for the buffer") */
if (currentChunk.buffer.remaining == 0) {
- /*println("Filled buffer at " + System.currentTimeMillis)*/
+ /* println("Filled buffer at " + System.currentTimeMillis) */
val bufferMessage = inbox.getMessageForChunk(currentChunk).get
if (bufferMessage.isCompletelyReceived) {
bufferMessage.flip
@@ -548,7 +559,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
}
}
} catch {
- case e: Exception => {
+ case e: Exception => {
logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala
new file mode 100644
index 0000000000000..d579c165a1917
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) {
+ override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId
+}
+
+private[spark] object ConnectionId {
+
+ def createConnectionIdFromString(connectionIdString: String): ConnectionId = {
+ val res = connectionIdString.split("_").map(_.trim())
+ if (res.size != 3) {
+ throw new Exception("Error converting ConnectionId string: " + connectionIdString +
+ " to a ConnectionId Object")
+ }
+ new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index 3dd82bee0b5fd..dcbbc1853186b 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -17,10 +17,12 @@
package org.apache.spark.network
-import java.net._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
+import java.net._
+import java.util.concurrent.atomic.AtomicInteger
+
import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
import scala.collection.mutable.ArrayBuffer
@@ -28,13 +30,16 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
+
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.concurrent.duration._
+import scala.language.postfixOps
import org.apache.spark._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SystemClock, Utils}
-private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Logging {
+private[spark] class ConnectionManager(port: Int, conf: SparkConf,
+ securityManager: SecurityManager) extends Logging {
class MessageStatus(
val message: Message,
@@ -50,6 +55,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private val selector = SelectorProvider.provider.openSelector()
+ // default to 30 second timeout waiting for authentication
+ private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
+
private val handleMessageExecutor = new ThreadPoolExecutor(
conf.getInt("spark.core.connection.handler.threads.min", 20),
conf.getInt("spark.core.connection.handler.threads.max", 60),
@@ -71,6 +79,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
new LinkedBlockingDeque[Runnable]())
private val serverChannel = ServerSocketChannel.open()
+ // used to track the SendingConnections waiting to do SASL negotiation
+ private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection]
+ with SynchronizedMap[ConnectionId, SendingConnection]
private val connectionsByKey =
new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
@@ -84,6 +95,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
+ private val authEnabled = securityManager.isAuthenticationEnabled()
+
serverChannel.configureBlocking(false)
serverChannel.socket.setReuseAddress(true)
serverChannel.socket.setReceiveBufferSize(256 * 1024)
@@ -94,6 +107,10 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
+ // used in combination with the ConnectionManagerId to create unique Connection ids
+ // to be able to track asynchronous messages
+ private val idCount: AtomicInteger = new AtomicInteger(1)
+
private val selectorThread = new Thread("connection-manager-thread") {
override def run() = ConnectionManager.this.run()
}
@@ -372,7 +389,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
// accept them all in a tight loop. non blocking accept with no processing, should be fine
while (newChannel != null) {
try {
- val newConnection = new ReceivingConnection(newChannel, selector)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
newConnection.onReceive(receiveMessage)
addListeners(newConnection)
addConnection(newConnection)
@@ -406,6 +424,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
connectionsById -= sendingConnectionManagerId
+ connectionsAwaitingSasl -= connection.connectionId
messageStatuses.synchronized {
messageStatuses
@@ -481,18 +500,141 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
val creationTime = System.currentTimeMillis
def run() {
logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
- handleMessage(connectionManagerId, message)
+ handleMessage(connectionManagerId, message, connection)
logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
}
}
handleMessageExecutor.execute(runnable)
- /*handleMessage(connection, message)*/
+ /* handleMessage(connection, message) */
+ }
+
+ private def handleClientAuthentication(
+ waitingConn: SendingConnection,
+ securityMsg: SecurityMessage,
+ connectionId : ConnectionId) {
+ if (waitingConn.isSaslComplete()) {
+ logDebug("Client sasl completed for id: " + waitingConn.connectionId)
+ connectionsAwaitingSasl -= waitingConn.connectionId
+ waitingConn.getAuthenticated().synchronized {
+ waitingConn.getAuthenticated().notifyAll();
+ }
+ return
+ } else {
+ var replyToken : Array[Byte] = null
+ try {
+ replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken);
+ if (waitingConn.isSaslComplete()) {
+ logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
+ connectionsAwaitingSasl -= waitingConn.connectionId
+ waitingConn.getAuthenticated().synchronized {
+ waitingConn.getAuthenticated().notifyAll()
+ }
+ return
+ }
+ var securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ securityMsg.getConnectionId.toString())
+ var message = securityMsgResp.toBufferMessage
+ if (message == null) throw new Exception("Error creating security message")
+ sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
+ } catch {
+ case e: Exception => {
+ logError("Error handling sasl client authentication", e)
+ waitingConn.close()
+ throw new Exception("Error evaluating sasl response: " + e)
+ }
+ }
+ }
+ }
+
+ private def handleServerAuthentication(
+ connection: Connection,
+ securityMsg: SecurityMessage,
+ connectionId: ConnectionId) {
+ if (!connection.isSaslComplete()) {
+ logDebug("saslContext not established")
+ var replyToken : Array[Byte] = null
+ try {
+ connection.synchronized {
+ if (connection.sparkSaslServer == null) {
+ logDebug("Creating sasl Server")
+ connection.sparkSaslServer = new SparkSaslServer(securityManager)
+ }
+ }
+ replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
+ if (connection.isSaslComplete()) {
+ logDebug("Server sasl completed: " + connection.connectionId)
+ } else {
+ logDebug("Server sasl not completed: " + connection.connectionId)
+ }
+ if (replyToken != null) {
+ var securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ securityMsg.getConnectionId)
+ var message = securityMsgResp.toBufferMessage
+ if (message == null) throw new Exception("Error creating security Message")
+ sendSecurityMessage(connection.getRemoteConnectionManagerId(), message)
+ }
+ } catch {
+ case e: Exception => {
+ logError("Error in server auth negotiation: " + e)
+ // It would probably be better to send an error message telling other side auth failed
+ // but for now just close
+ connection.close()
+ }
+ }
+ } else {
+ logDebug("connection already established for this connection id: " + connection.connectionId)
+ }
+ }
+
+
+ private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = {
+ if (bufferMessage.isSecurityNeg) {
+ logDebug("This is security neg message")
+
+ // parse as SecurityMessage
+ val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage)
+ val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId)
+
+ connectionsAwaitingSasl.get(connectionId) match {
+ case Some(waitingConn) => {
+ // Client - this must be in response to us doing Send
+ logDebug("Client handleAuth for id: " + waitingConn.connectionId)
+ handleClientAuthentication(waitingConn, securityMsg, connectionId)
+ }
+ case None => {
+ // Server - someone sent us something and we haven't authenticated yet
+ logDebug("Server handleAuth for id: " + connectionId)
+ handleServerAuthentication(conn, securityMsg, connectionId)
+ }
+ }
+ return true
+ } else {
+ if (!conn.isSaslComplete()) {
+ // We could handle this better and tell the client we need to do authentication
+ // negotiation, but for now just ignore them.
+ logError("message sent that is not security negotiation message on connection " +
+ "not authenticated yet, ignoring it!!")
+ return true
+ }
+ }
+ return false
}
- private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ private def handleMessage(
+ connectionManagerId: ConnectionManagerId,
+ message: Message,
+ connection: Connection) {
logDebug("Handling [" + message + "] from [" + connectionManagerId + "]")
message match {
case bufferMessage: BufferMessage => {
+ if (authEnabled) {
+ val res = handleAuthentication(connection, bufferMessage)
+ if (res == true) {
+ // message was security negotiation so skip the rest
+ logDebug("After handleAuth result was true, returning")
+ return
+ }
+ }
if (bufferMessage.hasAckId) {
val sentMessageStatus = messageStatuses.synchronized {
messageStatuses.get(bufferMessage.ackId) match {
@@ -541,20 +683,124 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
}
}
+ private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) {
+ // see if we need to do sasl before writing
+ // this should only be the first negotiation as the Client!!!
+ if (!conn.isSaslComplete()) {
+ conn.synchronized {
+ if (conn.sparkSaslClient == null) {
+ conn.sparkSaslClient = new SparkSaslClient(securityManager)
+ var firstResponse: Array[Byte] = null
+ try {
+ firstResponse = conn.sparkSaslClient.firstToken()
+ var securityMsg = SecurityMessage.fromResponse(firstResponse,
+ conn.connectionId.toString())
+ var message = securityMsg.toBufferMessage
+ if (message == null) throw new Exception("Error creating security message")
+ connectionsAwaitingSasl += ((conn.connectionId, conn))
+ sendSecurityMessage(connManagerId, message)
+ logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
+ } catch {
+ case e: Exception => {
+ logError("Error getting first response from the SaslClient.", e)
+ conn.close()
+ throw new Exception("Error getting first response from the SaslClient")
+ }
+ }
+ }
+ }
+ } else {
+ logDebug("Sasl already established ")
+ }
+ }
+
+ // allow us to add messages to the inbox for doing sasl negotiating
+ private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) {
+ def startNewConnection(): SendingConnection = {
+ val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId,
+ newConnectionId)
+ logInfo("creating new sending connection for security! " + newConnectionId )
+ registerRequests.enqueue(newConnection)
+
+ newConnection
+ }
+ // I removed the lookupKey stuff as part of merge ... should I re-add it ?
+ // We did not find it useful in our test-env ...
+ // If we do re-add it, we should consistently use it everywhere I guess ?
+ message.senderAddress = id.toSocketAddress()
+ logTrace("Sending Security [" + message + "] to [" + connManagerId + "]")
+ val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection())
+
+ // send security message until going connection has been authenticated
+ connection.send(message)
+
+ wakeupSelector()
+ }
+
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host,
connectionManagerId.port)
- val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
+ newConnectionId)
+ logTrace("creating new sending connection: " + newConnectionId)
registerRequests.enqueue(newConnection)
newConnection
}
- // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it
- // useful in our test-env ... If we do re-add it, we should consistently use it everywhere I
- // guess ?
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
+ if (authEnabled) {
+ checkSendAuthFirst(connectionManagerId, connection)
+ }
message.senderAddress = id.toSocketAddress()
+ logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
+ "connectionid: " + connection.connectionId)
+
+ if (authEnabled) {
+ // if we aren't authenticated yet lets block the senders until authentication completes
+ try {
+ connection.getAuthenticated().synchronized {
+ val clock = SystemClock
+ val startTime = clock.getTime()
+
+ while (!connection.isSaslComplete()) {
+ logDebug("getAuthenticated wait connectionid: " + connection.connectionId)
+ // have timeout in case remote side never responds
+ connection.getAuthenticated().wait(500)
+ if (((clock.getTime() - startTime) >= (authTimeout * 1000))
+ && (!connection.isSaslComplete())) {
+ // took to long to authenticate the connection, something probably went wrong
+ throw new Exception("Took to long for authentication to " + connectionManagerId +
+ ", waited " + authTimeout + "seconds, failing.")
+ }
+ }
+ }
+ } catch {
+ case e: Exception => logError("Exception while waiting for authentication.", e)
+
+ // need to tell sender it failed
+ messageStatuses.synchronized {
+ val s = messageStatuses.get(message.id)
+ s match {
+ case Some(msgStatus) => {
+ messageStatuses -= message.id
+ logInfo("Notifying " + msgStatus.connectionManagerId)
+ msgStatus.synchronized {
+ msgStatus.attempted = true
+ msgStatus.acked = false
+ msgStatus.markDone()
+ }
+ }
+ case None => {
+ logError("no messageStatus for failed message id: " + message.id)
+ }
+ }
+ }
+ }
+ }
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
connection.send(message)
@@ -606,20 +852,21 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private[spark] object ConnectionManager {
def main(args: Array[String]) {
- val manager = new ConnectionManager(9999, new SparkConf)
+ val conf = new SparkConf
+ val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
None
})
- /*testSequentialSending(manager)*/
- /*System.gc()*/
+ /* testSequentialSending(manager) */
+ /* System.gc() */
- /*testParallelSending(manager)*/
- /*System.gc()*/
+ /* testParallelSending(manager) */
+ /* System.gc() */
- /*testParallelDecreasingSending(manager)*/
- /*System.gc()*/
+ /* testParallelDecreasingSending(manager) */
+ /* System.gc() */
testContinuousSending(manager)
System.gc()
@@ -701,7 +948,7 @@ private[spark] object ConnectionManager {
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
- /*println("Started at " + startTime + ", finished at " + finishTime) */
+ /* println("Started at " + startTime + ", finished at " + finishTime) */
println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
index 35f64134b073a..4894ecd41f6eb 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
@@ -37,20 +37,20 @@ private[spark] object ConnectionManagerTest extends Logging{
"[size of msg in MB (integer)] [count] [await time in seconds)] ")
System.exit(1)
}
-
+
if (args(0).startsWith("local")) {
println("This runs only on a mesos cluster")
}
-
+
val sc = new SparkContext(args(0), "ConnectionManagerTest")
val slavesFile = Source.fromFile(args(1))
val slaves = slavesFile.mkString.split("\n")
slavesFile.close()
- /*println("Slaves")*/
- /*slaves.foreach(println)*/
+ /* println("Slaves") */
+ /* slaves.foreach(println) */
val tasknum = if (args.length > 2) args(2).toInt else slaves.length
- val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024
+ val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024
val count = if (args.length > 4) args(4).toInt else 3
val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second
println("Running " + count + " rounds of test: " + "parallel tasks = " + tasknum + ", " +
@@ -64,16 +64,16 @@ private[spark] object ConnectionManagerTest extends Logging{
(0 until count).foreach(i => {
val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => {
val connManager = SparkEnv.get.connectionManager
- val thisConnManagerId = connManager.id
- connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ val thisConnManagerId = connManager.id
+ connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
logInfo("Received [" + msg + "] from [" + id + "]")
None
})
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
-
- val startTime = System.currentTimeMillis
+
+ val startTime = System.currentTimeMillis
val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map{ slaveConnManagerId =>
{
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
@@ -84,7 +84,7 @@ private[spark] object ConnectionManagerTest extends Logging{
val results = futures.map(f => Await.result(f, awaitTime))
val finishTime = System.currentTimeMillis
Thread.sleep(5000)
-
+
val mb = size * results.size / 1024.0 / 1024.0
val ms = finishTime - startTime
val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms *
@@ -92,11 +92,11 @@ private[spark] object ConnectionManagerTest extends Logging{
logInfo(resultStr)
resultStr
}).collect()
-
- println("---------------------")
- println("Run " + i)
+
+ println("---------------------")
+ println("Run " + i)
resultStrs.foreach(println)
- println("---------------------")
+ println("---------------------")
})
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala
index 20fe67661844f..7caccfdbb44f9 100644
--- a/core/src/main/scala/org/apache/spark/network/Message.scala
+++ b/core/src/main/scala/org/apache/spark/network/Message.scala
@@ -27,6 +27,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var started = false
var startTime = -1L
var finishTime = -1L
+ var isSecurityNeg = false
def size: Int
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
index 9bcbc6141a502..ead663ede7a1c 100644
--- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
+++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
@@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader(
val totalSize: Int,
val chunkSize: Int,
val other: Int,
+ val securityNeg: Int,
val address: InetSocketAddress) {
lazy val buffer = {
// No need to change this, at 'use' time, we do a reverse lookup of the hostname.
@@ -40,6 +41,7 @@ private[spark] class MessageChunkHeader(
putInt(totalSize).
putInt(chunkSize).
putInt(other).
+ putInt(securityNeg).
putInt(ip.size).
put(ip).
putInt(port).
@@ -48,12 +50,13 @@ private[spark] class MessageChunkHeader(
}
override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
- " and sizes " + totalSize + " / " + chunkSize + " bytes"
+ " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg
+
}
private[spark] object MessageChunkHeader {
- val HEADER_SIZE = 40
+ val HEADER_SIZE = 44
def create(buffer: ByteBuffer): MessageChunkHeader = {
if (buffer.remaining != HEADER_SIZE) {
@@ -64,11 +67,13 @@ private[spark] object MessageChunkHeader {
val totalSize = buffer.getInt()
val chunkSize = buffer.getInt()
val other = buffer.getInt()
+ val securityNeg = buffer.getInt()
val ipSize = buffer.getInt()
val ipBytes = new Array[Byte](ipSize)
buffer.get(ipBytes)
val ip = InetAddress.getByAddress(ipBytes)
val port = buffer.getInt()
- new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg,
+ new InetSocketAddress(ip, port))
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
index 9976255c7e251..53a6038a9b59e 100644
--- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
@@ -18,17 +18,17 @@
package org.apache.spark.network
import java.nio.ByteBuffer
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
private[spark] object ReceiverTest {
def main(args: Array[String]) {
- val manager = new ConnectionManager(9999, new SparkConf)
+ val conf = new SparkConf
+ val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
println("Started connection manager with id = " + manager.id)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/
- val buffer = ByteBuffer.wrap("response".getBytes)
+ /* println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis) */
+ val buffer = ByteBuffer.wrap("response".getBytes("utf-8"))
Some(Message.createBufferMessage(buffer, msg.id))
})
Thread.currentThread.join()
diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
new file mode 100644
index 0000000000000..a1dfc4094cca7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.StringBuilder
+
+import org.apache.spark._
+import org.apache.spark.network._
+
+/**
+ * SecurityMessage is class that contains the connectionId and sasl token
+ * used in SASL negotiation. SecurityMessage has routines for converting
+ * it to and from a BufferMessage so that it can be sent by the ConnectionManager
+ * and easily consumed by users when received.
+ * The api was modeled after BlockMessage.
+ *
+ * The connectionId is the connectionId of the client side. Since
+ * message passing is asynchronous and its possible for the server side (receiving)
+ * to get multiple different types of messages on the same connection the connectionId
+ * is used to know which connnection the security message is intended for.
+ *
+ * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side
+ * is acting as a client and connecting to node_1. SASL negotiation has to occur
+ * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message.
+ * node_1 receives the message from node_0 but before it can process it and send a response,
+ * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0
+ * and sends a security message of its own to authenticate as a client. Now node_0 gets
+ * the message and it needs to decide if this message is in response to it being a client
+ * (from the first send) or if its just node_1 trying to connect to it to send data. This
+ * is where the connectionId field is used. node_0 can lookup the connectionId to see if
+ * it is in response to it being a client or if its in response to someone sending other data.
+ *
+ * The format of a SecurityMessage as its sent is:
+ * - Length of the ConnectionId
+ * - ConnectionId
+ * - Length of the token
+ * - Token
+ */
+private[spark] class SecurityMessage() extends Logging {
+
+ private var connectionId: String = null
+ private var token: Array[Byte] = null
+
+ def set(byteArr: Array[Byte], newconnectionId: String) {
+ if (byteArr == null) {
+ token = new Array[Byte](0)
+ } else {
+ token = byteArr
+ }
+ connectionId = newconnectionId
+ }
+
+ /**
+ * Read the given buffer and set the members of this class.
+ */
+ def set(buffer: ByteBuffer) {
+ val idLength = buffer.getInt()
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buffer.getChar()
+ }
+ connectionId = idBuilder.toString()
+
+ val tokenLength = buffer.getInt()
+ token = new Array[Byte](tokenLength)
+ if (tokenLength > 0) {
+ buffer.get(token, 0, tokenLength)
+ }
+ }
+
+ def set(bufferMsg: BufferMessage) {
+ val buffer = bufferMsg.buffers.apply(0)
+ buffer.clear()
+ set(buffer)
+ }
+
+ def getConnectionId: String = {
+ return connectionId
+ }
+
+ def getToken: Array[Byte] = {
+ return token
+ }
+
+ /**
+ * Create a BufferMessage that can be sent by the ConnectionManager containing
+ * the security information from this class.
+ * @return BufferMessage
+ */
+ def toBufferMessage: BufferMessage = {
+ val startTime = System.currentTimeMillis
+ val buffers = new ArrayBuffer[ByteBuffer]()
+
+ // 4 bytes for the length of the connectionId
+ // connectionId is of type char so multiple the length by 2 to get number of bytes
+ // 4 bytes for the length of token
+ // token is a byte buffer so just take the length
+ var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length)
+ buffer.putInt(connectionId.length())
+ connectionId.foreach((x: Char) => buffer.putChar(x))
+ buffer.putInt(token.length)
+
+ if (token.length > 0) {
+ buffer.put(token)
+ }
+ buffer.flip()
+ buffers += buffer
+
+ var message = Message.createBufferMessage(buffers)
+ logDebug("message total size is : " + message.size)
+ message.isSecurityNeg = true
+ return message
+ }
+
+ override def toString: String = {
+ "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]"
+ }
+}
+
+private[spark] object SecurityMessage {
+
+ /**
+ * Convert the given BufferMessage to a SecurityMessage by parsing the contents
+ * of the BufferMessage and populating the SecurityMessage fields.
+ * @param bufferMessage is a BufferMessage that was received
+ * @return new SecurityMessage
+ */
+ def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = {
+ val newSecurityMessage = new SecurityMessage()
+ newSecurityMessage.set(bufferMessage)
+ newSecurityMessage
+ }
+
+ /**
+ * Create a SecurityMessage to send from a given saslResponse.
+ * @param response is the response to a challenge from the SaslClient or Saslserver
+ * @param connectionId the client connectionId we are negotiation authentication for
+ * @return a new SecurityMessage
+ */
+ def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = {
+ val newSecurityMessage = new SecurityMessage()
+ newSecurityMessage.set(response, connectionId)
+ newSecurityMessage
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
index 646f8425d9551..b8ea7c2cff9a2 100644
--- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
@@ -18,8 +18,7 @@
package org.apache.spark.network
import java.nio.ByteBuffer
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
private[spark] object SenderTest {
def main(args: Array[String]) {
@@ -32,8 +31,8 @@ private[spark] object SenderTest {
val targetHost = args(0)
val targetPort = args(1).toInt
val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
-
- val manager = new ConnectionManager(0, new SparkConf)
+ val conf = new SparkConf
+ val manager = new ConnectionManager(0, conf, new SecurityManager(conf))
println("Started connection manager with id = " + manager.id)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
@@ -51,11 +50,11 @@ private[spark] object SenderTest {
(0 until count).foreach(i => {
val dataMessage = Message.createBufferMessage(buffer.duplicate)
val startTime = System.currentTimeMillis
- /*println("Started timer at " + startTime)*/
+ /* println("Started timer at " + startTime) */
val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage)
.map { response =>
val buffer = response.asInstanceOf[BufferMessage].buffers(0)
- new String(buffer.array)
+ new String(buffer.array, "utf-8")
}.getOrElse("none")
val finishTime = System.currentTimeMillis
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
index f9082ffb9141a..136c1912045aa 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
@@ -32,12 +32,12 @@ private[spark] class FileHeader (
buf.writeInt(fileLen)
buf.writeInt(blockId.name.length)
blockId.name.foreach((x: Char) => buf.writeByte(x))
- //padding the rest of header
+ // padding the rest of header
if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
} else {
- throw new Exception("too long header " + buf.readableBytes)
- logInfo("too long header")
+ throw new Exception("too long header " + buf.readableBytes)
+ logInfo("too long header")
}
buf
}
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
index 2625a7f6a575a..5cdbc306e56a0 100644
--- a/core/src/main/scala/org/apache/spark/package.scala
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -30,9 +30,18 @@ package org.apache
* type (e.g. RDD[(Int, Int)] through implicit conversions when you
* `import org.apache.spark.SparkContext._`.
*
- * Java programmers should reference the [[spark.api.java]] package
+ * Java programmers should reference the [[org.apache.spark.api.java]] package
* for Spark programming APIs in Java.
+ *
+ * Classes and methods marked with
+ * Experimental are user-facing features which have not been officially adopted by the
+ * Spark project. These are subject to change or removal in minor releases.
+ *
+ * Classes and methods marked with
+ * Developer API are intended for advanced users want to extend Spark through lower
+ * level interfaces. These are subject to changes or removal in minor releases.
*/
+
package object spark {
// For package docs only
}
diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
index 5f4450859cc9b..aed0353344427 100644
--- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
+++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
@@ -17,9 +17,13 @@
package org.apache.spark.partial
+import org.apache.spark.annotation.Experimental
+
/**
- * A Double with error bars on it.
+ * :: Experimental ::
+ * A Double value with error bars and associated confidence.
*/
+@Experimental
class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
override def toString(): String = "[%.3f, %.3f]".format(low, high)
}
diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
index 40b70baabcad9..8bb78123e3c9c 100644
--- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
@@ -22,36 +22,33 @@ import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.Map
import scala.collection.mutable.HashMap
+import scala.reflect.ClassTag
import cern.jet.stat.Probability
-import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+
+import org.apache.spark.util.collection.OpenHashMap
/**
* An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval.
*/
-private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double)
- extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] {
+private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[OpenHashMap[T,Long], Map[T, BoundedDouble]] {
var outputsMerged = 0
- var sums = new OLMap[T] // Sum of counts for each key
+ var sums = new OpenHashMap[T,Long]() // Sum of counts for each key
- override def merge(outputId: Int, taskResult: OLMap[T]) {
+ override def merge(outputId: Int, taskResult: OpenHashMap[T,Long]) {
outputsMerged += 1
- val iter = taskResult.object2LongEntrySet.fastIterator()
- while (iter.hasNext) {
- val entry = iter.next()
- sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue)
+ taskResult.foreach { case (key, value) =>
+ sums.changeValue(key, value, _ + value)
}
}
override def currentResult(): Map[T, BoundedDouble] = {
if (outputsMerged == totalOutputs) {
val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.object2LongEntrySet.fastIterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val sum = entry.getLongValue()
- result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
+ sums.foreach { case (key, sum) =>
+ result(key) = new BoundedDouble(sum, 1.0, sum, sum)
}
result
} else if (outputsMerged == 0) {
@@ -60,16 +57,13 @@ private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Dou
val p = outputsMerged.toDouble / totalOutputs
val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.object2LongEntrySet.fastIterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val sum = entry.getLongValue
+ sums.foreach { case (key, sum) =>
val mean = (sum + 1 - p) / p
val variance = (sum + 1) * (1 - p) / (p * p)
val stdev = math.sqrt(variance)
val low = mean - confFactor * stdev
val high = mean + confFactor * stdev
- result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
+ result(key) = new BoundedDouble(mean, confidence, low, high)
}
result
}
diff --git a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala
index 812368e04ac0d..cadd0c7ed19ba 100644
--- a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala
+++ b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala
@@ -17,6 +17,9 @@
package org.apache.spark.partial
+import org.apache.spark.annotation.Experimental
+
+@Experimental
class PartialResult[R](initialVal: R, isFinal: Boolean) {
private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None
private var failure: Option[Exception] = None
@@ -41,7 +44,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) {
}
}
- /**
+ /**
* Set a handler to be called when this PartialResult completes. Only one completion handler
* is supported per PartialResult.
*/
@@ -57,7 +60,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) {
return this
}
- /**
+ /**
* Set a handler to be called if this PartialResult's job fails. Only one failure handler
* is supported per PartialResult.
*/
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index d1c74a5063510..aed951a40b40c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -24,11 +24,14 @@ import scala.concurrent.ExecutionContext.Implicits.global
import scala.reflect.ClassTag
import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
+import org.apache.spark.annotation.Experimental
/**
+ * :: Experimental ::
* A set of asynchronous RDD actions available through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
+@Experimental
class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging {
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index e6c4a6d3794a0..c64da8804d166 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -19,24 +19,30 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
-import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext}
+import org.apache.spark._
import org.apache.spark.storage.{BlockId, BlockManager}
+import scala.Some
private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition {
val index = idx
}
private[spark]
-class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId])
+class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds: Array[BlockId])
extends RDD[T](sc, Nil) {
@transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)
+ @volatile private var _isValid = true
- override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => {
- new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
- }).toArray
+ override def getPartitions: Array[Partition] = {
+ assertValid()
+ (0 until blockIds.size).map(i => {
+ new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
+ }).toArray
+ }
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+ assertValid()
val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDPartition].blockId
blockManager.get(blockId) match {
@@ -47,7 +53,36 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId
}
override def getPreferredLocations(split: Partition): Seq[String] = {
+ assertValid()
locations_(split.asInstanceOf[BlockRDDPartition].blockId)
}
+
+ /**
+ * Remove the data blocks that this BlockRDD is made from. NOTE: This is an
+ * irreversible operation, as the data in the blocks cannot be recovered back
+ * once removed. Use it with caution.
+ */
+ private[spark] def removeBlocks() {
+ blockIds.foreach { blockId =>
+ sc.env.blockManager.master.removeBlock(blockId)
+ }
+ _isValid = false
+ }
+
+ /**
+ * Whether this BlockRDD is actually usable. This will be false if the data blocks have been
+ * removed using `this.removeBlocks`.
+ */
+ private[spark] def isValid: Boolean = {
+ _isValid
+ }
+
+ /** Check if this BlockRDD is valid. If not valid, exception is thrown. */
+ private[spark] def assertValid() {
+ if (!_isValid) {
+ throw new SparkException(
+ "Attempted to use %s after its blocks have been removed!".format(toString))
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 699a10c96c227..9ff76892aed32 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -20,10 +20,13 @@ package org.apache.spark.rdd
import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable.ArrayBuffer
+import scala.language.existentials
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
+import org.apache.spark.serializer.Serializer
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -50,12 +53,17 @@ private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]
}
/**
+ * :: DeveloperApi ::
* A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
* tuple with the list of values for that key.
*
+ * Note: This is an internal API. We recommend users use RDD.coGroup(...) instead of
+ * instantiating this directly.
+
* @param rdds parent RDDs.
- * @param part partitioner used to partition the shuffle output.
+ * @param part partitioner used to partition the shuffle output
*/
+@DeveloperApi
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
@@ -66,10 +74,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
private type CoGroupValue = (Any, Int) // Int is dependency number
private type CoGroupCombiner = Seq[CoGroup]
- private var serializerClass: String = null
+ private var serializer: Serializer = null
- def setSerializer(cls: String): CoGroupedRDD[K] = {
- serializerClass = cls
+ def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
+ this.serializer = serializer
this
}
@@ -80,7 +88,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
- new ShuffleDependency[Any, Any](rdd, part, serializerClass)
+ new ShuffleDependency[Any, Any](rdd, part, serializer)
}
}
}
@@ -102,7 +110,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
array
}
- override val partitioner = Some(part)
+ override val partitioner: Some[Partitioner] = Some(part)
override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = {
val sparkConf = SparkEnv.get.conf
@@ -113,18 +121,17 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
// A list of (rdd iterator, dependency number) pairs
val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
// Read them from the parent
val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
rddIterators += ((it, depNum))
- }
- case ShuffleCoGroupSplitDep(shuffleId) => {
+
+ case ShuffleCoGroupSplitDep(shuffleId) =>
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf)
+ val ser = Serializer.getSerializer(serializer)
val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
rddIterators += ((it, depNum))
- }
}
if (!externalSorting) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 4e82b51313bf0..c45b759f007cc 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -21,6 +21,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import scala.language.existentials
import scala.reflect.ClassTag
import org.apache.spark._
@@ -32,7 +33,7 @@ import org.apache.spark._
* @param parentsIndices list of indices in the parent that have been coalesced into this partition
* @param preferredLocation the preferred location for this partition
*/
-case class CoalescedRDDPartition(
+private[spark] case class CoalescedRDDPartition(
index: Int,
@transient rdd: RDD[_],
parentsIndices: Array[Int],
@@ -70,7 +71,7 @@ case class CoalescedRDDPartition(
* @param maxPartitions number of desired partitions in the coalesced RDD
* @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance
*/
-class CoalescedRDD[T: ClassTag](
+private[spark] class CoalescedRDD[T: ClassTag](
@transient var prev: RDD[T],
maxPartitions: Int,
balanceSlack: Double = 0.10)
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index a7b6b3b5146ce..9ca971c8a4c27 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -17,6 +17,7 @@
package org.apache.spark.rdd
+import org.apache.spark.annotation.Experimental
import org.apache.spark.{TaskContext, Logging}
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.MeanEvaluator
@@ -51,7 +52,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
/** Compute the standard deviation of this RDD's elements. */
def stdev(): Double = stats().stdev
- /**
+ /**
* Compute the sample standard deviation of this RDD's elements (which corrects for bias in
* estimating the standard deviation by dividing by N-1 instead of N).
*/
@@ -63,14 +64,22 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
*/
def sampleVariance(): Double = stats().sampleVariance
- /** (Experimental) Approximate operation to return the mean within a timeout. */
+ /**
+ * :: Experimental ::
+ * Approximate operation to return the mean within a timeout.
+ */
+ @Experimental
def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
val evaluator = new MeanEvaluator(self.partitions.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
- /** (Experimental) Approximate operation to return the sum within a timeout. */
+ /**
+ * :: Experimental ::
+ * Approximate operation to return the sum within a timeout.
+ */
+ @Experimental
def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
val evaluator = new SumEvaluator(self.partitions.size, confidence)
@@ -114,13 +123,13 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
* e.g. for the array
* [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50]
* e.g 1<=x<10 , 10<=x<20, 20<=x<50
- * And on the input of 1 and 50 we would have a histogram of 1, 0, 0
- *
+ * And on the input of 1 and 50 we would have a histogram of 1, 0, 0
+ *
* Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
* from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
* to true.
* buckets must be sorted and not contain any duplicates.
- * buckets array must be at least two elements
+ * buckets array must be at least two elements
* All NaN entries are treated the same. If you have a NaN bucket it must be
* the maximum value of the last position and all NaN entries will be counted
* in that bucket.
diff --git a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
index a84e5f9fd8ef8..a2d7e344cf1b2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
@@ -22,9 +22,9 @@ import scala.reflect.ClassTag
import org.apache.spark.{Partition, SparkContext, TaskContext}
/**
- * An RDD that is empty, i.e. has no element in it.
+ * An RDD that has no partitions and no elements.
*/
-class EmptyRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) {
+private[spark] class EmptyRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) {
override def getPartitions: Array[Partition] = Array.empty
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index a374fc4a871b0..6547755764dcf 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -17,17 +17,25 @@
package org.apache.spark.rdd
+import java.text.SimpleDateFormat
+import java.util.Date
import java.io.EOFException
+import scala.collection.immutable.Map
import org.apache.hadoop.conf.{Configurable, Configuration}
+import org.apache.hadoop.mapred.FileSplit
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.InputSplit
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
+import org.apache.hadoop.mapred.JobID
+import org.apache.hadoop.mapred.TaskAttemptID
+import org.apache.hadoop.mapred.TaskID
import org.apache.hadoop.util.ReflectionUtils
import org.apache.spark._
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.NextIterator
@@ -43,12 +51,33 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
override def hashCode(): Int = 41 * (41 + rddId) + idx
override val index: Int = idx
+
+ /**
+ * Get any environment variables that should be added to the users environment when running pipes
+ * @return a Map with the environment variables and corresponding values, it could be empty
+ */
+ def getPipeEnvVars(): Map[String, String] = {
+ val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) {
+ val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit]
+ // map_input_file is deprecated in favor of mapreduce_map_input_file but set both
+ // since its not removed yet
+ Map("map_input_file" -> is.getPath().toString(),
+ "mapreduce_map_input_file" -> is.getPath().toString())
+ } else {
+ Map()
+ }
+ envVars
+ }
}
/**
+ * :: DeveloperApi ::
* An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
* sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`).
*
+ * Note: Instantiating this class directly is not recommended, please use
+ * [[org.apache.spark.SparkContext.hadoopRDD()]]
+ *
* @param sc The SparkContext to associate the RDD with.
* @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
* variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job.
@@ -58,8 +87,9 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
* @param inputFormatClass Storage format of the data to be read.
* @param keyClass Class of the key associated with the inputFormatClass.
* @param valueClass Class of the value associated with the inputFormatClass.
- * @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate.
+ * @param minPartitions Minimum number of HadoopRDD partitions (Hadoop Splits) to generate.
*/
+@DeveloperApi
class HadoopRDD[K, V](
sc: SparkContext,
broadcastedConf: Broadcast[SerializableWritable[Configuration]],
@@ -67,7 +97,7 @@ class HadoopRDD[K, V](
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
- minSplits: Int)
+ minPartitions: Int)
extends RDD[(K, V)](sc, Nil) with Logging {
def this(
@@ -76,7 +106,7 @@ class HadoopRDD[K, V](
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
- minSplits: Int) = {
+ minPartitions: Int) = {
this(
sc,
sc.broadcast(new SerializableWritable(conf))
@@ -85,13 +115,16 @@ class HadoopRDD[K, V](
inputFormatClass,
keyClass,
valueClass,
- minSplits)
+ minPartitions)
}
protected val jobConfCacheKey = "rdd_%d_job_conf".format(id)
protected val inputFormatCacheKey = "rdd_%d_input_format".format(id)
+ // used to build JobTracker ID
+ private val createTime = new Date()
+
// Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads.
protected def getJobConf(): JobConf = {
val conf: Configuration = broadcastedConf.value.value
@@ -136,7 +169,7 @@ class HadoopRDD[K, V](
if (inputFormat.isInstanceOf[Configurable]) {
inputFormat.asInstanceOf[Configurable].setConf(jobConf)
}
- val inputSplits = inputFormat.getSplits(jobConf, minSplits)
+ val inputSplits = inputFormat.getSplits(jobConf, minPartitions)
val array = new Array[Partition](inputSplits.size)
for (i <- 0 until inputSplits.size) {
array(i) = new HadoopPartition(id, i, inputSplits(i))
@@ -144,14 +177,16 @@ class HadoopRDD[K, V](
array
}
- override def compute(theSplit: Partition, context: TaskContext) = {
+ override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
val iter = new NextIterator[(K, V)] {
+
val split = theSplit.asInstanceOf[HadoopPartition]
logInfo("Input split: " + split.inputSplit)
var reader: RecordReader[K, V] = null
-
val jobConf = getJobConf()
val inputFormat = getInputFormat(jobConf)
+ HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
+ context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
@@ -203,4 +238,17 @@ private[spark] object HadoopRDD {
def putCachedMetadata(key: String, value: Any) =
SparkEnv.get.hadoopJobMetadata.put(key, value)
+
+ /** Add Hadoop configuration specific to a single partition and attempt. */
+ def addLocalConfiguration(jobTrackerId: String, jobId: Int, splitId: Int, attemptId: Int,
+ conf: JobConf) {
+ val jobID = new JobID(jobTrackerId, jobId)
+ val taId = new TaskAttemptID(new TaskID(jobID, true, splitId), attemptId)
+
+ conf.set("mapred.tip.id", taId.getTaskID.toString)
+ conf.set("mapred.task.id", taId.toString)
+ conf.setBoolean("mapred.task.is.map", true)
+ conf.setInt("mapred.task.partition", splitId)
+ conf.set("mapred.job.id", jobID.toString)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 8df8718f3b65b..a76a070b5b863 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -27,7 +27,7 @@ import org.apache.spark.util.NextIterator
private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
override def index = idx
}
-
+// TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private
/**
* An RDD that executes an SQL query on a JDBC connection and reads results.
* For usage example, see test case JdbcRDDSuite.
@@ -116,7 +116,7 @@ class JdbcRDD[T: ClassTag](
}
object JdbcRDD {
- def resultSetToObjectArray(rs: ResultSet) = {
+ def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index d1fff296878c3..ac1ccc06f238a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -24,10 +24,18 @@ import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
-import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
-
-private[spark]
-class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.input.WholeTextFileInputFormat
+import org.apache.spark.InterruptibleIterator
+import org.apache.spark.Logging
+import org.apache.spark.Partition
+import org.apache.spark.SerializableWritable
+import org.apache.spark.{SparkContext, TaskContext}
+
+private[spark] class NewHadoopPartition(
+ rddId: Int,
+ val index: Int,
+ @transient rawSplit: InputSplit with Writable)
extends Partition {
val serializableHadoopSplit = new SerializableWritable(rawSplit)
@@ -36,15 +44,20 @@ class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputS
}
/**
+ * :: DeveloperApi ::
* An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
* sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`).
*
+ * Note: Instantiating this class directly is not recommended, please use
+ * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]]
+ *
* @param sc The SparkContext to associate the RDD with.
* @param inputFormatClass Storage format of the data to be read.
* @param keyClass Class of the key associated with the inputFormatClass.
* @param valueClass Class of the value associated with the inputFormatClass.
* @param conf The Hadoop configuration.
*/
+@DeveloperApi
class NewHadoopRDD[K, V](
sc : SparkContext,
inputFormatClass: Class[_ <: InputFormat[K, V]],
@@ -59,17 +72,19 @@ class NewHadoopRDD[K, V](
private val confBroadcast = sc.broadcast(new SerializableWritable(conf))
// private val serializableConf = new SerializableWritable(conf)
- private val jobtrackerId: String = {
+ private val jobTrackerId: String = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
formatter.format(new Date())
}
- @transient private val jobId = new JobID(jobtrackerId, id)
+ @transient protected val jobId = new JobID(jobTrackerId, id)
override def getPartitions: Array[Partition] = {
val inputFormat = inputFormatClass.newInstance
- if (inputFormat.isInstanceOf[Configurable]) {
- inputFormat.asInstanceOf[Configurable].setConf(conf)
+ inputFormat match {
+ case configurable: Configurable =>
+ configurable.setConf(conf)
+ case _ =>
}
val jobContext = newJobContext(conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray
@@ -80,16 +95,18 @@ class NewHadoopRDD[K, V](
result
}
- override def compute(theSplit: Partition, context: TaskContext) = {
+ override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
val iter = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
- val attemptId = newTaskAttemptID(jobtrackerId, id, isMap = true, split.index, 0)
+ val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
- if (format.isInstanceOf[Configurable]) {
- format.asInstanceOf[Configurable].setConf(conf)
+ format match {
+ case configurable: Configurable =>
+ configurable.setConf(conf)
+ case _ =>
}
val reader = format.createRecordReader(
split.serializableHadoopSplit.value, hadoopAttemptContext)
@@ -135,3 +152,30 @@ class NewHadoopRDD[K, V](
def getConf: Configuration = confBroadcast.value.value
}
+private[spark] class WholeTextFileRDD(
+ sc : SparkContext,
+ inputFormatClass: Class[_ <: WholeTextFileInputFormat],
+ keyClass: Class[String],
+ valueClass: Class[String],
+ @transient conf: Configuration,
+ minPartitions: Int)
+ extends NewHadoopRDD[String, String](sc, inputFormatClass, keyClass, valueClass, conf) {
+
+ override def getPartitions: Array[Partition] = {
+ val inputFormat = inputFormatClass.newInstance
+ inputFormat match {
+ case configurable: Configurable =>
+ configurable.setConf(conf)
+ case _ =>
+ }
+ val jobContext = newJobContext(conf, jobId)
+ inputFormat.setMaxSplitSize(jobContext, minPartitions)
+ val rawSplits = inputFormat.getSplits(jobContext).toArray
+ val result = new Array[Partition](rawSplits.size)
+ for (i <- 0 until rawSplits.size) {
+ result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
+ }
+ result
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index d5691f2267bfa..6a3f698444283 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -24,15 +24,31 @@ import org.apache.spark.{Logging, RangePartitioner}
/**
* Extra functions available on RDDs of (key, value) pairs where the key is sortable through
* an implicit conversion. Import `org.apache.spark.SparkContext._` at the top of your program to
- * use these functions. They will work with any key type that has a `scala.math.Ordered`
- * implementation.
+ * use these functions. They will work with any key type `K` that has an implicit `Ordering[K]` in
+ * scope. Ordering objects already exist for all of the standard primitive types. Users can also
+ * define their own orderings for custom types, or to override the default ordering. The implicit
+ * ordering that is in the closest scope will be used.
+ *
+ * {{{
+ * import org.apache.spark.SparkContext._
+ *
+ * val rdd: RDD[(String, Int)] = ...
+ * implicit val caseInsensitiveOrdering = new Ordering[String] {
+ * override def compare(a: String, b: String) = a.toLowerCase.compare(b.toLowerCase)
+ * }
+ *
+ * // Sort by key, using the above case insensitive ordering.
+ * rdd.sortByKey()
+ * }}}
*/
-class OrderedRDDFunctions[K <% Ordered[K]: ClassTag,
+class OrderedRDDFunctions[K : Ordering : ClassTag,
V: ClassTag,
P <: Product2[K, V] : ClassTag](
self: RDD[P])
extends Logging with Serializable {
+ private val ordering = implicitly[Ordering[K]]
+
/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
* `collect` or `save` on the resulting RDD will return or output an ordered list of records
@@ -45,9 +61,9 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassTag,
shuffled.mapPartitions(iter => {
val buf = iter.toArray
if (ascending) {
- buf.sortWith((x, y) => x._1 < y._1).iterator
+ buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator
} else {
- buf.sortWith((x, y) => x._1 > y._1).iterator
+ buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator
}
}, preservesPartitioning = true)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index d29a1a9881cd4..5efb4388f6c71 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -30,34 +30,34 @@ import scala.reflect.ClassTag
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.hadoop.conf.{Configurable, Configuration}
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
-import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
-import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
-import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter}
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job => NewAPIHadoopJob,
+RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
-// SparkHadoopWriter and SparkHadoopMapReduceUtil are actually source files defined in Spark.
-import org.apache.hadoop.mapred.SparkHadoopWriter
-import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil
-
import org.apache.spark._
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.SparkHadoopWriter
import org.apache.spark.Partitioner.defaultPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.partial.{BoundedDouble, PartialResult}
+import org.apache.spark.serializer.Serializer
import org.apache.spark.util.SerializableHyperLogLog
/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
-class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
+class PairRDDFunctions[K, V](self: RDD[(K, V)])
+ (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null)
extends Logging
with SparkHadoopMapReduceUtil
- with Serializable {
-
+ with Serializable
+{
/**
* Generic function to combine the elements for each key using a custom set of aggregation
* functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C
@@ -76,9 +76,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
mapSideCombine: Boolean = true,
- serializerClass: String = null): RDD[(K, C)] = {
+ serializer: Serializer = null): RDD[(K, C)] = {
require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0
- if (getKeyClass().isArray) {
+ if (keyClass.isArray) {
if (mapSideCombine) {
throw new SparkException("Cannot use map-side combining with array keys.")
}
@@ -96,13 +96,13 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
aggregator.combineValuesByKey(iter, context)
}, preservesPartitioning = true)
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
- .setSerializer(serializerClass)
+ .setSerializer(serializer)
partitioned.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter, context))
}, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
- val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
+ val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializer)
values.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
}, preservesPartitioning = true)
@@ -171,7 +171,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
*/
def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
- if (getKeyClass().isArray) {
+ if (keyClass.isArray) {
throw new SparkException("reduceByKeyLocally() does not support array keys")
}
@@ -196,15 +196,18 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
}
/** Alias for reduceByKeyLocally */
+ @deprecated("Use reduceByKeyLocally", "1.0.0")
def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func)
/** Count the number of elements for each key, and return the result to the master as a Map. */
def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
/**
- * (Experimental) Approximate version of countByKey that can return a partial result if it does
+ * :: Experimental ::
+ * Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
+ @Experimental
def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
: PartialResult[Map[K, BoundedDouble]] = {
self.map(_._1).countByValueApprox(timeout, confidence)
@@ -262,7 +265,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* Group the values for each key in the RDD into a single sequence. Allows controlling the
* partitioning of the resulting key-value pair RDD by passing a Partitioner.
*/
- def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = {
+ def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = {
// groupByKey shouldn't use map side combine because map side combine does not
// reduce the amount of data shuffled and requires all map side data be inserted
// into a hash table, leading to more objects in the old gen.
@@ -271,14 +274,14 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
def mergeCombiners(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = c1 ++ c2
val bufs = combineByKey[ArrayBuffer[V]](
createCombiner _, mergeValue _, mergeCombiners _, partitioner, mapSideCombine=false)
- bufs.asInstanceOf[RDD[(K, Seq[V])]]
+ bufs.mapValues(_.toIterable)
}
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
* resulting RDD with into `numPartitions` partitions.
*/
- def groupByKey(numPartitions: Int): RDD[(K, Seq[V])] = {
+ def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = {
groupByKey(new HashPartitioner(numPartitions))
}
@@ -286,10 +289,14 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* Return a copy of the RDD partitioned using the specified partitioner.
*/
def partitionBy(partitioner: Partitioner): RDD[(K, V)] = {
- if (getKeyClass().isArray && partitioner.isInstanceOf[HashPartitioner]) {
+ if (keyClass.isArray && partitioner.isInstanceOf[HashPartitioner]) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
- if (self.partitioner == partitioner) self else new ShuffledRDD[K, V, (K, V)](self, partitioner)
+ if (self.partitioner == Some(partitioner)) {
+ self
+ } else {
+ new ShuffledRDD[K, V, (K, V)](self, partitioner)
+ }
}
/**
@@ -299,7 +306,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
*/
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
this.cogroup(other, partitioner).flatMapValues { case (vs, ws) =>
- for (v <- vs.iterator; w <- ws.iterator) yield (v, w)
+ for (v <- vs; w <- ws) yield (v, w)
}
}
@@ -312,9 +319,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = {
this.cogroup(other, partitioner).flatMapValues { case (vs, ws) =>
if (ws.isEmpty) {
- vs.iterator.map(v => (v, None))
+ vs.map(v => (v, None))
} else {
- for (v <- vs.iterator; w <- ws.iterator) yield (v, Some(w))
+ for (v <- vs; w <- ws) yield (v, Some(w))
}
}
}
@@ -329,9 +336,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
: RDD[(K, (Option[V], W))] = {
this.cogroup(other, partitioner).flatMapValues { case (vs, ws) =>
if (vs.isEmpty) {
- ws.iterator.map(w => (None, w))
+ ws.map(w => (None, w))
} else {
- for (v <- vs.iterator; w <- ws.iterator) yield (Some(v), w)
+ for (v <- vs; w <- ws) yield (Some(v), w)
}
}
}
@@ -359,7 +366,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
* resulting RDD with the existing partitioner/parallelism level.
*/
- def groupByKey(): RDD[(K, Seq[V])] = {
+ def groupByKey(): RDD[(K, Iterable[V])] = {
groupByKey(defaultPartitioner(self))
}
@@ -425,7 +432,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* Return the key-value pairs in this RDD to the master as a Map.
*/
def collectAsMap(): Map[K, V] = {
- val data = self.toArray()
+ val data = self.collect()
val map = new mutable.HashMap[K, V]
map.sizeHint(data.length)
data.foreach { case (k, v) => map.put(k, v) }
@@ -454,8 +461,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
- def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = {
- if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
+ def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner)
+ : RDD[(K, (Iterable[V], Iterable[W]))] = {
+ if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other), partitioner)
@@ -469,13 +477,15 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner)
- : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
- if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
+ : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = {
+ if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner)
cg.mapValues { case Seq(vs, w1s, w2s) =>
- (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]])
+ (vs.asInstanceOf[Seq[V]],
+ w1s.asInstanceOf[Seq[W1]],
+ w2s.asInstanceOf[Seq[W2]])
}
}
@@ -483,7 +493,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
- def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = {
+ def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = {
cogroup(other, defaultPartitioner(self, other))
}
@@ -492,7 +502,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)])
- : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
+ : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = {
cogroup(other1, other2, defaultPartitioner(self, other1, other2))
}
@@ -500,7 +510,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
- def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Seq[V], Seq[W]))] = {
+ def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Iterable[V], Iterable[W]))] = {
cogroup(other, new HashPartitioner(numPartitions))
}
@@ -509,18 +519,18 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numPartitions: Int)
- : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
+ : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = {
cogroup(other1, other2, new HashPartitioner(numPartitions))
}
/** Alias for cogroup. */
- def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = {
+ def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = {
cogroup(other, defaultPartitioner(self, other))
}
/** Alias for cogroup. */
def groupWith[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)])
- : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
+ : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = {
cogroup(other1, other2, defaultPartitioner(self, other1, other2))
}
@@ -568,7 +578,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* supporting the key and value types K and V in this RDD.
*/
def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) {
- saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]])
+ saveAsHadoopFile(path, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]])
}
/**
@@ -579,7 +589,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
def saveAsHadoopFile[F <: OutputFormat[K, V]](
path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassTag[F]) {
val runtimeClass = fm.runtimeClass
- saveAsHadoopFile(path, getKeyClass, getValueClass, runtimeClass.asInstanceOf[Class[F]], codec)
+ saveAsHadoopFile(path, keyClass, valueClass, runtimeClass.asInstanceOf[Class[F]], codec)
}
/**
@@ -587,7 +597,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
*/
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) {
- saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]])
+ saveAsNewAPIHadoopFile(path, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]])
}
/**
@@ -604,46 +614,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
val job = new NewAPIHadoopJob(conf)
job.setOutputKeyClass(keyClass)
job.setOutputValueClass(valueClass)
- val wrappedConf = new SerializableWritable(job.getConfiguration)
- NewFileOutputFormat.setOutputPath(job, new Path(path))
- val formatter = new SimpleDateFormat("yyyyMMddHHmm")
- val jobtrackerID = formatter.format(new Date())
- val stageId = self.id
- def writeShard(context: TaskContext, iter: Iterator[(K,V)]): Int = {
- // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
- // around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
- /* "reduce task" */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
- attemptNumber)
- val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
- val format = outputFormatClass.newInstance
- format match {
- case c: Configurable => c.setConf(wrappedConf.value)
- case _ => ()
- }
- val committer = format.getOutputCommitter(hadoopContext)
- committer.setupTask(hadoopContext)
- val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
- while (iter.hasNext) {
- val (k, v) = iter.next()
- writer.write(k, v)
- }
- writer.close(hadoopContext)
- committer.commitTask(hadoopContext)
- return 1
- }
- val jobFormat = outputFormatClass.newInstance
- /* apparently we need a TaskAttemptID to construct an OutputCommitter;
- * however we're only going to use this local OutputCommitter for
- * setupJob/commitJob, so we just use a dummy "map" task.
- */
- val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0)
- val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
- val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
- jobCommitter.setupJob(jobTaskContext)
- val count = self.context.runJob(self, writeShard _).sum
- jobCommitter.commitJob(jobTaskContext)
+ job.setOutputFormatClass(outputFormatClass)
+ job.getConfiguration.set("mapred.output.dir", path)
+ saveAsNewAPIHadoopDataset(job.getConfiguration)
}
/**
@@ -689,6 +662,63 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
saveAsHadoopDataset(conf)
}
+ /**
+ * Output the RDD to any Hadoop-supported storage system with new Hadoop API, using a Hadoop
+ * Configuration object for that storage system. The Conf should set an OutputFormat and any
+ * output paths required (e.g. a table name to write to) in the same way as it would be
+ * configured for a Hadoop MapReduce job.
+ */
+ def saveAsNewAPIHadoopDataset(conf: Configuration) {
+ val job = new NewAPIHadoopJob(conf)
+ val formatter = new SimpleDateFormat("yyyyMMddHHmm")
+ val jobtrackerID = formatter.format(new Date())
+ val stageId = self.id
+ val wrappedConf = new SerializableWritable(job.getConfiguration)
+ val outfmt = job.getOutputFormatClass
+ val jobFormat = outfmt.newInstance
+
+ if (jobFormat.isInstanceOf[NewFileOutputFormat[_, _]]) {
+ // FileOutputFormat ignores the filesystem parameter
+ jobFormat.checkOutputSpecs(job)
+ }
+
+ def writeShard(context: TaskContext, iter: Iterator[(K,V)]): Int = {
+ // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
+ // around by taking a mod. We expect that no task will be attempted 2 billion times.
+ val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ /* "reduce task" */
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
+ attemptNumber)
+ val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
+ val format = outfmt.newInstance
+ format match {
+ case c: Configurable => c.setConf(wrappedConf.value)
+ case _ => ()
+ }
+ val committer = format.getOutputCommitter(hadoopContext)
+ committer.setupTask(hadoopContext)
+ val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
+ try {
+ while (iter.hasNext) {
+ val (k, v) = iter.next()
+ writer.write(k, v)
+ }
+ }
+ finally {
+ writer.close(hadoopContext)
+ }
+ committer.commitTask(hadoopContext)
+ return 1
+ }
+
+ val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0)
+ val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
+ val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
+ jobCommitter.setupJob(jobTaskContext)
+ self.context.runJob(self, writeShard _)
+ jobCommitter.commitJob(jobTaskContext)
+ }
+
/**
* Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for
* that storage system. The JobConf should set an OutputFormat and any output paths required
@@ -696,10 +726,10 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* MapReduce job.
*/
def saveAsHadoopDataset(conf: JobConf) {
- val outputFormatClass = conf.getOutputFormat
+ val outputFormatInstance = conf.getOutputFormat
val keyClass = conf.getOutputKeyClass
val valueClass = conf.getOutputValueClass
- if (outputFormatClass == null) {
+ if (outputFormatInstance == null) {
throw new SparkException("Output format class not set")
}
if (keyClass == null) {
@@ -708,10 +738,17 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
if (valueClass == null) {
throw new SparkException("Output value class not set")
}
+ SparkHadoopUtil.get.addCredentials(conf)
logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " +
valueClass.getSimpleName + ")")
+ if (outputFormatInstance.isInstanceOf[FileOutputFormat[_, _]]) {
+ // FileOutputFormat ignores the filesystem parameter
+ val ignoredFs = FileSystem.get(conf)
+ conf.getOutputFormat.checkOutputSpecs(ignoredFs, conf)
+ }
+
val writer = new SparkHadoopWriter(conf)
writer.preSetup()
@@ -722,15 +759,17 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.open()
-
- var count = 0
- while(iter.hasNext) {
- val record = iter.next()
- count += 1
- writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
+ try {
+ var count = 0
+ while(iter.hasNext) {
+ val record = iter.next()
+ count += 1
+ writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
+ }
+ }
+ finally {
+ writer.close()
}
-
- writer.close()
writer.commit()
}
@@ -748,7 +787,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
*/
def values: RDD[V] = self.map(_._2)
- private[spark] def getKeyClass() = implicitly[ClassTag[K]].runtimeClass
+ private[spark] def keyClass: Class[_] = kt.runtimeClass
+
+ private[spark] def valueClass: Class[_] = vt.runtimeClass
- private[spark] def getValueClass() = implicitly[ClassTag[V]].runtimeClass
+ private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
index b0440ca7f32cf..f781a8d776f2a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
@@ -20,8 +20,10 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
import org.apache.spark.{NarrowDependency, Partition, TaskContext}
+import org.apache.spark.annotation.DeveloperApi
-class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition) extends Partition {
+private[spark] class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition)
+ extends Partition {
override val index = idx
}
@@ -30,7 +32,7 @@ class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition) extends
* Represents a dependency between the PartitionPruningRDD and its parent. In this
* case, the child RDD contains a subset of partitions of the parents'.
*/
-class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
+private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
extends NarrowDependency[T](rdd) {
@transient
@@ -45,11 +47,13 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo
/**
+ * :: DeveloperApi ::
* A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on
* all partitions. An example use case: If we know the RDD is partitioned by range,
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
*/
+@DeveloperApi
class PartitionPruningRDD[T: ClassTag](
@transient prev: RDD[T],
@transient partitionFilterFunc: Int => Boolean)
@@ -63,6 +67,7 @@ class PartitionPruningRDD[T: ClassTag](
}
+@DeveloperApi
object PartitionPruningRDD {
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
index a84357b38414e..0c2cd7a24783b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
@@ -33,7 +33,7 @@ class PartitionerAwareUnionRDDPartition(
val idx: Int
) extends Partition {
var parents = rdds.map(_.partitions(idx)).toArray
-
+
override val index = idx
override def hashCode(): Int = idx
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
index ce4c0d382baab..b5b8a5706deb3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.random.RandomSampler
+import org.apache.spark.util.Utils
private[spark]
class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
@@ -38,14 +39,14 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
*
* @param prev RDD to be sampled
* @param sampler a random sampler
- * @param seed random seed, default to System.nanoTime
+ * @param seed random seed
* @tparam T input RDD item type
* @tparam U sampled RDD item type
*/
-class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
+private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
prev: RDD[T],
sampler: RandomSampler[T, U],
- @transient seed: Long = System.nanoTime)
+ @transient seed: Long = Utils.random.nextLong)
extends RDD[U](prev) {
override def getPartitions: Array[Partition] = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index abd4414e81f5c..5d77d37378458 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -17,6 +17,9 @@
package org.apache.spark.rdd
+import java.io.File
+import java.io.FilenameFilter
+import java.io.IOException
import java.io.PrintWriter
import java.util.StringTokenizer
@@ -27,17 +30,20 @@ import scala.io.Source
import scala.reflect.ClassTag
import org.apache.spark.{Partition, SparkEnv, TaskContext}
+import org.apache.spark.util.Utils
+
/**
* An RDD that pipes the contents of each parent partition through an external command
* (printing them one per line) and returns the output as a collection of strings.
*/
-class PipedRDD[T: ClassTag](
+private[spark] class PipedRDD[T: ClassTag](
prev: RDD[T],
command: Seq[String],
envVars: Map[String, String],
printPipeContext: (String => Unit) => Unit,
- printRDDElement: (T, String => Unit) => Unit)
+ printRDDElement: (T, String => Unit) => Unit,
+ separateWorkingDir: Boolean)
extends RDD[String](prev) {
// Similar to Runtime.exec(), if we are given a single string, split it into words
@@ -47,18 +53,69 @@ class PipedRDD[T: ClassTag](
command: String,
envVars: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
- printRDDElement: (T, String => Unit) => Unit = null) =
- this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement)
+ printRDDElement: (T, String => Unit) => Unit = null,
+ separateWorkingDir: Boolean = false) =
+ this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement,
+ separateWorkingDir)
override def getPartitions: Array[Partition] = firstParent[T].partitions
+ /**
+ * A FilenameFilter that accepts anything that isn't equal to the name passed in.
+ * @param name of file or directory to leave out
+ */
+ class NotEqualsFileNameFilter(filterName: String) extends FilenameFilter {
+ def accept(dir: File, name: String): Boolean = {
+ !name.equals(filterName)
+ }
+ }
+
override def compute(split: Partition, context: TaskContext): Iterator[String] = {
val pb = new ProcessBuilder(command)
// Add the environmental variables to the process.
val currentEnvVars = pb.environment()
envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) }
+ // for compatibility with Hadoop which sets these env variables
+ // so the user code can access the input filename
+ if (split.isInstanceOf[HadoopPartition]) {
+ val hadoopSplit = split.asInstanceOf[HadoopPartition]
+ currentEnvVars.putAll(hadoopSplit.getPipeEnvVars())
+ }
+
+ // When spark.worker.separated.working.directory option is turned on, each
+ // task will be run in separate directory. This should be resolve file
+ // access conflict issue
+ val taskDirectory = "tasks" + File.separator + java.util.UUID.randomUUID.toString
+ var workInTaskDirectory = false
+ logDebug("taskDirectory = " + taskDirectory)
+ if (separateWorkingDir) {
+ val currentDir = new File(".")
+ logDebug("currentDir = " + currentDir.getAbsolutePath())
+ val taskDirFile = new File(taskDirectory)
+ taskDirFile.mkdirs()
+
+ try {
+ val tasksDirFilter = new NotEqualsFileNameFilter("tasks")
+
+ // Need to add symlinks to jars, files, and directories. On Yarn we could have
+ // directories and other files not known to the SparkContext that were added via the
+ // Hadoop distributed cache. We also don't want to symlink to the /tasks directories we
+ // are creating here.
+ for (file <- currentDir.list(tasksDirFilter)) {
+ val fileWithDir = new File(currentDir, file)
+ Utils.symlink(new File(fileWithDir.getAbsolutePath()),
+ new File(taskDirectory + File.separator + fileWithDir.getName()))
+ }
+ pb.directory(taskDirFile)
+ workInTaskDirectory = true
+ } catch {
+ case e: Exception => logError("Unable to setup task working directory: " + e.getMessage +
+ " (" + taskDirectory + ")", e)
+ }
+ }
+
val proc = pb.start()
val env = SparkEnv.get
@@ -104,6 +161,15 @@ class PipedRDD[T: ClassTag](
if (exitStatus != 0) {
throw new Exception("Subprocess exited with status " + exitStatus)
}
+
+ // cleanup task working directory if used
+ if (workInTaskDirectory == true) {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ Utils.deleteRecursively(new File(taskDirectory))
+ }
+ logDebug("Removed task working directory " + taskDirectory)
+ }
+
false
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 50320f40350cd..3b3524f33e811 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -20,12 +20,11 @@ package org.apache.spark.rdd
import java.util.Random
import scala.collection.Map
-import scala.collection.JavaConversions.mapAsScalaMap
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.{classTag, ClassTag}
import com.clearspring.analytics.stream.cardinality.HyperLogLog
-import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.NullWritable
@@ -35,6 +34,7 @@ import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.spark._
import org.apache.spark.Partitioner._
import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
@@ -42,6 +42,7 @@ import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, SerializableHyperLogLog, Utils}
+import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler}
/**
@@ -86,7 +87,11 @@ abstract class RDD[T: ClassTag](
// Methods that should be implemented by subclasses of RDD
// =======================================================================
- /** Implemented by subclasses to compute a given partition. */
+ /**
+ * :: DeveloperApi ::
+ * Implemented by subclasses to compute a given partition.
+ */
+ @DeveloperApi
def compute(split: Partition, context: TaskContext): Iterator[T]
/**
@@ -101,7 +106,9 @@ abstract class RDD[T: ClassTag](
*/
protected def getDependencies: Seq[Dependency[_]] = deps
- /** Optionally overridden by subclasses to specify placement preferences. */
+ /**
+ * Optionally overridden by subclasses to specify placement preferences.
+ */
protected def getPreferredLocations(split: Partition): Seq[String] = Nil
/** Optionally overridden by subclasses to specify how they are partitioned. */
@@ -121,19 +128,11 @@ abstract class RDD[T: ClassTag](
@transient var name: String = null
/** Assign a name to this RDD */
- def setName(_name: String) = {
+ def setName(_name: String): RDD[T] = {
name = _name
this
}
- /** User-defined generator of this RDD*/
- @transient var generator = Utils.getCallSiteInfo.firstUserClass
-
- /** Reset generator*/
- def setGenerator(_generator: String) = {
- generator = _generator
- }
-
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. This can only be used to assign a new storage level if the RDD does not
@@ -145,9 +144,10 @@ abstract class RDD[T: ClassTag](
throw new UnsupportedOperationException(
"Cannot change storage level of an RDD after it was already assigned a level")
}
+ sc.persistRDD(this)
+ // Register the RDD with the ContextCleaner for automatic GC-based cleanup
+ sc.cleaner.foreach(_.registerRDDForCleanup(this))
storageLevel = newLevel
- // Register the RDD with the SparkContext
- sc.persistentRdds(id) = this
this
}
@@ -165,8 +165,7 @@ abstract class RDD[T: ClassTag](
*/
def unpersist(blocking: Boolean = true): RDD[T] = {
logInfo("Removing RDD " + id + " from persistence list")
- sc.env.blockManager.master.removeRdd(id, blocking)
- sc.persistentRdds.remove(id)
+ sc.unpersistRDD(id, blocking)
storageLevel = StorageLevel.NONE
this
}
@@ -231,6 +230,30 @@ abstract class RDD[T: ClassTag](
}
}
+ /**
+ * Return the ancestors of the given RDD that are related to it only through a sequence of
+ * narrow dependencies. This traverses the given RDD's dependency tree using DFS, but maintains
+ * no ordering on the RDDs returned.
+ */
+ private[spark] def getNarrowAncestors: Seq[RDD[_]] = {
+ val ancestors = new mutable.HashSet[RDD[_]]
+
+ def visit(rdd: RDD[_]) {
+ val narrowDependencies = rdd.dependencies.filter(_.isInstanceOf[NarrowDependency[_]])
+ val narrowParents = narrowDependencies.map(_.rdd)
+ val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains)
+ narrowParentsNotVisited.foreach { parent =>
+ ancestors.add(parent)
+ visit(parent)
+ }
+ }
+
+ visit(this)
+
+ // In case there is a cycle, do not include the root itself
+ ancestors.filterNot(_ == this).toSeq
+ }
+
/**
* Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
*/
@@ -261,7 +284,7 @@ abstract class RDD[T: ClassTag](
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
- def distinct(numPartitions: Int): RDD[T] =
+ def distinct(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] =
map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1)
/**
@@ -278,7 +301,7 @@ abstract class RDD[T: ClassTag](
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
*/
- def repartition(numPartitions: Int): RDD[T] = {
+ def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = {
coalesce(numPartitions, shuffle = true)
}
@@ -302,7 +325,8 @@ abstract class RDD[T: ClassTag](
* coalesce(1000, shuffle = true) will result in 1000 partitions with the
* data distributed using a hash partitioner.
*/
- def coalesce(numPartitions: Int, shuffle: Boolean = false): RDD[T] = {
+ def coalesce(numPartitions: Int, shuffle: Boolean = false)(implicit ord: Ordering[T] = null)
+ : RDD[T] = {
if (shuffle) {
// include a shuffle step so that our upstream tasks are still distributed
new CoalescedRDD(
@@ -317,7 +341,10 @@ abstract class RDD[T: ClassTag](
/**
* Return a sampled subset of this RDD.
*/
- def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
+ def sample(withReplacement: Boolean,
+ fraction: Double,
+ seed: Long = Utils.random.nextLong): RDD[T] = {
+ require(fraction >= 0.0, "Invalid fraction value: " + fraction)
if (withReplacement) {
new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
} else {
@@ -329,11 +356,11 @@ abstract class RDD[T: ClassTag](
* Randomly splits this RDD with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1
- * @param seed random seed, default to System.nanoTime
+ * @param seed random seed
*
* @return split RDDs in an array
*/
- def randomSplit(weights: Array[Double], seed: Long = System.nanoTime): Array[RDD[T]] = {
+ def randomSplit(weights: Array[Double], seed: Long = Utils.random.nextLong): Array[RDD[T]] = {
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
@@ -341,7 +368,8 @@ abstract class RDD[T: ClassTag](
}.toArray
}
- def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
+ def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] =
+ {
var fraction = 0.0
var total = 0
val multiplier = 3.0
@@ -352,6 +380,10 @@ abstract class RDD[T: ClassTag](
throw new IllegalArgumentException("Negative number of elements requested")
}
+ if (initialCount == 0) {
+ return new Array[T](0)
+ }
+
if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1
} else {
@@ -370,7 +402,7 @@ abstract class RDD[T: ClassTag](
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
// If the first sample didn't turn out large enough, keep trying to take samples;
- // this shouldn't happen often because we use a big multiplier for thei initial size
+ // this shouldn't happen often because we use a big multiplier for the initial size
while (samples.length < total) {
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
}
@@ -396,10 +428,11 @@ abstract class RDD[T: ClassTag](
*
* Note that this method performs a shuffle internally.
*/
- def intersection(other: RDD[T]): RDD[T] =
+ def intersection(other: RDD[T]): RDD[T] = {
this.map(v => (v, null)).cogroup(other.map(v => (v, null)))
.filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty }
.keys
+ }
/**
* Return the intersection of this RDD and another one. The output will not contain any duplicate
@@ -409,10 +442,12 @@ abstract class RDD[T: ClassTag](
*
* @param partitioner Partitioner to use for the resulting RDD
*/
- def intersection(other: RDD[T], partitioner: Partitioner): RDD[T] =
+ def intersection(other: RDD[T], partitioner: Partitioner)(implicit ord: Ordering[T] = null)
+ : RDD[T] = {
this.map(v => (v, null)).cogroup(other.map(v => (v, null)), partitioner)
.filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty }
.keys
+ }
/**
* Return the intersection of this RDD and another one. The output will not contain any duplicate
@@ -422,10 +457,11 @@ abstract class RDD[T: ClassTag](
*
* @param numPartitions How many partitions to use in the resulting RDD
*/
- def intersection(other: RDD[T], numPartitions: Int): RDD[T] =
+ def intersection(other: RDD[T], numPartitions: Int): RDD[T] = {
this.map(v => (v, null)).cogroup(other.map(v => (v, null)), new HashPartitioner(numPartitions))
.filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty }
.keys
+ }
/**
* Return an RDD created by coalescing all elements within each partition into an array.
@@ -439,22 +475,25 @@ abstract class RDD[T: ClassTag](
def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)
/**
- * Return an RDD of grouped items.
+ * Return an RDD of grouped items. Each group consists of a key and a sequence of elements
+ * mapping to that key.
*/
- def groupBy[K: ClassTag](f: T => K): RDD[(K, Seq[T])] =
+ def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] =
groupBy[K](f, defaultPartitioner(this))
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K: ClassTag](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] =
+ def groupBy[K](f: T => K, numPartitions: Int)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] =
groupBy(f, new HashPartitioner(numPartitions))
/**
- * Return an RDD of grouped items.
+ * Return an RDD of grouped items. Each group consists of a key and a sequence of elements
+ * mapping to that key.
*/
- def groupBy[K: ClassTag](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = {
+ def groupBy[K](f: T => K, p: Partitioner)(implicit kt: ClassTag[K], ord: Ordering[K] = null)
+ : RDD[(K, Iterable[T])] = {
val cleanF = sc.clean(f)
this.map(t => (cleanF(t), t)).groupByKey(p)
}
@@ -486,16 +525,19 @@ abstract class RDD[T: ClassTag](
* instead of constructing a huge String to concat all the elements:
* def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
* for (e <- record._2){f(e)}
+ * @param separateWorkingDir Use separate working directories for each task.
* @return the result RDD
*/
def pipe(
command: Seq[String],
env: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
- printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = {
+ printRDDElement: (T, String => Unit) => Unit = null,
+ separateWorkingDir: Boolean = false): RDD[String] = {
new PipedRDD(this, command, env,
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
- if (printRDDElement ne null) sc.clean(printRDDElement) else null)
+ if (printRDDElement ne null) sc.clean(printRDDElement) else null,
+ separateWorkingDir)
}
/**
@@ -518,9 +560,11 @@ abstract class RDD[T: ClassTag](
}
/**
+ * :: DeveloperApi ::
* Return a new RDD by applying a function to each partition of this RDD. This is a variant of
* mapPartitions that also passes the TaskContext into the closure.
*/
+ @DeveloperApi
def mapPartitionsWithContext[U: ClassTag](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {
@@ -543,7 +587,8 @@ abstract class RDD[T: ClassTag](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def mapWith[A: ClassTag, U: ClassTag]
+ @deprecated("use mapPartitionsWithIndex", "1.0.0")
+ def mapWith[A, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => U): RDD[U] = {
mapPartitionsWithIndex((index, iter) => {
@@ -557,7 +602,8 @@ abstract class RDD[T: ClassTag](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def flatMapWith[A: ClassTag, U: ClassTag]
+ @deprecated("use mapPartitionsWithIndex and flatMap", "1.0.0")
+ def flatMapWith[A, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => Seq[U]): RDD[U] = {
mapPartitionsWithIndex((index, iter) => {
@@ -571,7 +617,8 @@ abstract class RDD[T: ClassTag](
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) {
+ @deprecated("use mapPartitionsWithIndex and foreach", "1.0.0")
+ def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit) {
mapPartitionsWithIndex { (index, iter) =>
val a = constructA(index)
iter.map(t => {f(t, a); t})
@@ -583,7 +630,8 @@ abstract class RDD[T: ClassTag](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
+ @deprecated("use mapPartitionsWithIndex and filter", "1.0.0")
+ def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
mapPartitionsWithIndex((index, iter) => {
val a = constructA(index)
iter.filter(t => p(t, a))
@@ -659,9 +707,22 @@ abstract class RDD[T: ClassTag](
Array.concat(results: _*)
}
+ /**
+ * Return an iterator that contains all of the elements in this RDD.
+ *
+ * The iterator will consume as much memory as the largest partition in this RDD.
+ */
+ def toLocalIterator: Iterator[T] = {
+ def collectPartition(p: Int): Array[T] = {
+ sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head
+ }
+ (0 until partitions.length).iterator.flatMap(i => collectPartition(i))
+ }
+
/**
* Return an array that contains all of the elements in this RDD.
*/
+ @deprecated("use collect", "1.0.0")
def toArray(): Array[T] = collect()
/**
@@ -689,7 +750,7 @@ abstract class RDD[T: ClassTag](
/**
* Return an RDD with the elements from `this` that are not in `other`.
*/
- def subtract(other: RDD[T], p: Partitioner): RDD[T] = {
+ def subtract(other: RDD[T], p: Partitioner)(implicit ord: Ordering[T] = null): RDD[T] = {
if (partitioner == Some(p)) {
// Our partitioner knows how to handle T (which, since we have a partitioner, is
// really (K, V)) so make a new Partitioner that will de-tuple our fake tuples
@@ -775,9 +836,11 @@ abstract class RDD[T: ClassTag](
def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum
/**
- * (Experimental) Approximate version of count() that returns a potentially incomplete result
+ * :: Experimental ::
+ * Approximate version of count() that returns a potentially incomplete result
* within a timeout, even if not all tasks have finished.
*/
+ @Experimental
def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) =>
var result = 0L
@@ -795,46 +858,47 @@ abstract class RDD[T: ClassTag](
* Return the count of each unique value in this RDD as a map of (value, count) pairs. The final
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
- def countByValue(): Map[T, Long] = {
+ def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = {
if (elementClassTag.runtimeClass.isArray) {
throw new SparkException("countByValue() does not support arrays")
}
// TODO: This should perhaps be distributed by default.
- def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
- val map = new OLMap[T]
- while (iter.hasNext) {
- val v = iter.next()
- map.put(v, map.getLong(v) + 1L)
+ def countPartition(iter: Iterator[T]): Iterator[OpenHashMap[T,Long]] = {
+ val map = new OpenHashMap[T,Long]
+ iter.foreach {
+ t => map.changeValue(t, 1L, _ + 1L)
}
Iterator(map)
}
- def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = {
- val iter = m2.object2LongEntrySet.fastIterator()
- while (iter.hasNext) {
- val entry = iter.next()
- m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue)
+ def mergeMaps(m1: OpenHashMap[T,Long], m2: OpenHashMap[T,Long]): OpenHashMap[T,Long] = {
+ m2.foreach { case (key, value) =>
+ m1.changeValue(key, value, _ + value)
}
m1
}
val myResult = mapPartitions(countPartition).reduce(mergeMaps)
- myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map
+ // Convert to a Scala mutable map
+ val mutableResult = scala.collection.mutable.Map[T,Long]()
+ myResult.foreach { case (k, v) => mutableResult.put(k, v) }
+ mutableResult
}
/**
- * (Experimental) Approximate version of countByValue().
+ * :: Experimental ::
+ * Approximate version of countByValue().
*/
- def countByValueApprox(
- timeout: Long,
- confidence: Double = 0.95
- ): PartialResult[Map[T, BoundedDouble]] = {
+ @Experimental
+ def countByValueApprox(timeout: Long, confidence: Double = 0.95)
+ (implicit ord: Ordering[T] = null)
+ : PartialResult[Map[T, BoundedDouble]] =
+ {
if (elementClassTag.runtimeClass.isArray) {
throw new SparkException("countByValueApprox() does not support arrays")
}
- val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
- val map = new OLMap[T]
- while (iter.hasNext) {
- val v = iter.next()
- map.put(v, map.getLong(v) + 1L)
+ val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T,Long] = { (ctx, iter) =>
+ val map = new OpenHashMap[T,Long]
+ iter.foreach {
+ t => map.changeValue(t, 1L, _ + 1L)
}
map
}
@@ -843,6 +907,7 @@ abstract class RDD[T: ClassTag](
}
/**
+ * :: Experimental ::
* Return approximate number of distinct elements in the RDD.
*
* The accuracy of approximation can be controlled through the relative standard deviation
@@ -850,6 +915,7 @@ abstract class RDD[T: ClassTag](
* more accurate counts but increase the memory footprint and vise versa. The default value of
* relativeSD is 0.05.
*/
+ @Experimental
def countApproxDistinct(relativeSD: Double = 0.05): Long = {
val zeroCounter = new SerializableHyperLogLog(new HyperLogLog(relativeSD))
aggregate(zeroCounter)(_.add(_), _.merge(_)).value.cardinality()
@@ -927,32 +993,61 @@ abstract class RDD[T: ClassTag](
}
/**
- * Returns the top K elements from this RDD as defined by
- * the specified implicit Ordering[T].
+ * Returns the top K (largest) elements from this RDD as defined by the specified
+ * implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example:
+ * {{{
+ * sc.parallelize([10, 4, 2, 12, 3]).top(1)
+ * // returns [12]
+ *
+ * sc.parallelize([2, 3, 4, 5, 6]).top(2)
+ * // returns [6, 5]
+ * }}}
+ *
+ * @param num the number of top elements to return
+ * @param ord the implicit ordering for T
+ * @return an array of top elements
+ */
+ def top(num: Int)(implicit ord: Ordering[T]): Array[T] = takeOrdered(num)(ord.reverse)
+
+ /**
+ * Returns the first K (smallest) elements from this RDD as defined by the specified
+ * implicit Ordering[T] and maintains the ordering. This does the opposite of [[top]].
+ * For example:
+ * {{{
+ * sc.parallelize([10, 4, 2, 12, 3]).takeOrdered(1)
+ * // returns [12]
+ *
+ * sc.parallelize([2, 3, 4, 5, 6]).takeOrdered(2)
+ * // returns [2, 3]
+ * }}}
+ *
* @param num the number of top elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
*/
- def top(num: Int)(implicit ord: Ordering[T]): Array[T] = {
+ def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = {
mapPartitions { items =>
- val queue = new BoundedPriorityQueue[T](num)
- queue ++= items
+ // Priority keeps the largest elements, so let's reverse the ordering.
+ val queue = new BoundedPriorityQueue[T](num)(ord.reverse)
+ queue ++= util.collection.Utils.takeOrdered(items, num)(ord)
Iterator.single(queue)
}.reduce { (queue1, queue2) =>
queue1 ++= queue2
queue1
- }.toArray.sorted(ord.reverse)
+ }.toArray.sorted(ord)
}
/**
- * Returns the first K elements from this RDD as defined by
- * the specified implicit Ordering[T] and maintains the
- * ordering.
- * @param num the number of top elements to return
- * @param ord the implicit ordering for T
- * @return an array of top elements
- */
- def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = top(num)(ord.reverse)
+ * Returns the max of this RDD as defined by the implicit Ordering[T].
+ * @return the maximum element of the RDD
+ * */
+ def max()(implicit ord: Ordering[T]): T = this.reduce(ord.max)
+
+ /**
+ * Returns the min of this RDD as defined by the implicit Ordering[T].
+ * @return the minimum element of the RDD
+ * */
+ def min()(implicit ord: Ordering[T]): T = this.reduce(ord.min)
/**
* Save this RDD as a text file, using string representations of elements.
@@ -1027,8 +1122,9 @@ abstract class RDD[T: ClassTag](
private var storageLevel: StorageLevel = StorageLevel.NONE
- /** Record user function generating this RDD. */
- @transient private[spark] val origin = sc.getCallSite()
+ /** User code that created this RDD (e.g. `textFile`, `parallelize`). */
+ @transient private[spark] val creationSiteInfo = Utils.getCallSiteInfo
+ private[spark] def getCreationSite: String = creationSiteInfo.toString
private[spark] def elementClassTag: ClassTag[T] = classTag[T]
@@ -1046,9 +1142,9 @@ abstract class RDD[T: ClassTag](
@transient private var doCheckpointCalled = false
/**
- * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
- * after a job using this RDD has completed (therefore the RDD has been materialized and
- * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
+ * Performs the checkpointing of this RDD by saving this. It is called after a job using this RDD
+ * has completed (therefore the RDD has been materialized and potentially stored in memory).
+ * doCheckpoint() is called recursively on the parent RDDs.
*/
private[spark] def doCheckpoint() {
if (!doCheckpointCalled) {
@@ -1091,13 +1187,9 @@ abstract class RDD[T: ClassTag](
}
override def toString: String = "%s%s[%d] at %s".format(
- Option(name).map(_ + " ").getOrElse(""),
- getClass.getSimpleName,
- id,
- origin)
+ Option(name).map(_ + " ").getOrElse(""), getClass.getSimpleName, id, getCreationSite)
def toJavaRDD() : JavaRDD[T] = {
new JavaRDD(this)(elementClassTag)
}
-
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
index b50307cfa49b7..b097c30f8c231 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
@@ -26,14 +26,14 @@ import cern.jet.random.engine.DRand
import org.apache.spark.{Partition, TaskContext}
-@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0")
+@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0.0")
private[spark]
class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
override val index: Int = prev.index
}
-@deprecated("Replaced by PartitionwiseSampledRDD", "1.0")
-class SampledRDD[T: ClassTag](
+@deprecated("Replaced by PartitionwiseSampledRDD", "1.0.0")
+private[spark] class SampledRDD[T: ClassTag](
prev: RDD[T],
withReplacement: Boolean,
frac: Double,
diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
index 7df9a2960d8a5..9a1efc83cbe6a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
@@ -68,8 +68,8 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag
val keyClass = getWritableClass[K]
val valueClass = getWritableClass[V]
- val convertKey = !classOf[Writable].isAssignableFrom(self.getKeyClass)
- val convertValue = !classOf[Writable].isAssignableFrom(self.getValueClass)
+ val convertKey = !classOf[Writable].isAssignableFrom(self.keyClass)
+ val convertValue = !classOf[Writable].isAssignableFrom(self.valueClass)
logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," +
valueClass.getSimpleName + ")" )
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index 0bbda25a905cd..802b0bdfb2d59 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -20,6 +20,8 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.serializer.Serializer
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
@@ -27,26 +29,28 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
}
/**
+ * :: DeveloperApi ::
* The resulting RDD from a shuffle (e.g. repartitioning of data).
* @param prev the parent RDD.
* @param part the partitioner used to partition the RDD
* @tparam K the key class.
* @tparam V the value class.
*/
+@DeveloperApi
class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
@transient var prev: RDD[P],
part: Partitioner)
extends RDD[P](prev.context, Nil) {
- private var serializerClass: String = null
+ private var serializer: Serializer = null
- def setSerializer(cls: String): ShuffledRDD[K, V, P] = {
- serializerClass = cls
+ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
+ this.serializer = serializer
this
}
override def getDependencies: Seq[Dependency[_]] = {
- List(new ShuffleDependency(prev, part, serializerClass))
+ List(new ShuffleDependency(prev, part, serializer))
}
override val partitioner = Some(part)
@@ -57,8 +61,8 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
- SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf))
+ val ser = Serializer.getSerializer(serializer)
+ SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 5fe9f363db453..9a09c05bbc959 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -30,6 +30,7 @@ import org.apache.spark.Partitioner
import org.apache.spark.ShuffleDependency
import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext
+import org.apache.spark.serializer.Serializer
/**
* An optimized version of cogroup for set difference/subtraction.
@@ -53,10 +54,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {
- private var serializerClass: String = null
+ private var serializer: Serializer = null
- def setSerializer(cls: String): SubtractedRDD[K, V, W] = {
- serializerClass = cls
+ def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
+ this.serializer = serializer
this
}
@@ -67,7 +68,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
- new ShuffleDependency(rdd, part, serializerClass)
+ new ShuffleDependency(rdd, part, serializer)
}
}
}
@@ -92,7 +93,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
- val serializer = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)
+ val ser = Serializer.getSerializer(serializer)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
@@ -105,14 +106,13 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
}
}
def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
- }
- case ShuffleCoGroupSplitDep(shuffleId) => {
+
+ case ShuffleCoGroupSplitDep(shuffleId) =>
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
- context, serializer)
+ context, ser)
iter.foreach(op)
- }
}
// the first dep is rdd1; add all values to the map
integrate(partition.deps(0), t => getSeq(t._1) += t._2)
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index a447030752096..21c6e07d69f90 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
+import org.apache.spark.annotation.DeveloperApi
private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitIndex: Int)
extends Partition {
@@ -43,6 +44,7 @@ private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitInd
}
}
+@DeveloperApi
class UnionRDD[T: ClassTag](
sc: SparkContext,
@transient var rdds: Seq[RDD[T]])
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index b56643444aa40..f3d30f6c9b32f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -41,7 +41,7 @@ private[spark] class ZippedPartitionsPartition(
}
}
-abstract class ZippedPartitionsBaseRDD[V: ClassTag](
+private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
sc: SparkContext,
var rdds: Seq[RDD[_]],
preservesPartitioning: Boolean = false)
@@ -74,7 +74,7 @@ abstract class ZippedPartitionsBaseRDD[V: ClassTag](
}
}
-class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](
+private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](
sc: SparkContext,
f: (Iterator[A], Iterator[B]) => Iterator[V],
var rdd1: RDD[A],
@@ -94,7 +94,7 @@ class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](
}
}
-class ZippedPartitionsRDD3
+private[spark] class ZippedPartitionsRDD3
[A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag](
sc: SparkContext,
f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
@@ -119,7 +119,7 @@ class ZippedPartitionsRDD3
}
}
-class ZippedPartitionsRDD4
+private[spark] class ZippedPartitionsRDD4
[A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag](
sc: SparkContext,
f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
index 2119e76f0e032..b8110ffc42f2d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
@@ -44,7 +44,7 @@ private[spark] class ZippedPartition[T: ClassTag, U: ClassTag](
}
}
-class ZippedRDD[T: ClassTag, U: ClassTag](
+private[spark] class ZippedRDD[T: ClassTag, U: ClassTag](
sc: SparkContext,
var rdd1: RDD[T],
var rdd2: RDD[U])
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala
new file mode 100644
index 0000000000000..cd5d44ad4a7e6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+/**
+ * A simple listener for application events.
+ *
+ * This listener expects to hear events from a single application only. If events
+ * from multiple applications are seen, the behavior is unspecified.
+ */
+private[spark] class ApplicationEventListener extends SparkListener {
+ var appName = ""
+ var sparkUser = ""
+ var startTime = -1L
+ var endTime = -1L
+ var viewAcls = ""
+ var enableViewAcls = false
+
+ def applicationStarted = startTime != -1
+
+ def applicationCompleted = endTime != -1
+
+ def applicationDuration: Long = {
+ val difference = endTime - startTime
+ if (applicationStarted && applicationCompleted && difference > 0) difference else -1L
+ }
+
+ override def onApplicationStart(applicationStart: SparkListenerApplicationStart) {
+ appName = applicationStart.appName
+ startTime = applicationStart.time
+ sparkUser = applicationStart.sparkUser
+ }
+
+ override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) {
+ endTime = applicationEnd.time
+ }
+
+ override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) {
+ synchronized {
+ val environmentDetails = environmentUpdate.environmentDetails
+ val allProperties = environmentDetails("Spark Properties").toMap
+ viewAcls = allProperties.getOrElse("spark.ui.view.acls", "")
+ enableViewAcls = allProperties.getOrElse("spark.ui.acls.enable", "false").toBoolean
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index dc5b25d845dc2..ff411e24a3d85 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -22,17 +22,23 @@ import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+import scala.concurrent.Await
import scala.concurrent.duration._
+import scala.language.postfixOps
import scala.reflect.ClassTag
import akka.actor._
+import akka.actor.OneForOneStrategy
+import akka.actor.SupervisorStrategy.Stop
+import akka.pattern.ask
+import akka.util.Timeout
import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId}
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
+import org.apache.spark.util.Utils
/**
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
@@ -47,24 +53,84 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
* not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task
* a small number of times before cancelling the whole stage.
*
- * THREADING: This class runs all its logic in a single thread executing the run() method, to which
- * events are submitted using a synchronized queue (eventQueue). The public API methods, such as
- * runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods
- * should be private.
*/
private[spark]
class DAGScheduler(
- taskSched: TaskScheduler,
+ private[scheduler] val sc: SparkContext,
+ private[scheduler] val taskScheduler: TaskScheduler,
+ listenerBus: LiveListenerBus,
mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
extends Logging {
- def this(taskSched: TaskScheduler) {
- this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
- SparkEnv.get.blockManager.master, SparkEnv.get)
+ import DAGScheduler._
+
+ def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
+ this(
+ sc,
+ taskScheduler,
+ sc.listenerBus,
+ sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+ sc.env.blockManager.master,
+ sc.env)
+ }
+
+ def this(sc: SparkContext) = this(sc, sc.taskScheduler)
+
+ private[scheduler] val nextJobId = new AtomicInteger(0)
+ private[scheduler] def numTotalJobs: Int = nextJobId.get()
+ private val nextStageId = new AtomicInteger(0)
+
+ private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]]
+ private[scheduler] val stageIdToJobIds = new HashMap[Int, HashSet[Int]]
+ private[scheduler] val stageIdToStage = new HashMap[Int, Stage]
+ private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage]
+ private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob]
+ private[scheduler] val resultStageToJob = new HashMap[Stage, ActiveJob]
+ private[scheduler] val stageToInfos = new HashMap[Stage, StageInfo]
+
+ // Stages we need to run whose parents aren't done
+ private[scheduler] val waitingStages = new HashSet[Stage]
+
+ // Stages we are running right now
+ private[scheduler] val runningStages = new HashSet[Stage]
+
+ // Stages that must be resubmitted due to fetch failures
+ private[scheduler] val failedStages = new HashSet[Stage]
+
+ // Missing tasks from each stage
+ private[scheduler] val pendingTasks = new HashMap[Stage, HashSet[Task[_]]]
+
+ private[scheduler] val activeJobs = new HashSet[ActiveJob]
+
+ // Contains the locations that each RDD's partitions are cached on
+ private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
+
+ // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with
+ // every task. When we detect a node failing, we note the current epoch number and failed
+ // executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask results.
+ //
+ // TODO: Garbage collect information about failure epochs when we know there are no more
+ // stray messages to detect.
+ private val failedEpoch = new HashMap[String, Long]
+
+ private val dagSchedulerActorSupervisor =
+ env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))
+
+ private[scheduler] var eventProcessActor: ActorRef = _
+
+ private def initializeEventProcessActor() {
+ // blocking the thread until supervisor is started, which ensures eventProcessActor is
+ // not null before any job is submitted
+ implicit val timeout = Timeout(30 seconds)
+ val initEventActorReply =
+ dagSchedulerActorSupervisor ? Props(new DAGSchedulerEventProcessActor(this))
+ eventProcessActor = Await.result(initEventActorReply, timeout.duration).
+ asInstanceOf[ActorRef]
}
- taskSched.setDAGScheduler(this)
+
+ initializeEventProcessActor()
// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
@@ -72,8 +138,8 @@ class DAGScheduler(
}
// Called to report that a task has completed and results are being fetched remotely.
- def taskGettingResult(task: Task[_], taskInfo: TaskInfo) {
- eventProcessActor ! GettingResultEvent(task, taskInfo)
+ def taskGettingResult(taskInfo: TaskInfo) {
+ eventProcessActor ! GettingResultEvent(taskInfo)
}
// Called by TaskScheduler to report task completions or failures.
@@ -93,8 +159,8 @@ class DAGScheduler(
}
// Called by TaskScheduler when a host is added
- def executorGained(execId: String, host: String) {
- eventProcessActor ! ExecutorGained(execId, host)
+ def executorAdded(execId: String, host: String) {
+ eventProcessActor ! ExecutorAdded(execId, host)
}
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
@@ -103,106 +169,9 @@ class DAGScheduler(
eventProcessActor ! TaskSetFailed(taskSet, reason)
}
- // The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
- // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
- // as more failure events come in
- val RESUBMIT_TIMEOUT = 200.milliseconds
-
- // The time, in millis, to wake up between polls of the completion queue in order to potentially
- // resubmit failed stages
- val POLL_TIMEOUT = 10L
-
- // Warns the user if a stage contains a task with size greater than this value (in KB)
- val TASK_SIZE_TO_WARN = 100
-
- private var eventProcessActor: ActorRef = _
-
- private[scheduler] val nextJobId = new AtomicInteger(0)
-
- def numTotalJobs: Int = nextJobId.get()
-
- private val nextStageId = new AtomicInteger(0)
-
- private[scheduler] val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]]
-
- private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]]
-
- private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage]
-
- private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
-
- private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
-
- // An async scheduler event bus. The bus should be stopped when DAGSCheduler is stopped.
- private[spark] val listenerBus = new SparkListenerBus
-
- // Contains the locations that each RDD's partitions are cached on
- private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
-
- // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with
- // every task. When we detect a node failing, we note the current epoch number and failed
- // executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask results.
- //
- // TODO: Garbage collect information about failure epochs when we know there are no more
- // stray messages to detect.
- val failedEpoch = new HashMap[String, Long]
-
- // stage id to the active job
- val idToActiveJob = new HashMap[Int, ActiveJob]
-
- val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
- val running = new HashSet[Stage] // Stages we are running right now
- val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
- // Missing tasks from each stage
- val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]]
-
- val activeJobs = new HashSet[ActiveJob]
- val resultStageToJob = new HashMap[Stage, ActiveJob]
-
- val metadataCleaner = new MetadataCleaner(
- MetadataCleanerType.DAG_SCHEDULER, this.cleanup, env.conf)
-
- /**
- * Starts the event processing actor. The actor has two responsibilities:
- *
- * 1. Waits for events like job submission, task finished, task failure etc., and calls
- * [[org.apache.spark.scheduler.DAGScheduler.processEvent()]] to process them.
- * 2. Schedules a periodical task to resubmit failed stages.
- *
- * NOTE: the actor cannot be started in the constructor, because the periodical task references
- * some internal states of the enclosing [[org.apache.spark.scheduler.DAGScheduler]] object, thus
- * cannot be scheduled until the [[org.apache.spark.scheduler.DAGScheduler]] is fully constructed.
- */
- def start() {
- eventProcessActor = env.actorSystem.actorOf(Props(new Actor {
- /**
- * The main event loop of the DAG scheduler.
- */
- def receive = {
- case event: DAGSchedulerEvent =>
- logTrace("Got event of type " + event.getClass.getName)
-
- /**
- * All events are forwarded to `processEvent()`, so that the event processing logic can
- * easily tested without starting a dedicated actor. Please refer to `DAGSchedulerSuite`
- * for details.
- */
- if (!processEvent(event)) {
- submitWaitingStages()
- } else {
- context.stop(self)
- }
- }
- }))
- }
-
- def addSparkListener(listener: SparkListener) {
- listenerBus.addListener(listener)
- }
-
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
if (!cacheLocs.contains(rdd.id)) {
- val blockIds = rdd.partitions.indices.map(index=> RDDBlockId(rdd.id, index)).toArray[BlockId]
+ val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
cacheLocs(rdd.id) = blockIds.map { id =>
locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
@@ -250,7 +219,7 @@ class DAGScheduler(
new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
updateJobIdStageIdMaps(jobId, stage)
- stageToInfos(stage) = new StageInfo(stage)
+ stageToInfos(stage) = StageInfo.fromStage(stage)
stage
}
@@ -269,7 +238,7 @@ class DAGScheduler(
: Stage =
{
val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite)
- if (mapOutputTracker.has(shuffleDep.shuffleId)) {
+ if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
for (i <- 0 until locs.size) {
@@ -279,7 +248,7 @@ class DAGScheduler(
} else {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of partitions is unknown
- logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
+ logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size)
}
stage
@@ -356,29 +325,31 @@ class DAGScheduler(
}
/**
- * Removes job and any stages that are not needed by any other job. Returns the set of ids for
- * stages that were removed. The associated tasks for those stages need to be cancelled if we
- * got here via job cancellation.
+ * Removes state for job and any stages that are not needed by any other job. Does not
+ * handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks.
+ *
+ * @param job The job whose state to cleanup.
+ * @param resultStage Specifies the result stage for the job; if set to None, this method
+ * searches resultStagesToJob to find and cleanup the appropriate result stage.
*/
- private def removeJobAndIndependentStages(jobId: Int): Set[Int] = {
- val registeredStages = jobIdToStageIds(jobId)
- val independentStages = new HashSet[Int]()
- if (registeredStages.isEmpty) {
- logError("No stages registered for job " + jobId)
+ private def cleanupStateForJobAndIndependentStages(job: ActiveJob, resultStage: Option[Stage]) {
+ val registeredStages = jobIdToStageIds.get(job.jobId)
+ if (registeredStages.isEmpty || registeredStages.get.isEmpty) {
+ logError("No stages registered for job " + job.jobId)
} else {
- stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach {
+ stageIdToJobIds.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach {
case (stageId, jobSet) =>
- if (!jobSet.contains(jobId)) {
+ if (!jobSet.contains(job.jobId)) {
logError(
"Job %d not registered for stage %d even though that stage was registered for the job"
- .format(jobId, stageId))
+ .format(job.jobId, stageId))
} else {
def removeStage(stageId: Int) {
// data structures based on Stage
for (stage <- stageIdToStage.get(stageId)) {
- if (running.contains(stage)) {
+ if (runningStages.contains(stage)) {
logDebug("Removing running stage %d".format(stageId))
- running -= stage
+ runningStages -= stage
}
stageToInfos -= stage
for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) {
@@ -388,40 +359,48 @@ class DAGScheduler(
logDebug("Removing pending status for stage %d".format(stageId))
}
pendingTasks -= stage
- if (waiting.contains(stage)) {
+ if (waitingStages.contains(stage)) {
logDebug("Removing stage %d from waiting set.".format(stageId))
- waiting -= stage
+ waitingStages -= stage
}
- if (failed.contains(stage)) {
+ if (failedStages.contains(stage)) {
logDebug("Removing stage %d from failed set.".format(stageId))
- failed -= stage
+ failedStages -= stage
}
}
// data structures based on StageId
stageIdToStage -= stageId
stageIdToJobIds -= stageId
+ ShuffleMapTask.removeStage(stageId)
+ ResultTask.removeStage(stageId)
+
logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}
- jobSet -= jobId
+ jobSet -= job.jobId
if (jobSet.isEmpty) { // no other job needs this stage
- independentStages += stageId
removeStage(stageId)
}
}
}
}
- independentStages.toSet
- }
-
- private def jobIdToStageIdsRemove(jobId: Int) {
- if (!jobIdToStageIds.contains(jobId)) {
- logDebug("Trying to remove unregistered job " + jobId)
+ jobIdToStageIds -= job.jobId
+ jobIdToActiveJob -= job.jobId
+ activeJobs -= job
+
+ if (resultStage.isEmpty) {
+ // Clean up result stages.
+ val resultStagesForJob = resultStageToJob.keySet.filter(
+ stage => resultStageToJob(stage).jobId == job.jobId)
+ if (resultStagesForJob.size != 1) {
+ logWarning(
+ s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)")
+ }
+ resultStageToJob --= resultStagesForJob
} else {
- removeJobAndIndependentStages(jobId)
- jobIdToStageIds -= jobId
+ resultStageToJob -= resultStage.get
}
}
@@ -440,7 +419,7 @@ class DAGScheduler(
{
// Check to make sure we are not launching a task on a partition that does not exist.
val maxPartitions = rdd.partitions.length
- partitions.find(p => p >= maxPartitions).foreach { p =>
+ partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
throw new IllegalArgumentException(
"Attempting to access a non-existent partition: " + p + ". " +
"Total number of partitions: " + maxPartitions)
@@ -471,7 +450,7 @@ class DAGScheduler(
val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
waiter.awaitResult() match {
case JobSucceeded => {}
- case JobFailed(exception: Exception, _) =>
+ case JobFailed(exception: Exception) =>
logInfo("Failed to run " + callSite)
throw exception
}
@@ -515,111 +494,20 @@ class DAGScheduler(
eventProcessActor ! AllJobsCancelled
}
+ private[scheduler] def doCancelAllJobs() {
+ // Cancel all running jobs.
+ runningStages.map(_.jobId).foreach(handleJobCancellation(_,
+ reason = "as part of cancellation of all jobs"))
+ activeJobs.clear() // These should already be empty by this point,
+ jobIdToActiveJob.clear() // but just in case we lost track of some jobs...
+ submitWaitingStages()
+ }
+
/**
- * Process one event retrieved from the event processing actor.
- *
- * @param event The event to be processed.
- * @return `true` if we should stop the event loop.
+ * Cancel all jobs associated with a running or scheduled stage.
*/
- private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
- event match {
- case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
- var finalStage: Stage = null
- try {
- // New stage creation may throw an exception if, for example, jobs are run on a HadoopRDD
- // whose underlying HDFS files have been deleted.
- finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
- } catch {
- case e: Exception =>
- logWarning("Creating new stage failed due to exception - job: " + jobId, e)
- listener.jobFailed(e)
- return false
- }
- val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
- clearCacheLocs()
- logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
- " output partitions (allowLocal=" + allowLocal + ")")
- logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")")
- logInfo("Parents of final stage: " + finalStage.parents)
- logInfo("Missing parents: " + getMissingParentStages(finalStage))
- if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
- // Compute very short actions like first() or take() with no parent stages locally.
- listenerBus.post(SparkListenerJobStart(job, Array(), properties))
- runLocally(job)
- } else {
- idToActiveJob(jobId) = job
- activeJobs += job
- resultStageToJob(finalStage) = job
- listenerBus.post(SparkListenerJobStart(job, jobIdToStageIds(jobId).toArray, properties))
- submitStage(finalStage)
- }
-
- case JobCancelled(jobId) =>
- handleJobCancellation(jobId)
-
- case JobGroupCancelled(groupId) =>
- // Cancel all jobs belonging to this job group.
- // First finds all active jobs with this group id, and then kill stages for them.
- val activeInGroup = activeJobs.filter(activeJob =>
- groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
- val jobIds = activeInGroup.map(_.jobId)
- jobIds.foreach { handleJobCancellation }
-
- case AllJobsCancelled =>
- // Cancel all running jobs.
- running.map(_.jobId).foreach { handleJobCancellation }
- activeJobs.clear() // These should already be empty by this point,
- idToActiveJob.clear() // but just in case we lost track of some jobs...
-
- case ExecutorGained(execId, host) =>
- handleExecutorGained(execId, host)
-
- case ExecutorLost(execId) =>
- handleExecutorLost(execId)
-
- case BeginEvent(task, taskInfo) =>
- for (
- job <- idToActiveJob.get(task.stageId);
- stage <- stageIdToStage.get(task.stageId);
- stageInfo <- stageToInfos.get(stage)
- ) {
- if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 &&
- !stageInfo.emittedTaskSizeWarning) {
- stageInfo.emittedTaskSizeWarning = true
- logWarning(("Stage %d (%s) contains a task of very large " +
- "size (%d KB). The maximum recommended task size is %d KB.").format(
- task.stageId, stageInfo.name, taskInfo.serializedSize / 1024, TASK_SIZE_TO_WARN))
- }
- }
- listenerBus.post(SparkListenerTaskStart(task, taskInfo))
-
- case GettingResultEvent(task, taskInfo) =>
- listenerBus.post(SparkListenerTaskGettingResult(task, taskInfo))
-
- case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
- listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics))
- handleTaskCompletion(completion)
-
- case TaskSetFailed(taskSet, reason) =>
- stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason) }
-
- case ResubmitFailedStages =>
- if (failed.size > 0) {
- // Failed stages may be removed by job cancellation, so failed might be empty even if
- // the ResubmitFailedStages event has been scheduled.
- resubmitFailedStages()
- }
-
- case StopDAGScheduler =>
- // Cancel any active jobs
- for (job <- activeJobs) {
- val error = new SparkException("Job cancelled because SparkContext was shut down")
- job.listener.jobFailed(error)
- listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, None)))
- }
- return true
- }
- false
+ def cancelStage(stageId: Int) {
+ eventProcessActor ! StageCancelled(stageId)
}
/**
@@ -627,29 +515,34 @@ class DAGScheduler(
* the last fetch failure.
*/
private[scheduler] def resubmitFailedStages() {
- logInfo("Resubmitting failed stages")
- clearCacheLocs()
- val failed2 = failed.toArray
- failed.clear()
- for (stage <- failed2.sortBy(_.jobId)) {
- submitStage(stage)
+ if (failedStages.size > 0) {
+ // Failed stages may be removed by job cancellation, so failed might be empty even if
+ // the ResubmitFailedStages event has been scheduled.
+ logInfo("Resubmitting failed stages")
+ clearCacheLocs()
+ val failedStagesCopy = failedStages.toArray
+ failedStages.clear()
+ for (stage <- failedStagesCopy.sortBy(_.jobId)) {
+ submitStage(stage)
+ }
}
+ submitWaitingStages()
}
/**
* Check for waiting or failed stages which are now eligible for resubmission.
* Ordinarily run on every iteration of the event loop.
*/
- private[scheduler] def submitWaitingStages() {
+ private def submitWaitingStages() {
// TODO: We might want to run this less often, when we are sure that something has become
// runnable that wasn't before.
logTrace("Checking for newly runnable parent stages")
- logTrace("running: " + running)
- logTrace("waiting: " + waiting)
- logTrace("failed: " + failed)
- val waiting2 = waiting.toArray
- waiting.clear()
- for (stage <- waiting2.sortBy(_.jobId)) {
+ logTrace("running: " + runningStages)
+ logTrace("waiting: " + waitingStages)
+ logTrace("failed: " + failedStages)
+ val waitingStagesCopy = waitingStages.toArray
+ waitingStages.clear()
+ for (stage <- waitingStagesCopy.sortBy(_.jobId)) {
submitStage(stage)
}
}
@@ -685,7 +578,7 @@ class DAGScheduler(
}
} catch {
case e: Exception =>
- jobResult = JobFailed(e, Some(job.finalStage))
+ jobResult = JobFailed(e)
job.listener.jobFailed(e)
} finally {
val s = job.finalStage
@@ -693,7 +586,7 @@ class DAGScheduler(
stageIdToStage -= s.id // but that won't get cleaned up via the normal paths through
stageToInfos -= s // completion events or stage abort
jobIdToStageIds -= job.jobId
- listenerBus.post(SparkListenerJobEnd(job, jobResult))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, jobResult))
}
}
@@ -705,29 +598,125 @@ class DAGScheduler(
private def activeJobForStage(stage: Stage): Option[Int] = {
if (stageIdToJobIds.contains(stage.id)) {
val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted
- jobsThatUseStage.find(idToActiveJob.contains(_))
+ jobsThatUseStage.find(jobIdToActiveJob.contains)
} else {
None
}
}
+ private[scheduler] def handleJobGroupCancelled(groupId: String) {
+ // Cancel all jobs belonging to this job group.
+ // First finds all active jobs with this group id, and then kill stages for them.
+ val activeInGroup = activeJobs.filter(activeJob =>
+ groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
+ val jobIds = activeInGroup.map(_.jobId)
+ jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId)))
+ submitWaitingStages()
+ }
+
+ private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) {
+ for (stage <- stageIdToStage.get(task.stageId); stageInfo <- stageToInfos.get(stage)) {
+ if (taskInfo.serializedSize > DAGScheduler.TASK_SIZE_TO_WARN * 1024 &&
+ !stageInfo.emittedTaskSizeWarning) {
+ stageInfo.emittedTaskSizeWarning = true
+ logWarning(("Stage %d (%s) contains a task of very large " +
+ "size (%d KB). The maximum recommended task size is %d KB.").format(
+ task.stageId, stageInfo.name, taskInfo.serializedSize / 1024,
+ DAGScheduler.TASK_SIZE_TO_WARN))
+ }
+ }
+ listenerBus.post(SparkListenerTaskStart(task.stageId, taskInfo))
+ submitWaitingStages()
+ }
+
+ private[scheduler] def handleTaskSetFailed(taskSet: TaskSet, reason: String) {
+ stageIdToStage.get(taskSet.stageId).foreach {abortStage(_, reason) }
+ submitWaitingStages()
+ }
+
+ private[scheduler] def cleanUpAfterSchedulerStop() {
+ for (job <- activeJobs) {
+ val error = new SparkException("Job cancelled because SparkContext was shut down")
+ job.listener.jobFailed(error)
+ // Tell the listeners that all of the running stages have ended. Don't bother
+ // cancelling the stages because if the DAG scheduler is stopped, the entire application
+ // is in the process of getting stopped.
+ val stageFailedMessage = "Stage cancelled because SparkContext was shut down"
+ runningStages.foreach { stage =>
+ val info = stageToInfos(stage)
+ info.stageFailed(stageFailedMessage)
+ listenerBus.post(SparkListenerStageCompleted(info))
+ }
+ listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
+ }
+ }
+
+ private[scheduler] def handleGetTaskResult(taskInfo: TaskInfo) {
+ listenerBus.post(SparkListenerTaskGettingResult(taskInfo))
+ submitWaitingStages()
+ }
+
+ private[scheduler] def handleJobSubmitted(jobId: Int,
+ finalRDD: RDD[_],
+ func: (TaskContext, Iterator[_]) => _,
+ partitions: Array[Int],
+ allowLocal: Boolean,
+ callSite: String,
+ listener: JobListener,
+ properties: Properties = null)
+ {
+ var finalStage: Stage = null
+ try {
+ // New stage creation may throw an exception if, for example, jobs are run on a
+ // HadoopRDD whose underlying HDFS files have been deleted.
+ finalStage = newStage(finalRDD, partitions.size, None, jobId, Some(callSite))
+ } catch {
+ case e: Exception =>
+ logWarning("Creating new stage failed due to exception - job: " + jobId, e)
+ listener.jobFailed(e)
+ return
+ }
+ if (finalStage != null) {
+ val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
+ clearCacheLocs()
+ logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format(
+ job.jobId, callSite, partitions.length, allowLocal))
+ logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")")
+ logInfo("Parents of final stage: " + finalStage.parents)
+ logInfo("Missing parents: " + getMissingParentStages(finalStage))
+ if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
+ // Compute very short actions like first() or take() with no parent stages locally.
+ listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties))
+ runLocally(job)
+ } else {
+ jobIdToActiveJob(jobId) = job
+ activeJobs += job
+ resultStageToJob(finalStage) = job
+ listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray,
+ properties))
+ submitStage(finalStage)
+ }
+ }
+ submitWaitingStages()
+ }
+
/** Submits stage, but first recursively submits any missing parents. */
private def submitStage(stage: Stage) {
val jobId = activeJobForStage(stage)
if (jobId.isDefined) {
logDebug("submitStage(" + stage + ")")
- if (!waiting(stage) && !running(stage) && !failed(stage)) {
+ if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
if (missing == Nil) {
logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
submitMissingTasks(stage, jobId.get)
- running += stage
+ runningStages += stage
} else {
for (parent <- missing) {
submitStage(parent)
}
- waiting += stage
+ waitingStages += stage
}
}
} else {
@@ -758,10 +747,10 @@ class DAGScheduler(
}
}
- val properties = if (idToActiveJob.contains(jobId)) {
- idToActiveJob(stage.jobId).properties
+ val properties = if (jobIdToActiveJob.contains(jobId)) {
+ jobIdToActiveJob(stage.jobId).properties
} else {
- //this stage will be assigned to "default" pool
+ // this stage will be assigned to "default" pool
null
}
@@ -779,20 +768,20 @@ class DAGScheduler(
} catch {
case e: NotSerializableException =>
abortStage(stage, "Task not serializable: " + e.toString)
- running -= stage
+ runningStages -= stage
return
}
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
- taskSched.submitTasks(
+ taskScheduler.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
stageToInfos(stage).submissionTime = Some(System.currentTimeMillis())
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
- running -= stage
+ runningStages -= stage
}
}
@@ -800,9 +789,12 @@ class DAGScheduler(
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
*/
- private def handleTaskCompletion(event: CompletionEvent) {
+ private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
-
+ val stageId = task.stageId
+ val taskType = Utils.getFormattedClassName(task)
+ listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo,
+ event.taskMetrics))
if (!stageIdToStage.contains(task.stageId)) {
// Skip all the actions if the stage has been cancelled.
return
@@ -817,7 +809,7 @@ class DAGScheduler(
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
- running -= stage
+ runningStages -= stage
}
event.reason match {
case Success =>
@@ -826,7 +818,6 @@ class DAGScheduler(
Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
}
pendingTasks(stage) -= task
- stageToInfos(stage).taskInfos += event.taskInfo -> event.taskMetrics
task match {
case rt: ResultTask[_, _] =>
resultStageToJob.get(stage) match {
@@ -836,12 +827,9 @@ class DAGScheduler(
job.numFinished += 1
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
- idToActiveJob -= stage.jobId
- activeJobs -= job
- resultStageToJob -= stage
markStageAsFinished(stage)
- jobIdToStageIdsRemove(job.jobId)
- listenerBus.post(SparkListenerJobEnd(job, JobSucceeded))
+ cleanupStateForJobAndIndependentStages(job, Some(stage))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded))
}
job.listener.taskSucceeded(rt.outputId, event.result)
}
@@ -858,12 +846,12 @@ class DAGScheduler(
} else {
stage.addOutputLoc(smt.partitionId, status)
}
- if (running.contains(stage) && pendingTasks(stage).isEmpty) {
+ if (runningStages.contains(stage) && pendingTasks(stage).isEmpty) {
markStageAsFinished(stage)
logInfo("looking for newly runnable stages")
- logInfo("running: " + running)
- logInfo("waiting: " + waiting)
- logInfo("failed: " + failed)
+ logInfo("running: " + runningStages)
+ logInfo("waiting: " + waitingStages)
+ logInfo("failed: " + failedStages)
if (stage.shuffleDep.isDefined) {
// We supply true to increment the epoch number here in case this is a
// recomputation of the map outputs. In that case, some nodes may have cached
@@ -886,14 +874,14 @@ class DAGScheduler(
submitStage(stage)
} else {
val newlyRunnable = new ArrayBuffer[Stage]
- for (stage <- waiting) {
+ for (stage <- waitingStages) {
logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage))
}
- for (stage <- waiting if getMissingParentStages(stage) == Nil) {
+ for (stage <- waitingStages if getMissingParentStages(stage) == Nil) {
newlyRunnable += stage
}
- waiting --= newlyRunnable
- running ++= newlyRunnable
+ waitingStages --= newlyRunnable
+ runningStages ++= newlyRunnable
for {
stage <- newlyRunnable.sortBy(_.id)
jobId <- activeJobForStage(stage)
@@ -912,7 +900,7 @@ class DAGScheduler(
case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
// Mark the stage that the reducer was in as unrunnable
val failedStage = stageIdToStage(task.stageId)
- running -= failedStage
+ runningStages -= failedStage
// TODO: Cancel running tasks in the stage
logInfo("Marking " + failedStage + " (" + failedStage.name +
") for resubmision due to a fetch failure")
@@ -924,7 +912,7 @@ class DAGScheduler(
}
logInfo("The failed fetch was from " + mapStage + " (" + mapStage.name +
"); marking it for resubmission")
- if (failed.isEmpty && eventProcessActor != null) {
+ if (failedStages.isEmpty && eventProcessActor != null) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled. eventProcessActor may be
// null during unit tests.
@@ -932,8 +920,8 @@ class DAGScheduler(
env.actorSystem.scheduler.scheduleOnce(
RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages)
}
- failed += failedStage
- failed += mapStage
+ failedStages += failedStage
+ failedStages += mapStage
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, Some(task.epoch))
@@ -949,6 +937,7 @@ class DAGScheduler(
// Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler
// will abort the job.
}
+ submitWaitingStages()
}
/**
@@ -958,7 +947,7 @@ class DAGScheduler(
* Optionally the epoch during which the failure was caught can be passed to avoid allowing
* stray fetch failures from possibly retriggering the detection of a node as lost.
*/
- private def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) {
+ private[scheduler] def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) {
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) {
failedEpoch(execId) = currentEpoch
@@ -978,37 +967,45 @@ class DAGScheduler(
logDebug("Additional executor lost message for " + execId +
"(epoch " + currentEpoch + ")")
}
+ submitWaitingStages()
}
- private def handleExecutorGained(execId: String, host: String) {
+ private[scheduler] def handleExecutorAdded(execId: String, host: String) {
// remove from failedEpoch(execId) ?
if (failedEpoch.contains(execId)) {
- logInfo("Host gained which was in lost list earlier: " + host)
+ logInfo("Host added was in lost list earlier: " + host)
failedEpoch -= execId
}
+ submitWaitingStages()
+ }
+
+ private[scheduler] def handleStageCancellation(stageId: Int) {
+ if (stageIdToJobIds.contains(stageId)) {
+ val jobsThatUseStage: Array[Int] = stageIdToJobIds(stageId).toArray
+ jobsThatUseStage.foreach(jobId => {
+ handleJobCancellation(jobId, "because Stage %s was cancelled".format(stageId))
+ })
+ } else {
+ logInfo("No active jobs to kill for Stage " + stageId)
+ }
+ submitWaitingStages()
}
- private def handleJobCancellation(jobId: Int) {
+ private[scheduler] def handleJobCancellation(jobId: Int, reason: String = "") {
if (!jobIdToStageIds.contains(jobId)) {
logDebug("Trying to cancel unregistered job " + jobId)
} else {
- val independentStages = removeJobAndIndependentStages(jobId)
- independentStages.foreach { taskSched.cancelTasks }
- val error = new SparkException("Job %d cancelled".format(jobId))
- val job = idToActiveJob(jobId)
- job.listener.jobFailed(error)
- jobIdToStageIds -= jobId
- activeJobs -= job
- idToActiveJob -= jobId
- listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(job.finalStage))))
+ failJobAndIndependentStages(jobIdToActiveJob(jobId),
+ "Job %d cancelled %s".format(jobId, reason), None)
}
+ submitWaitingStages()
}
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
- private def abortStage(failedStage: Stage, reason: String) {
+ private[scheduler] def abortStage(failedStage: Stage, reason: String) {
if (!stageIdToStage.contains(failedStage.id)) {
// Skip all the actions if the stage has been removed.
return
@@ -1017,19 +1014,61 @@ class DAGScheduler(
stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
- val error = new SparkException("Job aborted: " + reason)
- job.listener.jobFailed(error)
- jobIdToStageIdsRemove(job.jobId)
- idToActiveJob -= resultStage.jobId
- activeJobs -= job
- resultStageToJob -= resultStage
- listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
+ failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason",
+ Some(resultStage))
}
if (dependentStages.isEmpty) {
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
}
}
+ /**
+ * Fails a job and all stages that are only used by that job, and cleans up relevant state.
+ *
+ * @param resultStage The result stage for the job, if known. Used to cleanup state for the job
+ * slightly more efficiently than when not specified.
+ */
+ private def failJobAndIndependentStages(job: ActiveJob, failureReason: String,
+ resultStage: Option[Stage]) {
+ val error = new SparkException(failureReason)
+ job.listener.jobFailed(error)
+
+ val shouldInterruptThread =
+ if (job.properties == null) false
+ else job.properties.getProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false").toBoolean
+
+ // Cancel all independent, running stages.
+ val stages = jobIdToStageIds(job.jobId)
+ if (stages.isEmpty) {
+ logError("No stages registered for job " + job.jobId)
+ }
+ stages.foreach { stageId =>
+ val jobsForStage = stageIdToJobIds.get(stageId)
+ if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) {
+ logError(
+ "Job %d not registered for stage %d even though that stage was registered for the job"
+ .format(job.jobId, stageId))
+ } else if (jobsForStage.get.size == 1) {
+ if (!stageIdToStage.contains(stageId)) {
+ logError("Missing Stage for stage with id $stageId")
+ } else {
+ // This is the only job that uses this stage, so fail the stage if it is running.
+ val stage = stageIdToStage(stageId)
+ if (runningStages.contains(stage)) {
+ taskScheduler.cancelTasks(stageId, shouldInterruptThread)
+ val stageInfo = stageToInfos(stage)
+ stageInfo.stageFailed(failureReason)
+ listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
+ }
+ }
+ }
+ }
+
+ cleanupStateForJobAndIndependentStages(job, resultStage)
+
+ listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
+ }
+
/**
* Return true if one of stage's ancestors is target.
*/
@@ -1094,27 +1133,99 @@ class DAGScheduler(
Nil
}
- private def cleanup(cleanupTime: Long) {
- Map(
- "stageIdToStage" -> stageIdToStage,
- "shuffleToMapStage" -> shuffleToMapStage,
- "pendingTasks" -> pendingTasks,
- "stageToInfos" -> stageToInfos,
- "jobIdToStageIds" -> jobIdToStageIds,
- "stageIdToJobIds" -> stageIdToJobIds).
- foreach { case(s, t) => {
- val sizeBefore = t.size
- t.clearOldValues(cleanupTime)
- logInfo("%s %d --> %d".format(s, sizeBefore, t.size))
- }}
+ def stop() {
+ logInfo("Stopping DAGScheduler")
+ dagSchedulerActorSupervisor ! PoisonPill
+ taskScheduler.stop()
}
+}
- def stop() {
- if (eventProcessActor != null) {
- eventProcessActor ! StopDAGScheduler
+private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler)
+ extends Actor with Logging {
+
+ override val supervisorStrategy =
+ OneForOneStrategy() {
+ case x: Exception =>
+ logError("eventProcesserActor failed due to the error %s; shutting down SparkContext"
+ .format(x.getMessage))
+ dagScheduler.doCancelAllJobs()
+ dagScheduler.sc.stop()
+ Stop
}
- metadataCleaner.cancel()
- taskSched.stop()
- listenerBus.stop()
+
+ def receive = {
+ case p: Props => sender ! context.actorOf(p)
+ case _ => logWarning("received unknown message in DAGSchedulerActorSupervisor")
+ }
+}
+
+private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGScheduler)
+ extends Actor with Logging {
+
+ override def preStart() {
+ // set DAGScheduler for taskScheduler to ensure eventProcessActor is always
+ // valid when the messages arrive
+ dagScheduler.taskScheduler.setDAGScheduler(dagScheduler)
}
+
+ /**
+ * The main event loop of the DAG scheduler.
+ */
+ def receive = {
+ case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
+ dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
+ listener, properties)
+
+ case StageCancelled(stageId) =>
+ dagScheduler.handleStageCancellation(stageId)
+
+ case JobCancelled(jobId) =>
+ dagScheduler.handleJobCancellation(jobId)
+
+ case JobGroupCancelled(groupId) =>
+ dagScheduler.handleJobGroupCancelled(groupId)
+
+ case AllJobsCancelled =>
+ dagScheduler.doCancelAllJobs()
+
+ case ExecutorAdded(execId, host) =>
+ dagScheduler.handleExecutorAdded(execId, host)
+
+ case ExecutorLost(execId) =>
+ dagScheduler.handleExecutorLost(execId)
+
+ case BeginEvent(task, taskInfo) =>
+ dagScheduler.handleBeginEvent(task, taskInfo)
+
+ case GettingResultEvent(taskInfo) =>
+ dagScheduler.handleGetTaskResult(taskInfo)
+
+ case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
+ dagScheduler.handleTaskCompletion(completion)
+
+ case TaskSetFailed(taskSet, reason) =>
+ dagScheduler.handleTaskSetFailed(taskSet, reason)
+
+ case ResubmitFailedStages =>
+ dagScheduler.resubmitFailedStages()
+ }
+
+ override def postStop() {
+ // Cancel any active jobs in postStop hook
+ dagScheduler.cleanUpAfterSchedulerStop()
+ }
+}
+
+private[spark] object DAGScheduler {
+ // The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
+ // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
+ // as more failure events come in
+ val RESUBMIT_TIMEOUT = 200.milliseconds
+
+ // The time, in millis, to wake up between polls of the completion queue in order to potentially
+ // resubmit failed stages
+ val POLL_TIMEOUT = 10L
+
+ // Warns the user if a stage contains a task with size greater than this value (in KB)
+ val TASK_SIZE_TO_WARN = 100
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 39cd98e2d74e4..23f57441b4b11 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -20,6 +20,7 @@ package org.apache.spark.scheduler
import java.util.Properties
import scala.collection.mutable.Map
+import scala.language.existentials
import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
@@ -44,6 +45,8 @@ private[scheduler] case class JobSubmitted(
properties: Properties = null)
extends DAGSchedulerEvent
+private[scheduler] case class StageCancelled(stageId: Int) extends DAGSchedulerEvent
+
private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
@@ -54,7 +57,7 @@ private[scheduler]
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
private[scheduler]
-case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+case class GettingResultEvent(taskInfo: TaskInfo) extends DAGSchedulerEvent
private[scheduler] case class CompletionEvent(
task: Task[_],
@@ -65,7 +68,7 @@ private[scheduler] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
-private[scheduler] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
+private[scheduler] case class ExecutorAdded(execId: String, host: String) extends DAGSchedulerEvent
private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
@@ -73,5 +76,3 @@ private[scheduler]
case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent
-
-private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
index b52fe2410abde..5878e733908f5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
@@ -28,15 +28,15 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar
val sourceName = "%s.DAGScheduler".format(sc.appName)
metricRegistry.register(MetricRegistry.name("stage", "failedStages"), new Gauge[Int] {
- override def getValue: Int = dagScheduler.failed.size
+ override def getValue: Int = dagScheduler.failedStages.size
})
metricRegistry.register(MetricRegistry.name("stage", "runningStages"), new Gauge[Int] {
- override def getValue: Int = dagScheduler.running.size
+ override def getValue: Int = dagScheduler.runningStages.size
})
metricRegistry.register(MetricRegistry.name("stage", "waitingStages"), new Gauge[Int] {
- override def getValue: Int = dagScheduler.waiting.size
+ override def getValue: Int = dagScheduler.waitingStages.size
})
metricRegistry.register(MetricRegistry.name("job", "allJobs"), new Gauge[Int] {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
new file mode 100644
index 0000000000000..2fe65cd944b67
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import scala.collection.mutable
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.{Logging, SparkConf, SparkContext}
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.util.{FileLogger, JsonProtocol}
+
+/**
+ * A SparkListener that logs events to persistent storage.
+ *
+ * Event logging is specified by the following configurable parameters:
+ * spark.eventLog.enabled - Whether event logging is enabled.
+ * spark.eventLog.compress - Whether to compress logged events
+ * spark.eventLog.overwrite - Whether to overwrite any existing files.
+ * spark.eventLog.dir - Path to the directory in which events are logged.
+ * spark.eventLog.buffer.kb - Buffer size to use when writing to output streams
+ */
+private[spark] class EventLoggingListener(
+ appName: String,
+ conf: SparkConf,
+ hadoopConfiguration: Configuration)
+ extends SparkListener with Logging {
+
+ import EventLoggingListener._
+
+ private val shouldCompress = conf.getBoolean("spark.eventLog.compress", false)
+ private val shouldOverwrite = conf.getBoolean("spark.eventLog.overwrite", false)
+ private val outputBufferSize = conf.getInt("spark.eventLog.buffer.kb", 100) * 1024
+ private val logBaseDir = conf.get("spark.eventLog.dir", "/tmp/spark-events").stripSuffix("/")
+ private val name = appName.replaceAll("[ :/]", "-").toLowerCase + "-" + System.currentTimeMillis
+ val logDir = logBaseDir + "/" + name
+
+ private val logger =
+ new FileLogger(logDir, conf, hadoopConfiguration, outputBufferSize, shouldCompress,
+ shouldOverwrite)
+
+ /**
+ * Begin logging events.
+ * If compression is used, log a file that indicates which compression library is used.
+ */
+ def start() {
+ logInfo("Logging events to %s".format(logDir))
+ if (shouldCompress) {
+ val codec = conf.get("spark.io.compression.codec", CompressionCodec.DEFAULT_COMPRESSION_CODEC)
+ logger.newFile(COMPRESSION_CODEC_PREFIX + codec)
+ }
+ logger.newFile(SPARK_VERSION_PREFIX + SparkContext.SPARK_VERSION)
+ logger.newFile(LOG_PREFIX + logger.fileIndex)
+ }
+
+ /** Log the event as JSON. */
+ private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) {
+ val eventJson = compact(render(JsonProtocol.sparkEventToJson(event)))
+ logger.logLine(eventJson)
+ if (flushLogger) {
+ logger.flush()
+ }
+ }
+
+ // Events that do not trigger a flush
+ override def onStageSubmitted(event: SparkListenerStageSubmitted) =
+ logEvent(event)
+ override def onTaskStart(event: SparkListenerTaskStart) =
+ logEvent(event)
+ override def onTaskGettingResult(event: SparkListenerTaskGettingResult) =
+ logEvent(event)
+ override def onTaskEnd(event: SparkListenerTaskEnd) =
+ logEvent(event)
+ override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate) =
+ logEvent(event)
+
+ // Events that trigger a flush
+ override def onStageCompleted(event: SparkListenerStageCompleted) =
+ logEvent(event, flushLogger = true)
+ override def onJobStart(event: SparkListenerJobStart) =
+ logEvent(event, flushLogger = true)
+ override def onJobEnd(event: SparkListenerJobEnd) =
+ logEvent(event, flushLogger = true)
+ override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded) =
+ logEvent(event, flushLogger = true)
+ override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved) =
+ logEvent(event, flushLogger = true)
+ override def onUnpersistRDD(event: SparkListenerUnpersistRDD) =
+ logEvent(event, flushLogger = true)
+ override def onApplicationStart(event: SparkListenerApplicationStart) =
+ logEvent(event, flushLogger = true)
+ override def onApplicationEnd(event: SparkListenerApplicationEnd) =
+ logEvent(event, flushLogger = true)
+
+ /**
+ * Stop logging events.
+ * In addition, create an empty special file to indicate application completion.
+ */
+ def stop() = {
+ logger.newFile(APPLICATION_COMPLETE)
+ logger.stop()
+ }
+}
+
+private[spark] object EventLoggingListener extends Logging {
+ val LOG_PREFIX = "EVENT_LOG_"
+ val SPARK_VERSION_PREFIX = "SPARK_VERSION_"
+ val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_"
+ val APPLICATION_COMPLETE = "APPLICATION_COMPLETE"
+
+ // A cache for compression codecs to avoid creating the same codec many times
+ private val codecMap = new mutable.HashMap[String, CompressionCodec]
+
+ def isEventLogFile(fileName: String): Boolean = {
+ fileName.startsWith(LOG_PREFIX)
+ }
+
+ def isSparkVersionFile(fileName: String): Boolean = {
+ fileName.startsWith(SPARK_VERSION_PREFIX)
+ }
+
+ def isCompressionCodecFile(fileName: String): Boolean = {
+ fileName.startsWith(COMPRESSION_CODEC_PREFIX)
+ }
+
+ def isApplicationCompleteFile(fileName: String): Boolean = {
+ fileName == APPLICATION_COMPLETE
+ }
+
+ def parseSparkVersion(fileName: String): String = {
+ if (isSparkVersionFile(fileName)) {
+ fileName.replaceAll(SPARK_VERSION_PREFIX, "")
+ } else ""
+ }
+
+ def parseCompressionCodec(fileName: String): String = {
+ if (isCompressionCodecFile(fileName)) {
+ fileName.replaceAll(COMPRESSION_CODEC_PREFIX, "")
+ } else ""
+ }
+
+ /**
+ * Parse the event logging information associated with the logs in the given directory.
+ *
+ * Specifically, this looks for event log files, the Spark version file, the compression
+ * codec file (if event logs are compressed), and the application completion file (if the
+ * application has run to completion).
+ */
+ def parseLoggingInfo(logDir: Path, fileSystem: FileSystem): EventLoggingInfo = {
+ try {
+ val fileStatuses = fileSystem.listStatus(logDir)
+ val filePaths =
+ if (fileStatuses != null) {
+ fileStatuses.filter(!_.isDir).map(_.getPath).toSeq
+ } else {
+ Seq[Path]()
+ }
+ if (filePaths.isEmpty) {
+ logWarning("No files found in logging directory %s".format(logDir))
+ }
+ EventLoggingInfo(
+ logPaths = filePaths.filter { path => isEventLogFile(path.getName) },
+ sparkVersion = filePaths
+ .find { path => isSparkVersionFile(path.getName) }
+ .map { path => parseSparkVersion(path.getName) }
+ .getOrElse(""),
+ compressionCodec = filePaths
+ .find { path => isCompressionCodecFile(path.getName) }
+ .map { path =>
+ val codec = EventLoggingListener.parseCompressionCodec(path.getName)
+ val conf = new SparkConf
+ conf.set("spark.io.compression.codec", codec)
+ codecMap.getOrElseUpdate(codec, CompressionCodec.createCodec(conf))
+ },
+ applicationComplete = filePaths.exists { path => isApplicationCompleteFile(path.getName) }
+ )
+ } catch {
+ case t: Throwable =>
+ logError("Exception in parsing logging info from directory %s".format(logDir), t)
+ EventLoggingInfo.empty
+ }
+ }
+
+ /**
+ * Parse the event logging information associated with the logs in the given directory.
+ */
+ def parseLoggingInfo(logDir: String, fileSystem: FileSystem): EventLoggingInfo = {
+ parseLoggingInfo(new Path(logDir), fileSystem)
+ }
+}
+
+
+/**
+ * Information needed to process the event logs associated with an application.
+ */
+private[spark] case class EventLoggingInfo(
+ logPaths: Seq[Path],
+ sparkVersion: String,
+ compressionCodec: Option[CompressionCodec],
+ applicationComplete: Boolean = false)
+
+private[spark] object EventLoggingInfo {
+ def empty = EventLoggingInfo(Seq[Path](), "", None, applicationComplete = false)
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
index 5555585c8b4cd..bac37bfdaa23f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
@@ -27,11 +27,14 @@ import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.util.ReflectionUtils
import org.apache.spark.Logging
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.SparkHadoopUtil
/**
+ * :: DeveloperApi ::
* Parses and holds information about inputFormat (and files) specified as a parameter.
*/
+@DeveloperApi
class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_],
val path: String) extends Logging {
@@ -164,8 +167,7 @@ object InputFormatInfo {
PS: I know the wording here is weird, hopefully it makes some sense !
*/
- def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]]
- = {
+ def computePreferredLocations(formats: Seq[InputFormatInfo]): Map[String, Set[SplitInfo]] = {
val nodeToSplit = new HashMap[String, HashSet[SplitInfo]]
for (inputSplit <- formats) {
@@ -178,6 +180,6 @@ object InputFormatInfo {
}
}
- nodeToSplit
+ nodeToSplit.mapValues(_.toSet).toMap
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 9d75d7c4ad69a..a1e21cad48b9b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -22,24 +22,27 @@ import java.text.SimpleDateFormat
import java.util.{Date, Properties}
import java.util.concurrent.LinkedBlockingQueue
-import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
+import scala.collection.mutable.HashMap
import org.apache.spark._
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
/**
+ * :: DeveloperApi ::
* A logger class to record runtime information for jobs in Spark. This class outputs one log file
- * for each Spark job, containing RDD graph, tasks start/stop, shuffle information.
- * JobLogger is a subclass of SparkListener, use addSparkListener to add JobLogger to a SparkContext
- * after the SparkContext is created.
- * Note that each JobLogger only works for one SparkContext
- * @param logDirName The base directory for the log files.
+ * for each Spark job, containing tasks start/stop and shuffle information. JobLogger is a subclass
+ * of SparkListener, use addSparkListener to add JobLogger to a SparkContext after the SparkContext
+ * is created. Note that each JobLogger only works for one SparkContext
+ *
+ * NOTE: The functionality of this class is heavily stripped down to accommodate for a general
+ * refactor of the SparkListener interface. In its place, the EventLoggingListener is introduced
+ * to log application information as SparkListenerEvents. To enable this functionality, set
+ * spark.eventLog.enabled to true.
*/
-
-class JobLogger(val user: String, val logDirName: String)
- extends SparkListener with Logging {
+@DeveloperApi
+@deprecated("Log application information by setting spark.eventLog.enabled.", "1.0.0")
+class JobLogger(val user: String, val logDirName: String) extends SparkListener with Logging {
def this() = this(System.getProperty("user.name", ""),
String.valueOf(System.currentTimeMillis()))
@@ -51,19 +54,21 @@ class JobLogger(val user: String, val logDirName: String)
"/tmp/spark-%s".format(user)
}
- private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
- private val stageIDToJobID = new HashMap[Int, Int]
- private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
- private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
- private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+ private val jobIdToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIdToJobId = new HashMap[Int, Int]
+ private val jobIdToStageIds = new HashMap[Int, Seq[Int]]
+ private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
+ override def initialValue() = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ }
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent]
createLogDir()
// The following 5 functions are used only in testing.
private[scheduler] def getLogDir = logDir
- private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
- private[scheduler] def getStageIDToJobID = stageIDToJobID
- private[scheduler] def getJobIDToStages = jobIDToStages
+ private[scheduler] def getJobIdToPrintWriter = jobIdToPrintWriter
+ private[scheduler] def getStageIdToJobId = stageIdToJobId
+ private[scheduler] def getJobIdToStageIds = jobIdToStageIds
private[scheduler] def getEventQueue = eventQueue
/** Create a folder for log files, the folder's name is the creation time of jobLogger */
@@ -80,191 +85,78 @@ class JobLogger(val user: String, val logDirName: String)
/**
* Create a log file for one job
- * @param jobID ID of the job
- * @exception FileNotFoundException Fail to create log file
+ * @param jobId ID of the job
+ * @throws FileNotFoundException Fail to create log file
*/
- protected def createLogWriter(jobID: Int) {
+ protected def createLogWriter(jobId: Int) {
try {
- val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
- jobIDToPrintWriter += (jobID -> fileWriter)
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobId)
+ jobIdToPrintWriter += (jobId -> fileWriter)
} catch {
case e: FileNotFoundException => e.printStackTrace()
}
}
/**
- * Close log file, and clean the stage relationship in stageIDToJobID
- * @param jobID ID of the job
+ * Close log file, and clean the stage relationship in stageIdToJobId
+ * @param jobId ID of the job
*/
- protected def closeLogWriter(jobID: Int) {
- jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ protected def closeLogWriter(jobId: Int) {
+ jobIdToPrintWriter.get(jobId).foreach { fileWriter =>
fileWriter.close()
- jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
- stageIDToJobID -= stage.id
+ jobIdToStageIds.get(jobId).foreach(_.foreach { stageId =>
+ stageIdToJobId -= stageId
})
- jobIDToPrintWriter -= jobID
- jobIDToStages -= jobID
+ jobIdToPrintWriter -= jobId
+ jobIdToStageIds -= jobId
}
}
+ /**
+ * Build up the maps that represent stage-job relationships
+ * @param jobId ID of the job
+ * @param stageIds IDs of the associated stages
+ */
+ protected def buildJobStageDependencies(jobId: Int, stageIds: Seq[Int]) = {
+ jobIdToStageIds(jobId) = stageIds
+ stageIds.foreach { stageId => stageIdToJobId(stageId) = jobId }
+ }
+
/**
* Write info into log file
- * @param jobID ID of the job
+ * @param jobId ID of the job
* @param info Info to be recorded
* @param withTime Controls whether to record time stamp before the info, default is true
*/
- protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ protected def jobLogInfo(jobId: Int, info: String, withTime: Boolean = true) {
var writeInfo = info
if (withTime) {
val date = new Date(System.currentTimeMillis())
- writeInfo = DATE_FORMAT.format(date) + ": " + info
+ writeInfo = dateFormat.get.format(date) + ": " + info
}
- jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo))
}
/**
* Write info into log file
- * @param stageID ID of the stage
+ * @param stageId ID of the stage
* @param info Info to be recorded
* @param withTime Controls whether to record time stamp before the info, default is true
*/
- protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) {
- stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
- }
-
- /**
- * Build stage dependency for a job
- * @param jobID ID of the job
- * @param stage Root stage of the job
- */
- protected def buildJobDep(jobID: Int, stage: Stage) {
- if (stage.jobId == jobID) {
- jobIDToStages.get(jobID) match {
- case Some(stageList) => stageList += stage
- case None => val stageList = new ListBuffer[Stage]
- stageList += stage
- jobIDToStages += (jobID -> stageList)
- }
- stageIDToJobID += (stage.id -> jobID)
- stage.parents.foreach(buildJobDep(jobID, _))
- }
- }
-
- /**
- * Record stage dependency and RDD dependency for a stage
- * @param jobID Job ID of the stage
- */
- protected def recordStageDep(jobID: Int) {
- def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
- var rddList = new ListBuffer[RDD[_]]
- rddList += rdd
- rdd.dependencies.foreach {
- case shufDep: ShuffleDependency[_, _] =>
- case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
- }
- rddList
- }
- jobIDToStages.get(jobID).foreach {_.foreach { stage =>
- var depRddDesc: String = ""
- getRddsInStage(stage.rdd).foreach { rdd =>
- depRddDesc += rdd.id + ","
- }
- var depStageDesc: String = ""
- stage.parents.foreach { stage =>
- depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
- }
- jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
- depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
- " STAGE_DEP=" + depStageDesc, false)
- }
- }
- }
-
- /**
- * Generate indents and convert to String
- * @param indent Number of indents
- * @return string of indents
- */
- protected def indentString(indent: Int): String = {
- val sb = new StringBuilder()
- for (i <- 1 to indent) {
- sb.append(" ")
- }
- sb.toString()
- }
-
- /**
- * Get RDD's name
- * @param rdd Input RDD
- * @return String of RDD's name
- */
- protected def getRddName(rdd: RDD[_]): String = {
- var rddName = rdd.getClass.getSimpleName
- if (rdd.name != null) {
- rddName = rdd.name
- }
- rddName
- }
-
- /**
- * Record RDD dependency graph in a stage
- * @param jobID Job ID of the stage
- * @param rdd Root RDD of the stage
- * @param indent Indent number before info
- */
- protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
- val rddInfo =
- if (rdd.getStorageLevel != StorageLevel.NONE) {
- "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " CACHED" + " " +
- rdd.origin + " " + rdd.generator
- } else {
- "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " NONE" + " " +
- rdd.origin + " " + rdd.generator
- }
- jobLogInfo(jobID, indentString(indent) + rddInfo, false)
- rdd.dependencies.foreach {
- case shufDep: ShuffleDependency[_, _] =>
- val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
- jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
- case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
- }
- }
-
- /**
- * Record stage dependency graph of a job
- * @param jobID Job ID of the stage
- * @param stage Root stage of the job
- * @param indent Indent number before info, default is 0
- */
- protected def recordStageDepGraph(jobID: Int, stage: Stage, idSet: HashSet[Int], indent: Int = 0)
- {
- val stageInfo = if (stage.isShuffleMap) {
- "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
- } else {
- "STAGE_ID=" + stage.id + " RESULT_STAGE"
- }
- if (stage.jobId == jobID) {
- jobLogInfo(jobID, indentString(indent) + stageInfo, false)
- if (!idSet.contains(stage.id)) {
- idSet += stage.id
- recordRddInStageGraph(jobID, stage.rdd, indent)
- stage.parents.foreach(recordStageDepGraph(jobID, _, idSet, indent + 2))
- }
- } else {
- jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
- }
+ protected def stageLogInfo(stageId: Int, info: String, withTime: Boolean = true) {
+ stageIdToJobId.get(stageId).foreach(jobId => jobLogInfo(jobId, info, withTime))
}
/**
* Record task metrics into job log files, including execution info and shuffle metrics
- * @param stageID Stage ID of the task
+ * @param stageId Stage ID of the task
* @param status Status info of the task
* @param taskInfo Task description info
* @param taskMetrics Task running metrics
*/
- protected def recordTaskMetrics(stageID: Int, status: String,
+ protected def recordTaskMetrics(stageId: Int, status: String,
taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
- val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageId +
" START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
" EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
@@ -275,7 +167,6 @@ class JobLogger(val user: String, val logDirName: String)
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
" REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
" REMOTE_BYTES_READ=" + metrics.remoteBytesRead
case None => ""
}
@@ -283,7 +174,7 @@ class JobLogger(val user: String, val logDirName: String)
case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
case None => ""
}
- stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ stageLogInfo(stageId, status + info + executorRunTime + readMetrics + writeMetrics)
}
/**
@@ -291,8 +182,9 @@ class JobLogger(val user: String, val logDirName: String)
* @param stageSubmitted Stage submitted event
*/
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
+ val stageInfo = stageSubmitted.stageInfo
+ stageLogInfo(stageInfo.stageId, "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageInfo.stageId, stageInfo.numTasks))
}
/**
@@ -300,36 +192,34 @@ class JobLogger(val user: String, val logDirName: String)
* @param stageCompleted Stage completed event
*/
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
- stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
- stageCompleted.stage.stageId))
+ val stageId = stageCompleted.stageInfo.stageId
+ if (stageCompleted.stageInfo.failureReason.isEmpty) {
+ stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=COMPLETED")
+ } else {
+ stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=FAILED")
+ }
}
- override def onTaskStart(taskStart: SparkListenerTaskStart) { }
-
/**
* When task ends, record task completion status and metrics
* @param taskEnd Task end event
*/
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- val task = taskEnd.task
val taskInfo = taskEnd.taskInfo
- var taskStatus = ""
- task match {
- case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
- case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
- }
+ var taskStatus = "TASK_TYPE=%s".format(taskEnd.taskType)
+ val taskMetrics = if (taskEnd.taskMetrics != null) taskEnd.taskMetrics else TaskMetrics.empty
taskEnd.reason match {
case Success => taskStatus += " STATUS=SUCCESS"
- recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
+ recordTaskMetrics(taskEnd.stageId, taskStatus, taskInfo, taskMetrics)
case Resubmitted =>
taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId
- stageLogInfo(task.stageId, taskStatus)
+ " STAGE_ID=" + taskEnd.stageId
+ stageLogInfo(taskEnd.stageId, taskStatus)
case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
- task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
mapId + " REDUCE_ID=" + reduceId
- stageLogInfo(task.stageId, taskStatus)
+ stageLogInfo(taskEnd.stageId, taskStatus)
case _ =>
}
}
@@ -339,28 +229,28 @@ class JobLogger(val user: String, val logDirName: String)
* @param jobEnd Job end event
*/
override def onJobEnd(jobEnd: SparkListenerJobEnd) {
- val job = jobEnd.job
- var info = "JOB_ID=" + job.jobId
+ val jobId = jobEnd.jobId
+ var info = "JOB_ID=" + jobId
jobEnd.jobResult match {
case JobSucceeded => info += " STATUS=SUCCESS"
- case JobFailed(exception, _) =>
+ case JobFailed(exception) =>
info += " STATUS=FAILED REASON="
exception.getMessage.split("\\s+").foreach(info += _ + "_")
case _ =>
}
- jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
- closeLogWriter(job.jobId)
+ jobLogInfo(jobId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(jobId)
}
/**
* Record job properties into job log file
- * @param jobID ID of the job
+ * @param jobId ID of the job
* @param properties Properties of the job
*/
- protected def recordJobProperties(jobID: Int, properties: Properties) {
- if(properties != null) {
+ protected def recordJobProperties(jobId: Int, properties: Properties) {
+ if (properties != null) {
val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
- jobLogInfo(jobID, description, false)
+ jobLogInfo(jobId, description, false)
}
}
@@ -369,14 +259,11 @@ class JobLogger(val user: String, val logDirName: String)
* @param jobStart Job start event
*/
override def onJobStart(jobStart: SparkListenerJobStart) {
- val job = jobStart.job
+ val jobId = jobStart.jobId
val properties = jobStart.properties
- createLogWriter(job.jobId)
- recordJobProperties(job.jobId, properties)
- buildJobDep(job.jobId, job.finalStage)
- recordStageDep(job.jobId)
- recordStageDepGraph(job.jobId, job.finalStage, new HashSet[Int])
- jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
+ createLogWriter(jobId)
+ recordJobProperties(jobId, properties)
+ buildJobStageDependencies(jobId, jobStart.stageIds)
+ jobLogInfo(jobId, "JOB_ID=" + jobId + " STATUS=STARTED")
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala
index d94f6ad924260..4cd6cbe189aab 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala
@@ -17,11 +17,17 @@
package org.apache.spark.scheduler
+import org.apache.spark.annotation.DeveloperApi
+
/**
+ * :: DeveloperApi ::
* A result of a job in the DAGScheduler.
*/
-private[spark] sealed trait JobResult
+@DeveloperApi
+sealed trait JobResult
+
+@DeveloperApi
+case object JobSucceeded extends JobResult
-private[spark] case object JobSucceeded extends JobResult
-private[spark] case class JobFailed(exception: Exception, failedStage: Option[Stage])
- extends JobResult
+@DeveloperApi
+private[spark] case class JobFailed(exception: Exception) extends JobResult
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index b026f860a8cd8..e9bfee2248e5b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -64,7 +64,7 @@ private[spark] class JobWaiter[T](
override def jobFailed(exception: Exception): Unit = synchronized {
_jobFinished = true
- jobResult = JobFailed(exception, None)
+ jobResult = JobFailed(exception)
this.notifyAll()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
new file mode 100644
index 0000000000000..dec3316bf7745
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.concurrent.{LinkedBlockingQueue, Semaphore}
+
+import org.apache.spark.Logging
+
+/**
+ * Asynchronously passes SparkListenerEvents to registered SparkListeners.
+ *
+ * Until start() is called, all posted events are only buffered. Only after this listener bus
+ * has started will events be actually propagated to all attached listeners. This listener bus
+ * is stopped when it receives a SparkListenerShutdown event, which is posted using stop().
+ */
+private[spark] class LiveListenerBus extends SparkListenerBus with Logging {
+
+ /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than
+ * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
+ private val EVENT_QUEUE_CAPACITY = 10000
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY)
+ private var queueFullErrorMessageLogged = false
+ private var started = false
+
+ // A counter that represents the number of events produced and consumed in the queue
+ private val eventLock = new Semaphore(0)
+
+ private val listenerThread = new Thread("SparkListenerBus") {
+ setDaemon(true)
+ override def run() {
+ while (true) {
+ eventLock.acquire()
+ // Atomically remove and process this event
+ LiveListenerBus.this.synchronized {
+ val event = eventQueue.poll
+ if (event == SparkListenerShutdown) {
+ // Get out of the while loop and shutdown the daemon thread
+ return
+ }
+ Option(event).foreach(postToAll)
+ }
+ }
+ }
+ }
+
+ /**
+ * Start sending events to attached listeners.
+ *
+ * This first sends out all buffered events posted before this listener bus has started, then
+ * listens for any additional events asynchronously while the listener bus is still running.
+ * This should only be called once.
+ */
+ def start() {
+ if (started) {
+ throw new IllegalStateException("Listener bus already started!")
+ }
+ listenerThread.start()
+ started = true
+ }
+
+ def post(event: SparkListenerEvent) {
+ val eventAdded = eventQueue.offer(event)
+ if (eventAdded) {
+ eventLock.release()
+ } else if (!queueFullErrorMessageLogged) {
+ logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
+ "This likely means one of the SparkListeners is too slow and cannot keep up with the " +
+ "rate at which tasks are being started by the scheduler.")
+ queueFullErrorMessageLogged = true
+ }
+ }
+
+ /**
+ * For testing only. Wait until there are no more events in the queue, or until the specified
+ * time has elapsed. Return true if the queue has emptied and false is the specified time
+ * elapsed before the queue emptied.
+ */
+ def waitUntilEmpty(timeoutMillis: Int): Boolean = {
+ val finishTime = System.currentTimeMillis + timeoutMillis
+ while (!queueIsEmpty) {
+ if (System.currentTimeMillis > finishTime) {
+ return false
+ }
+ /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify
+ * add overhead in the general case. */
+ Thread.sleep(10)
+ }
+ true
+ }
+
+ /**
+ * Return whether the event queue is empty.
+ *
+ * The use of synchronized here guarantees that all events that once belonged to this queue
+ * have already been processed by all attached listeners, if this returns true.
+ */
+ def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty }
+
+ def stop() {
+ if (!started) {
+ throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!")
+ }
+ post(SparkListenerShutdown)
+ listenerThread.join()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 4bc13c23d980b..187672c4e19e7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -62,7 +62,7 @@ private[spark] class Pool(
override def addSchedulable(schedulable: Schedulable) {
schedulableQueue += schedulable
schedulableNameToSchedulable(schedulable.name) = schedulable
- schedulable.parent= this
+ schedulable.parent = this
}
override def removeSchedulable(schedulable: Schedulable) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
new file mode 100644
index 0000000000000..f89724d4ea196
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.{BufferedInputStream, InputStream}
+
+import scala.io.Source
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.Logging
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.util.JsonProtocol
+
+/**
+ * A SparkListenerBus that replays logged events from persisted storage.
+ *
+ * This assumes the given paths are valid log files, where each line can be deserialized into
+ * exactly one SparkListenerEvent.
+ */
+private[spark] class ReplayListenerBus(
+ logPaths: Seq[Path],
+ fileSystem: FileSystem,
+ compressionCodec: Option[CompressionCodec])
+ extends SparkListenerBus with Logging {
+
+ private var replayed = false
+
+ if (logPaths.length == 0) {
+ logWarning("Log path provided contains no log files.")
+ }
+
+ /**
+ * Replay each event in the order maintained in the given logs.
+ * This should only be called exactly once.
+ */
+ def replay() {
+ assert(!replayed, "ReplayListenerBus cannot replay events more than once")
+ logPaths.foreach { path =>
+ // Keep track of input streams at all levels to close them later
+ // This is necessary because an exception can occur in between stream initializations
+ var fileStream: Option[InputStream] = None
+ var bufferedStream: Option[InputStream] = None
+ var compressStream: Option[InputStream] = None
+ var currentLine = ""
+ try {
+ fileStream = Some(fileSystem.open(path))
+ bufferedStream = Some(new BufferedInputStream(fileStream.get))
+ compressStream = Some(wrapForCompression(bufferedStream.get))
+
+ // Parse each line as an event and post the event to all attached listeners
+ val lines = Source.fromInputStream(compressStream.get).getLines()
+ lines.foreach { line =>
+ currentLine = line
+ postToAll(JsonProtocol.sparkEventFromJson(parse(line)))
+ }
+ } catch {
+ case e: Exception =>
+ logError("Exception in parsing Spark event log %s".format(path), e)
+ logError("Malformed line: %s\n".format(currentLine))
+ } finally {
+ fileStream.foreach(_.close())
+ bufferedStream.foreach(_.close())
+ compressStream.foreach(_.close())
+ }
+ }
+ replayed = true
+ }
+
+ /** If a compression codec is specified, wrap the given stream in a compression stream. */
+ private def wrapForCompression(stream: InputStream): InputStream = {
+ compressionCodec.map(_.compressedInputStream(stream)).getOrElse(stream)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 3fc6cc9850feb..0e8d551e4b2ab 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -20,21 +20,18 @@ package org.apache.spark.scheduler
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+import scala.collection.mutable.HashMap
+import scala.language.existentials
+
import org.apache.spark._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.RDDCheckpointData
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
+import org.apache.spark.rdd.{RDD, RDDCheckpointData}
private[spark] object ResultTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
- val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
-
- // TODO: This object shouldn't have global variables
- val metadataCleaner = new MetadataCleaner(
- MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues, new SparkConf)
+ private val serializedInfoCache = new HashMap[Int, Array[Byte]]
def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] =
{
@@ -58,7 +55,6 @@ private[spark] object ResultTask {
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
{
- val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
@@ -67,6 +63,10 @@ private[spark] object ResultTask {
(rdd, func)
}
+ def removeStage(stageId: Int) {
+ serializedInfoCache.remove(stageId)
+ }
+
def clearCache() {
synchronized {
serializedInfoCache.clear()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
index e4eced383c3a5..6c5827f75e636 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
@@ -23,6 +23,7 @@ import java.util.{NoSuchElementException, Properties}
import scala.xml.XML
import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.util.Utils
/**
* An interface to build Schedulable tree
@@ -72,7 +73,7 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
schedulerAllocFile.map { f =>
new FileInputStream(f)
}.getOrElse {
- getClass.getClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
+ Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index eefc8c232b564..6a6d8e609bc39 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
/**
* A backend interface for scheduling systems that allows plugging in different ones under
- * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
+ * TaskSchedulerImpl. We assume a Mesos-like model where the application gets resource offers as
* machines become available and can launch tasks on them.
*/
private[spark] trait SchedulerBackend {
@@ -28,5 +28,6 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int
- def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException
+ def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit =
+ throw new UnsupportedOperationException
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
index 3832ee7ff6eef..75186b6ba4a41 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
@@ -25,5 +25,5 @@ package org.apache.spark.scheduler
object SchedulingMode extends Enumeration {
type SchedulingMode = Value
- val FAIR,FIFO,NONE = Value
+ val FAIR, FIFO, NONE = Value
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 77789031f464a..02b62de7e36b6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -21,24 +21,20 @@ import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
+import scala.language.existentials
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.RDDCheckpointData
+import org.apache.spark.rdd.{RDD, RDDCheckpointData}
+import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
private[spark] object ShuffleMapTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
- val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
-
- // TODO: This object shouldn't have global variables
- val metadataCleaner = new MetadataCleaner(
- MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues, new SparkConf)
+ private val serializedInfoCache = new HashMap[Int, Array[Byte]]
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
@@ -79,6 +75,10 @@ private[spark] object ShuffleMapTask {
HashMap(set.toSeq: _*)
}
+ def removeStage(stageId: Int) {
+ serializedInfoCache.remove(stageId)
+ }
+
def clearCache() {
synchronized {
serializedInfoCache.clear()
@@ -153,7 +153,7 @@ private[spark] class ShuffleMapTask(
try {
// Obtain all the block writers for shuffle blocks.
- val ser = SparkEnv.get.serializerManager.get(dep.serializerClass, SparkEnv.get.conf)
+ val ser = Serializer.getSerializer(dep.serializer)
shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
// Write the map output to its associated buckets.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 9590c03f10632..378cf1aaebe7b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -19,40 +19,80 @@ package org.apache.spark.scheduler
import java.util.Properties
+import scala.collection.Map
+import scala.collection.mutable
+
import org.apache.spark.{Logging, TaskEndReason}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{Distribution, Utils}
-sealed trait SparkListenerEvents
+@DeveloperApi
+sealed trait SparkListenerEvent
+
+@DeveloperApi
+case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null)
+ extends SparkListenerEvent
+
+@DeveloperApi
+case class SparkListenerStageCompleted(stageInfo: StageInfo) extends SparkListenerEvent
+
+@DeveloperApi
+case class SparkListenerTaskStart(stageId: Int, taskInfo: TaskInfo) extends SparkListenerEvent
+
+@DeveloperApi
+case class SparkListenerTaskGettingResult(taskInfo: TaskInfo) extends SparkListenerEvent
-case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
- extends SparkListenerEvents
+@DeveloperApi
+case class SparkListenerTaskEnd(
+ stageId: Int,
+ taskType: String,
+ reason: TaskEndReason,
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics)
+ extends SparkListenerEvent
-case class SparkListenerStageCompleted(stage: StageInfo) extends SparkListenerEvents
+@DeveloperApi
+case class SparkListenerJobStart(jobId: Int, stageIds: Seq[Int], properties: Properties = null)
+ extends SparkListenerEvent
-case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
+@DeveloperApi
+case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent
-case class SparkListenerTaskGettingResult(
- task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
+@DeveloperApi
+case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(String, String)]])
+ extends SparkListenerEvent
-case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
- taskMetrics: TaskMetrics) extends SparkListenerEvents
+@DeveloperApi
+case class SparkListenerBlockManagerAdded(blockManagerId: BlockManagerId, maxMem: Long)
+ extends SparkListenerEvent
-case class SparkListenerJobStart(job: ActiveJob, stageIds: Array[Int],
- properties: Properties = null) extends SparkListenerEvents
+@DeveloperApi
+case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId)
+ extends SparkListenerEvent
-case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult)
- extends SparkListenerEvents
+@DeveloperApi
+case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent
+
+case class SparkListenerApplicationStart(appName: String, time: Long, sparkUser: String)
+ extends SparkListenerEvent
+
+case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent
/** An event used in the listener to shutdown the listener daemon thread. */
-private[scheduler] case object SparkListenerShutdown extends SparkListenerEvents
+private[spark] case object SparkListenerShutdown extends SparkListenerEvent
+
/**
- * Interface for listening to events from the Spark scheduler.
+ * :: DeveloperApi ::
+ * Interface for listening to events from the Spark scheduler. Note that this is an internal
+ * interface which might change in different Spark releases.
*/
+@DeveloperApi
trait SparkListener {
/**
- * Called when a stage is completed, with information on the completed stage
+ * Called when a stage completes successfully or fails, with information on the completed stage.
*/
def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { }
@@ -87,97 +127,146 @@ trait SparkListener {
*/
def onJobEnd(jobEnd: SparkListenerJobEnd) { }
+ /**
+ * Called when environment properties have been updated
+ */
+ def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { }
+
+ /**
+ * Called when a new block manager has joined
+ */
+ def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded) { }
+
+ /**
+ * Called when an existing block manager has been removed
+ */
+ def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved) { }
+
+ /**
+ * Called when an RDD is manually unpersisted by the application
+ */
+ def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) { }
+
+ /**
+ * Called when the application starts
+ */
+ def onApplicationStart(applicationStart: SparkListenerApplicationStart) { }
+
+ /**
+ * Called when the application ends
+ */
+ def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { }
}
/**
+ * :: DeveloperApi ::
* Simple SparkListener that logs a few summary statistics when each stage completes
*/
+@DeveloperApi
class StatsReportListener extends SparkListener with Logging {
+
+ import org.apache.spark.scheduler.StatsReportListener._
+
+ private val taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]()
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ val info = taskEnd.taskInfo
+ val metrics = taskEnd.taskMetrics
+ if (info != null && metrics != null) {
+ taskInfoMetrics += ((info, metrics))
+ }
+ }
+
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
- import org.apache.spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
- this.logInfo("Finished stage: " + stageCompleted.stage)
- showMillisDistribution("task runtime:", (info, _) => Some(info.duration))
+ this.logInfo("Finished stage: " + stageCompleted.stageInfo)
+ showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics)
- //shuffle write
+ // Shuffle write
showBytesDistribution("shuffle bytes written:",
- (_,metric) => metric.shuffleWriteMetrics.map(_.shuffleBytesWritten))
+ (_, metric) => metric.shuffleWriteMetrics.map(_.shuffleBytesWritten), taskInfoMetrics)
- //fetch & io
+ // Fetch & I/O
showMillisDistribution("fetch wait time:",
- (_, metric) => metric.shuffleReadMetrics.map(_.fetchWaitTime))
+ (_, metric) => metric.shuffleReadMetrics.map(_.fetchWaitTime), taskInfoMetrics)
showBytesDistribution("remote bytes read:",
- (_, metric) => metric.shuffleReadMetrics.map(_.remoteBytesRead))
- showBytesDistribution("task result size:", (_, metric) => Some(metric.resultSize))
+ (_, metric) => metric.shuffleReadMetrics.map(_.remoteBytesRead), taskInfoMetrics)
+ showBytesDistribution("task result size:",
+ (_, metric) => Some(metric.resultSize), taskInfoMetrics)
- //runtime breakdown
-
- val runtimePcts = stageCompleted.stage.taskInfos.map{
- case (info, metrics) => RuntimePercentage(info.duration, metrics)
+ // Runtime breakdown
+ val runtimePcts = taskInfoMetrics.map { case (info, metrics) =>
+ RuntimePercentage(info.duration, metrics)
}
showDistribution("executor (non-fetch) time pct: ",
- Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%")
+ Distribution(runtimePcts.map(_.executorPct * 100)), "%2.0f %%")
showDistribution("fetch wait time pct: ",
- Distribution(runtimePcts.flatMap{_.fetchPct.map{_ * 100}}), "%2.0f %%")
- showDistribution("other time pct: ", Distribution(runtimePcts.map{_.other * 100}), "%2.0f %%")
+ Distribution(runtimePcts.flatMap(_.fetchPct.map(_ * 100))), "%2.0f %%")
+ showDistribution("other time pct: ", Distribution(runtimePcts.map(_.other * 100)), "%2.0f %%")
+ taskInfoMetrics.clear()
}
}
private[spark] object StatsReportListener extends Logging {
- //for profiling, the extremes are more interesting
+ // For profiling, the extremes are more interesting
val percentiles = Array[Int](0,5,10,25,50,75,90,95,100)
- val probabilities = percentiles.map{_ / 100.0}
+ val probabilities = percentiles.map(_ / 100.0)
val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
- def extractDoubleDistribution(stage: SparkListenerStageCompleted,
- getMetric: (TaskInfo,TaskMetrics) => Option[Double])
- : Option[Distribution] = {
- Distribution(stage.stage.taskInfos.flatMap {
- case ((info,metric)) => getMetric(info, metric)})
+ def extractDoubleDistribution(
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)],
+ getMetric: (TaskInfo, TaskMetrics) => Option[Double]): Option[Distribution] = {
+ Distribution(taskInfoMetrics.flatMap { case (info, metric) => getMetric(info, metric) })
}
- //is there some way to setup the types that I can get rid of this completely?
- def extractLongDistribution(stage: SparkListenerStageCompleted,
- getMetric: (TaskInfo,TaskMetrics) => Option[Long])
- : Option[Distribution] = {
- extractDoubleDistribution(stage, (info, metric) => getMetric(info,metric).map{_.toDouble})
+ // Is there some way to setup the types that I can get rid of this completely?
+ def extractLongDistribution(
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)],
+ getMetric: (TaskInfo, TaskMetrics) => Option[Long]): Option[Distribution] = {
+ extractDoubleDistribution(
+ taskInfoMetrics,
+ (info, metric) => { getMetric(info, metric).map(_.toDouble) })
}
def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
val stats = d.statCounter
- val quantiles = d.getQuantiles(probabilities).map{formatNumber}
+ val quantiles = d.getQuantiles(probabilities).map(formatNumber)
logInfo(heading + stats)
logInfo(percentilesHeader)
logInfo("\t" + quantiles.mkString("\t"))
}
- def showDistribution(heading: String, dOpt: Option[Distribution], formatNumber: Double => String)
- {
+ def showDistribution(
+ heading: String,
+ dOpt: Option[Distribution],
+ formatNumber: Double => String) {
dOpt.foreach { d => showDistribution(heading, d, formatNumber)}
}
def showDistribution(heading: String, dOpt: Option[Distribution], format:String) {
- def f(d:Double) = format.format(d)
+ def f(d: Double) = format.format(d)
showDistribution(heading, dOpt, f _)
}
def showDistribution(
heading: String,
format: String,
- getMetric: (TaskInfo, TaskMetrics) => Option[Double])
- (implicit stage: SparkListenerStageCompleted) {
- showDistribution(heading, extractDoubleDistribution(stage, getMetric), format)
+ getMetric: (TaskInfo, TaskMetrics) => Option[Double],
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+ showDistribution(heading, extractDoubleDistribution(taskInfoMetrics, getMetric), format)
}
- def showBytesDistribution(heading:String, getMetric: (TaskInfo,TaskMetrics) => Option[Long])
- (implicit stage: SparkListenerStageCompleted) {
- showBytesDistribution(heading, extractLongDistribution(stage, getMetric))
+ def showBytesDistribution(
+ heading:String,
+ getMetric: (TaskInfo, TaskMetrics) => Option[Long],
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+ showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric))
}
def showBytesDistribution(heading: String, dOpt: Option[Distribution]) {
- dOpt.foreach{dist => showBytesDistribution(heading, dist)}
+ dOpt.foreach { dist => showBytesDistribution(heading, dist) }
}
def showBytesDistribution(heading: String, dist: Distribution) {
@@ -189,9 +278,11 @@ private[spark] object StatsReportListener extends Logging {
(d => StatsReportListener.millisToString(d.toLong)): Double => String)
}
- def showMillisDistribution(heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long])
- (implicit stage: SparkListenerStageCompleted) {
- showMillisDistribution(heading, extractLongDistribution(stage, getMetric))
+ def showMillisDistribution(
+ heading: String,
+ getMetric: (TaskInfo, TaskMetrics) => Option[Long],
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+ showMillisDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric))
}
val seconds = 1000L
@@ -199,7 +290,7 @@ private[spark] object StatsReportListener extends Logging {
val hours = minutes * 60
/**
- * reformat a time interval in milliseconds to a prettier format for output
+ * Reformat a time interval in milliseconds to a prettier format for output
*/
def millisToString(ms: Long) = {
val (size, units) =
@@ -221,8 +312,8 @@ private case class RuntimePercentage(executorPct: Double, fetchPct: Option[Doubl
private object RuntimePercentage {
def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
val denom = totalTime.toDouble
- val fetchTime = metrics.shuffleReadMetrics.map{_.fetchWaitTime}
- val fetch = fetchTime.map{_ / denom}
+ val fetchTime = metrics.shuffleReadMetrics.map(_.fetchWaitTime)
+ val fetch = fetchTime.map(_ / denom)
val exec = (metrics.executorRunTime - fetchTime.getOrElse(0L)) / denom
val other = 1.0 - (exec + fetch.getOrElse(0d))
RuntimePercentage(exec, fetch, other)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index 17b1328b86788..d6df193d9bcf8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -1,100 +1,71 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import java.util.concurrent.LinkedBlockingQueue
-
-import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
-
-import org.apache.spark.Logging
-
-/** Asynchronously passes SparkListenerEvents to registered SparkListeners. */
-private[spark] class SparkListenerBus extends Logging {
- private val sparkListeners = new ArrayBuffer[SparkListener] with SynchronizedBuffer[SparkListener]
-
- /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than
- * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
- private val EVENT_QUEUE_CAPACITY = 10000
- private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents](EVENT_QUEUE_CAPACITY)
- private var queueFullErrorMessageLogged = false
-
- // Create a new daemon thread to listen for events. This thread is stopped when it receives
- // a SparkListenerShutdown event, using the stop method.
- new Thread("SparkListenerBus") {
- setDaemon(true)
- override def run() {
- while (true) {
- val event = eventQueue.take
- event match {
- case stageSubmitted: SparkListenerStageSubmitted =>
- sparkListeners.foreach(_.onStageSubmitted(stageSubmitted))
- case stageCompleted: SparkListenerStageCompleted =>
- sparkListeners.foreach(_.onStageCompleted(stageCompleted))
- case jobStart: SparkListenerJobStart =>
- sparkListeners.foreach(_.onJobStart(jobStart))
- case jobEnd: SparkListenerJobEnd =>
- sparkListeners.foreach(_.onJobEnd(jobEnd))
- case taskStart: SparkListenerTaskStart =>
- sparkListeners.foreach(_.onTaskStart(taskStart))
- case taskGettingResult: SparkListenerTaskGettingResult =>
- sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
- case taskEnd: SparkListenerTaskEnd =>
- sparkListeners.foreach(_.onTaskEnd(taskEnd))
- case SparkListenerShutdown =>
- // Get out of the while loop and shutdown the daemon thread
- return
- case _ =>
- }
- }
- }
- }.start()
-
- def addListener(listener: SparkListener) {
- sparkListeners += listener
- }
-
- def post(event: SparkListenerEvents) {
- val eventAdded = eventQueue.offer(event)
- if (!eventAdded && !queueFullErrorMessageLogged) {
- logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
- "This likely means one of the SparkListeners is too slow and cannot keep up with the " +
- "rate at which tasks are being started by the scheduler.")
- queueFullErrorMessageLogged = true
- }
- }
-
- /**
- * Waits until there are no more events in the queue, or until the specified time has elapsed.
- * Used for testing only. Returns true if the queue has emptied and false is the specified time
- * elapsed before the queue emptied.
- */
- def waitUntilEmpty(timeoutMillis: Int): Boolean = {
- val finishTime = System.currentTimeMillis + timeoutMillis
- while (!eventQueue.isEmpty) {
- if (System.currentTimeMillis > finishTime) {
- return false
- }
- /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify
- * add overhead in the general case. */
- Thread.sleep(10)
- }
- true
- }
-
- def stop(): Unit = post(SparkListenerShutdown)
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * A SparkListenerEvent bus that relays events to its listeners
+ */
+private[spark] trait SparkListenerBus {
+
+ // SparkListeners attached to this event bus
+ protected val sparkListeners = new ArrayBuffer[SparkListener]
+ with mutable.SynchronizedBuffer[SparkListener]
+
+ def addListener(listener: SparkListener) {
+ sparkListeners += listener
+ }
+
+ /**
+ * Post an event to all attached listeners. This does nothing if the event is
+ * SparkListenerShutdown.
+ */
+ protected def postToAll(event: SparkListenerEvent) {
+ event match {
+ case stageSubmitted: SparkListenerStageSubmitted =>
+ sparkListeners.foreach(_.onStageSubmitted(stageSubmitted))
+ case stageCompleted: SparkListenerStageCompleted =>
+ sparkListeners.foreach(_.onStageCompleted(stageCompleted))
+ case jobStart: SparkListenerJobStart =>
+ sparkListeners.foreach(_.onJobStart(jobStart))
+ case jobEnd: SparkListenerJobEnd =>
+ sparkListeners.foreach(_.onJobEnd(jobEnd))
+ case taskStart: SparkListenerTaskStart =>
+ sparkListeners.foreach(_.onTaskStart(taskStart))
+ case taskGettingResult: SparkListenerTaskGettingResult =>
+ sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
+ case taskEnd: SparkListenerTaskEnd =>
+ sparkListeners.foreach(_.onTaskEnd(taskEnd))
+ case environmentUpdate: SparkListenerEnvironmentUpdate =>
+ sparkListeners.foreach(_.onEnvironmentUpdate(environmentUpdate))
+ case blockManagerAdded: SparkListenerBlockManagerAdded =>
+ sparkListeners.foreach(_.onBlockManagerAdded(blockManagerAdded))
+ case blockManagerRemoved: SparkListenerBlockManagerRemoved =>
+ sparkListeners.foreach(_.onBlockManagerRemoved(blockManagerRemoved))
+ case unpersistRDD: SparkListenerUnpersistRDD =>
+ sparkListeners.foreach(_.onUnpersistRDD(unpersistRDD))
+ case applicationStart: SparkListenerApplicationStart =>
+ sparkListeners.foreach(_.onApplicationStart(applicationStart))
+ case applicationEnd: SparkListenerApplicationEnd =>
+ sparkListeners.foreach(_.onApplicationEnd(applicationEnd))
+ case SparkListenerShutdown =>
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala
index 5b40a3eb29b30..1ce83485f024b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala
@@ -19,10 +19,17 @@ package org.apache.spark.scheduler
import collection.mutable.ArrayBuffer
+import org.apache.spark.annotation.DeveloperApi
+
// information about a specific split instance : handles both split instances.
// So that we do not need to worry about the differences.
-class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String,
- val length: Long, val underlyingSplit: Any) {
+@DeveloperApi
+class SplitInfo(
+ val inputFormatClazz: Class[_],
+ val hostLocation: String,
+ val path: String,
+ val length: Long,
+ val underlyingSplit: Any) {
override def toString(): String = {
"SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz +
", hostLocation : " + hostLocation + ", path : " + path +
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index a78b0186b9eab..5c1fc30e4a557 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -100,7 +100,7 @@ private[spark] class Stage(
id
}
- val name = callSite.getOrElse(rdd.origin)
+ val name = callSite.getOrElse(rdd.getCreationSite)
override def toString = "Stage " + id
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 8f320e5c7a74b..b42e231e11f91 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -17,28 +17,41 @@
package org.apache.spark.scheduler
-import scala.collection._
-
-import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.storage.RDDInfo
/**
+ * :: DeveloperApi ::
* Stores information about a stage to pass from the scheduler to SparkListeners.
- *
- * taskInfos stores the metrics for all tasks that have completed, including redundant, speculated
- * tasks.
*/
-class StageInfo(
- stage: Stage,
- val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] =
- mutable.Buffer[(TaskInfo, TaskMetrics)]()
-) {
- val stageId = stage.id
+@DeveloperApi
+class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo]) {
/** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
var submissionTime: Option[Long] = None
+ /** Time when all tasks in the stage completed or when the stage was cancelled. */
var completionTime: Option[Long] = None
- val rddName = stage.rdd.name
- val name = stage.name
- val numPartitions = stage.numPartitions
- val numTasks = stage.numTasks
+ /** If the stage failed, the reason why. */
+ var failureReason: Option[String] = None
+
var emittedTaskSizeWarning = false
+
+ def stageFailed(reason: String) {
+ failureReason = Some(reason)
+ completionTime = Some(System.currentTimeMillis)
+ }
+}
+
+private[spark] object StageInfo {
+ /**
+ * Construct a StageInfo from a Stage.
+ *
+ * Each Stage is associated with one or many RDDs, with the boundary of a Stage marked by
+ * shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a
+ * sequence of narrow dependencies should also be associated with this Stage.
+ */
+ def fromStage(stage: Stage): StageInfo = {
+ val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd)
+ val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos
+ new StageInfo(stage.id, stage.name, stage.numTasks, rddInfos)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index b85b4a50cd93a..2ca3479c80efc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -17,13 +17,11 @@
package org.apache.spark.scheduler
-import java.io.{DataInputStream, DataOutputStream}
+import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
import scala.collection.mutable.HashMap
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-
import org.apache.spark.TaskContext
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
@@ -46,8 +44,9 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
+ taskThread = Thread.currentThread()
if (_killed) {
- kill()
+ kill(interruptThread = false)
}
runTask(context)
}
@@ -64,6 +63,9 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
// Task context, to be initialized in run().
@transient protected var context: TaskContext = _
+ // The actual Thread on which the task is running, if any. Initialized in run().
+ @volatile @transient private var taskThread: Thread = _
+
// A flag to indicate whether the task is killed. This is used in case context is not yet
// initialized when kill() is invoked.
@volatile @transient private var _killed = false
@@ -77,12 +79,16 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
* Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
* code and user code to properly handle the flag. This function should be idempotent so it can
* be called multiple times.
+ * If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread.
*/
- def kill() {
+ def kill(interruptThread: Boolean) {
_killed = true
if (context != null) {
context.interrupted = true
}
+ if (interruptThread && taskThread != null) {
+ taskThread.interrupt()
+ }
}
}
@@ -104,7 +110,7 @@ private[spark] object Task {
serializer: SerializerInstance)
: ByteBuffer = {
- val out = new FastByteArrayOutputStream(4096)
+ val out = new ByteArrayOutputStream(4096)
val dataOut = new DataOutputStream(out)
// Write currentFiles
@@ -125,8 +131,7 @@ private[spark] object Task {
dataOut.flush()
val taskBytes = serializer.serialize(task).array()
out.write(taskBytes)
- out.trim()
- ByteBuffer.wrap(out.array)
+ ByteBuffer.wrap(out.toByteArray)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 6183b125def99..4c62e4dc0bac8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -17,10 +17,13 @@
package org.apache.spark.scheduler
+import org.apache.spark.annotation.DeveloperApi
+
/**
+ * :: DeveloperApi ::
* Information about a running task attempt inside a TaskSet.
*/
-private[spark]
+@DeveloperApi
class TaskInfo(
val taskId: Long,
val index: Int,
@@ -46,15 +49,15 @@ class TaskInfo(
var serializedSize: Int = 0
- def markGettingResult(time: Long = System.currentTimeMillis) {
+ private[spark] def markGettingResult(time: Long = System.currentTimeMillis) {
gettingResultTime = time
}
- def markSuccessful(time: Long = System.currentTimeMillis) {
+ private[spark] def markSuccessful(time: Long = System.currentTimeMillis) {
finishTime = time
}
- def markFailed(time: Long = System.currentTimeMillis) {
+ private[spark] def markFailed(time: Long = System.currentTimeMillis) {
finishTime = time
failed = true
}
@@ -83,11 +86,11 @@ class TaskInfo(
def duration: Long = {
if (!finished) {
- throw new UnsupportedOperationException("duration() called on unfinished tasks")
+ throw new UnsupportedOperationException("duration() called on unfinished task")
} else {
finishTime - launchTime
}
}
- def timeRunning(currentTime: Long): Long = currentTime - launchTime
+ private[spark] def timeRunning(currentTime: Long): Long = currentTime - launchTime
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
index ea3229b75be36..eb920ab0c0b67 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
@@ -17,8 +17,11 @@
package org.apache.spark.scheduler
-private[spark] object TaskLocality extends Enumeration {
- // process local is expected to be used ONLY within tasksetmanager for now.
+import org.apache.spark.annotation.DeveloperApi
+
+@DeveloperApi
+object TaskLocality extends Enumeration {
+ // Process local is expected to be used ONLY within TaskSetManager for now.
val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
type TaskLocality = Value
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index cb4ad4ae9350c..c9ad2b151daf0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -85,13 +85,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
try {
if (serializedData != null && serializedData.limit() > 0) {
reason = serializer.get().deserialize[TaskEndReason](
- serializedData, getClass.getClassLoader)
+ serializedData, Utils.getSparkClassLoader)
}
} catch {
case cnd: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastropic if we can't
// deserialize the reason.
- val loader = Thread.currentThread.getContextClassLoader
+ val loader = Utils.getContextOrSparkClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
case ex: Throwable => {}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 1cdfed1d7005e..819c35257b5a7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -20,7 +20,7 @@ package org.apache.spark.scheduler
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
- * Low-level task scheduler interface, currently implemented exclusively by the ClusterScheduler.
+ * Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl.
* This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks
* for a single SparkContext. These schedulers get sets of tasks submitted to them from the
* DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running
@@ -47,7 +47,7 @@ private[spark] trait TaskScheduler {
def submitTasks(taskSet: TaskSet): Unit
// Cancel a stage.
- def cancelTasks(stageId: Int)
+ def cancelTasks(stageId: Int, interruptThread: Boolean)
// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 8df37c247d0d4..5a68f38bc5844 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -25,6 +25,8 @@ import scala.concurrent.duration._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
+import scala.language.postfixOps
+import scala.util.Random
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
@@ -41,7 +43,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
*
* THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
* threads, so it needs locks in public API methods to maintain its state. In addition, some
- * SchedulerBackends sycnchronize on themselves when they want to send events here, and then
+ * SchedulerBackends synchronize on themselves when they want to send events here, and then
* acquire a lock on us, so we need to make sure that we don't try to lock the backend while
* we are holding a lock on ourselves.
*/
@@ -61,6 +63,9 @@ private[spark] class TaskSchedulerImpl(
// Threshold above which we warn user initial TaskSet may be starved
val STARVATION_TIMEOUT = conf.getLong("spark.starvation.timeout", 15000)
+ // CPUs to request per task
+ val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1)
+
// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
val activeTaskSets = new HashMap[String, TaskSetManager]
@@ -94,8 +99,13 @@ private[spark] class TaskSchedulerImpl(
var schedulableBuilder: SchedulableBuilder = null
var rootPool: Pool = null
// default scheduler is FIFO
- val schedulingMode: SchedulingMode = SchedulingMode.withName(
- conf.get("spark.scheduler.mode", "FIFO"))
+ private val schedulingModeConf = conf.get("spark.scheduler.mode", "FIFO")
+ val schedulingMode: SchedulingMode = try {
+ SchedulingMode.withName(schedulingModeConf.toUpperCase)
+ } catch {
+ case e: java.util.NoSuchElementException =>
+ throw new SparkException(s"Urecognized spark.scheduler.mode: $schedulingModeConf")
+ }
// This is a var so that we can reset it for testing purposes.
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
@@ -160,7 +170,7 @@ private[spark] class TaskSchedulerImpl(
backend.reviveOffers()
}
- override def cancelTasks(stageId: Int): Unit = synchronized {
+ override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
// There are two possible cases here:
@@ -171,7 +181,7 @@ private[spark] class TaskSchedulerImpl(
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
- backend.killTask(tid, execId)
+ backend.killTask(tid, execId, interruptThread)
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
@@ -203,13 +213,15 @@ private[spark] class TaskSchedulerImpl(
executorIdToHost(o.executorId) = o.host
if (!executorsByHost.contains(o.host)) {
executorsByHost(o.host) = new HashSet[String]()
- executorGained(o.executorId, o.host)
+ executorAdded(o.executorId, o.host)
}
}
- // Build a list of tasks to assign to each worker
- val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
- val availableCpus = offers.map(o => o.cores).toArray
+ // Randomly shuffle offers to avoid always placing tasks on the same set of workers.
+ val shuffledOffers = Random.shuffle(offers)
+ // Build a list of tasks to assign to each worker.
+ val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
+ val availableCpus = shuffledOffers.map(o => o.cores).toArray
val sortedTaskSets = rootPool.getSortedTaskSetQueue()
for (taskSet <- sortedTaskSets) {
logDebug("parentName: %s, name: %s, runningTasks: %s".format(
@@ -222,18 +234,21 @@ private[spark] class TaskSchedulerImpl(
for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) {
do {
launchedTask = false
- for (i <- 0 until offers.size) {
- val execId = offers(i).executorId
- val host = offers(i).host
- for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) {
- tasks(i) += task
- val tid = task.taskId
- taskIdToTaskSetId(tid) = taskSet.taskSet.id
- taskIdToExecutorId(tid) = execId
- activeExecutorIds += execId
- executorsByHost(host) += execId
- availableCpus(i) -= 1
- launchedTask = true
+ for (i <- 0 until shuffledOffers.size) {
+ val execId = shuffledOffers(i).executorId
+ val host = shuffledOffers(i).host
+ if (availableCpus(i) >= CPUS_PER_TASK) {
+ for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
+ tasks(i) += task
+ val tid = task.taskId
+ taskIdToTaskSetId(tid) = taskSet.taskSet.id
+ taskIdToExecutorId(tid) = execId
+ activeExecutorIds += execId
+ executorsByHost(host) += execId
+ availableCpus(i) -= CPUS_PER_TASK
+ assert (availableCpus(i) >= 0)
+ launchedTask = true
+ }
}
}
} while (launchedTask)
@@ -341,6 +356,7 @@ private[spark] class TaskSchedulerImpl(
if (taskResultGetter != null) {
taskResultGetter.stop()
}
+ starvationTimer.cancel()
// sleeping for an arbitrary 1 seconds to ensure that messages are sent out.
Thread.sleep(1000L)
@@ -396,8 +412,8 @@ private[spark] class TaskSchedulerImpl(
rootPool.executorLost(executorId, host)
}
- def executorGained(execId: String, host: String) {
- dagScheduler.executorGained(execId, host)
+ def executorAdded(execId: String, host: String) {
+ dagScheduler.executorAdded(execId, host)
}
def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
index 03bf76083761f..613fa7850bb25 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
@@ -31,8 +31,8 @@ private[spark] class TaskSet(
val properties: Properties) {
val id: String = stageId + "." + attempt
- def kill() {
- tasks.foreach(_.kill())
+ def kill(interruptThread: Boolean) {
+ tasks.foreach(_.kill(interruptThread))
}
override def toString: String = "TaskSet " + id
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 1a4b7e599c01e..f3bd0797aa035 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -26,13 +26,14 @@ import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
-import org.apache.spark.{ExceptionFailure, ExecutorLostFailure, FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
+import org.apache.spark.{ExceptionFailure, ExecutorLostFailure, FetchFailed, Logging, Resubmitted,
+ SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.util.{Clock, SystemClock}
/**
- * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of
+ * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of
* each task, retries tasks if they fail (up to a limited number of times), and
* handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
* to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
@@ -41,7 +42,7 @@ import org.apache.spark.util.{Clock, SystemClock}
* THREADING: This class is designed to only be called from code with a lock on the
* TaskScheduler (e.g. its event handlers). It should not be called from other threads.
*
- * @param sched the ClusterScheduler associated with the TaskSetManager
+ * @param sched the TaskSchedulerImpl associated with the TaskSetManager
* @param taskSet the TaskSet to manage scheduling for
* @param maxTaskFailures if any particular task fails more than this number of times, the entire
* task set will be aborted
@@ -55,8 +56,14 @@ private[spark] class TaskSetManager(
{
val conf = sched.sc.conf
- // CPUs to request per task
- val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1)
+ /*
+ * Sometimes if an executor is dead or in an otherwise invalid state, the driver
+ * does not realize right away leading to repeated task failures. If enabled,
+ * this temporarily prevents a task from re-launching on an executor where
+ * it just failed.
+ */
+ private val EXECUTOR_TASK_BLACKLIST_TIMEOUT =
+ conf.getLong("spark.scheduler.executorTaskBlacklistTime", 0L)
// Quantile of tasks at which to start speculation
val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75)
@@ -70,7 +77,9 @@ private[spark] class TaskSetManager(
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)
val successful = new Array[Boolean](numTasks)
- val numFailures = new Array[Int](numTasks)
+ private val numFailures = new Array[Int](numTasks)
+ // key is taskId, value is a Map of executor id to when it failed
+ private val failedExecutors = new HashMap[Int, HashMap[String, Long]]()
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
var tasksSuccessful = 0
@@ -227,12 +236,18 @@ private[spark] class TaskSetManager(
* This method also cleans up any tasks in the list that have already
* been launched, since we want that to happen lazily.
*/
- private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
- while (!list.isEmpty) {
- val index = list.last
- list.trimEnd(1)
- if (copiesRunning(index) == 0 && !successful(index)) {
- return Some(index)
+ private def findTaskFromList(execId: String, list: ArrayBuffer[Int]): Option[Int] = {
+ var indexOffset = list.size
+
+ while (indexOffset > 0) {
+ indexOffset -= 1
+ val index = list(indexOffset)
+ if (!executorIsBlacklisted(execId, index)) {
+ // This should almost always be list.trimEnd(1) to remove tail
+ list.remove(indexOffset)
+ if (copiesRunning(index) == 0 && !successful(index)) {
+ return Some(index)
+ }
}
}
None
@@ -243,6 +258,21 @@ private[spark] class TaskSetManager(
taskAttempts(taskIndex).exists(_.host == host)
}
+ /**
+ * Is this re-execution of a failed task on an executor it already failed in before
+ * EXECUTOR_TASK_BLACKLIST_TIMEOUT has elapsed ?
+ */
+ private def executorIsBlacklisted(execId: String, taskId: Int): Boolean = {
+ if (failedExecutors.contains(taskId)) {
+ val failed = failedExecutors.get(taskId).get
+
+ return failed.contains(execId) &&
+ clock.getTime() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT
+ }
+
+ false
+ }
+
/**
* Return a speculative task for a given executor if any are available. The task should not have
* an attempt running on this host, in case the host is slow. In addition, the task should meet
@@ -253,10 +283,13 @@ private[spark] class TaskSetManager(
{
speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
+ def canRunOnHost(index: Int): Boolean =
+ !hasAttemptOnHost(index, host) && !executorIsBlacklisted(execId, index)
+
if (!speculatableTasks.isEmpty) {
// Check for process-local or preference-less tasks; note that tasks can be process-local
// on multiple nodes when we replicate cached blocks, as in Spark Streaming
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ for (index <- speculatableTasks if canRunOnHost(index)) {
val prefs = tasks(index).preferredLocations
val executors = prefs.flatMap(_.executorId)
if (prefs.size == 0 || executors.contains(execId)) {
@@ -267,7 +300,7 @@ private[spark] class TaskSetManager(
// Check for node-local tasks
if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ for (index <- speculatableTasks if canRunOnHost(index)) {
val locations = tasks(index).preferredLocations.map(_.host)
if (locations.contains(host)) {
speculatableTasks -= index
@@ -279,7 +312,7 @@ private[spark] class TaskSetManager(
// Check for rack-local tasks
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
for (rack <- sched.getRackForHost(host)) {
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ for (index <- speculatableTasks if canRunOnHost(index)) {
val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
if (racks.contains(rack)) {
speculatableTasks -= index
@@ -291,7 +324,7 @@ private[spark] class TaskSetManager(
// Check for non-local tasks
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ for (index <- speculatableTasks if canRunOnHost(index)) {
speculatableTasks -= index
return Some((index, TaskLocality.ANY))
}
@@ -308,12 +341,12 @@ private[spark] class TaskSetManager(
private def findTask(execId: String, host: String, locality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value)] =
{
- for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
+ for (index <- findTaskFromList(execId, getPendingTasksForExecutor(execId))) {
return Some((index, TaskLocality.PROCESS_LOCAL))
}
if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
- for (index <- findTaskFromList(getPendingTasksForHost(host))) {
+ for (index <- findTaskFromList(execId, getPendingTasksForHost(host))) {
return Some((index, TaskLocality.NODE_LOCAL))
}
}
@@ -321,19 +354,19 @@ private[spark] class TaskSetManager(
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
for {
rack <- sched.getRackForHost(host)
- index <- findTaskFromList(getPendingTasksForRack(rack))
+ index <- findTaskFromList(execId, getPendingTasksForRack(rack))
} {
return Some((index, TaskLocality.RACK_LOCAL))
}
}
// Look for no-pref tasks after rack-local tasks since they can run anywhere.
- for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
+ for (index <- findTaskFromList(execId, pendingTasksWithNoPrefs)) {
return Some((index, TaskLocality.PROCESS_LOCAL))
}
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
- for (index <- findTaskFromList(allPendingTasks)) {
+ for (index <- findTaskFromList(execId, allPendingTasks)) {
return Some((index, TaskLocality.ANY))
}
}
@@ -348,11 +381,10 @@ private[spark] class TaskSetManager(
def resourceOffer(
execId: String,
host: String,
- availableCpus: Int,
maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription] =
{
- if (!isZombie && availableCpus >= CPUS_PER_TASK) {
+ if (!isZombie) {
val curTime = clock.getTime()
var allowedLocality = getAllowedLocalityLevel(curTime)
@@ -433,7 +465,7 @@ private[spark] class TaskSetManager(
def handleTaskGettingResult(tid: Long) = {
val info = taskInfos(tid)
info.markGettingResult()
- sched.dagScheduler.taskGettingResult(tasks(info.index), info)
+ sched.dagScheduler.taskGettingResult(info)
}
/**
@@ -459,6 +491,7 @@ private[spark] class TaskSetManager(
logInfo("Ignorning task-finished event for TID " + tid + " because task " +
index + " has already completed successfully")
}
+ failedExecutors.remove(index)
maybeFinishTaskSet()
}
@@ -479,7 +512,7 @@ private[spark] class TaskSetManager(
logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
}
var taskMetrics : TaskMetrics = null
- var failureReason = "unknown"
+ var failureReason: String = null
reason match {
case fetchFailed: FetchFailed =>
logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
@@ -487,9 +520,11 @@ private[spark] class TaskSetManager(
successful(index) = true
tasksSuccessful += 1
}
+ // Not adding to failed executors for FetchFailed.
isZombie = true
case TaskKilled =>
+ // Not adding to failed executors for TaskKilled.
logWarning("Task %d was killed.".format(tid))
case ef: ExceptionFailure =>
@@ -503,7 +538,8 @@ private[spark] class TaskSetManager(
return
}
val key = ef.description
- failureReason = "Exception failure: %s".format(ef.description)
+ failureReason = "Exception failure in TID %s on host %s: %s\n%s".format(
+ tid, info.host, ef.description, ef.stackTrace.map(" " + _).mkString("\n"))
val now = clock.getTime()
val (printFull, dupCount) = {
if (recentExceptions.contains(key)) {
@@ -532,16 +568,21 @@ private[spark] class TaskSetManager(
failureReason = "Lost result for TID %s on host %s".format(tid, info.host)
logWarning(failureReason)
- case _ => {}
+ case _ =>
+ failureReason = "TID %s on host %s failed for unknown reason".format(tid, info.host)
}
+ // always add to failed executors
+ failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()).
+ put(info.executorId, clock.getTime())
sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics)
addPendingTask(index)
if (!isZombie && state != TaskState.KILLED) {
+ assert (null != failureReason)
numFailures(index) += 1
if (numFailures(index) >= maxTaskFailures) {
logError("Task %s:%d failed %d times; aborting job".format(
taskSet.id, index, maxTaskFailures))
- abort("Task %s:%d failed %d times (most recent failure: %s)".format(
+ abort("Task %s:%d failed %d times, most recent failure: %s\nDriver stacktrace:".format(
taskSet.id, index, maxTaskFailures, failureReason))
return
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
index ba6bab3f91a65..810b36cddf835 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
@@ -21,4 +21,4 @@ package org.apache.spark.scheduler
* Represents free resources available on an executor.
*/
private[spark]
-class WorkerOffer(val executorId: String, val host: String, val cores: Int)
+case class WorkerOffer(executorId: String, host: String, cores: Int)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 4a9a1659d8254..ddbc74e82ac49 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -30,7 +30,8 @@ private[spark] object CoarseGrainedClusterMessages {
// Driver to executors
case class LaunchTask(task: TaskDescription) extends CoarseGrainedClusterMessage
- case class KillTask(taskId: Long, executor: String) extends CoarseGrainedClusterMessage
+ case class KillTask(taskId: Long, executor: String, interruptThread: Boolean)
+ extends CoarseGrainedClusterMessage
case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
extends CoarseGrainedClusterMessage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 379e02eb9a437..a6d6b3d26a3c6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -54,6 +54,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
private val executorAddress = new HashMap[String, Address]
private val executorHost = new HashMap[String, String]
private val freeCores = new HashMap[String, Int]
+ private val totalCores = new HashMap[String, Int]
private val addressToExecutorId = new HashMap[Address, String]
override def preStart() {
@@ -76,6 +77,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
sender ! RegisteredExecutor(sparkProperties)
executorActor(executorId) = sender
executorHost(executorId) = Utils.parseHostPort(hostPort)._1
+ totalCores(executorId) = cores
freeCores(executorId) = cores
executorAddress(executorId) = sender.path.address
addressToExecutorId(sender.path.address) = executorId
@@ -87,7 +89,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
if (executorActor.contains(executorId)) {
- freeCores(executorId) += 1
+ freeCores(executorId) += scheduler.CPUS_PER_TASK
makeOffers(executorId)
} else {
// Ignoring the update since we don't know about the executor.
@@ -99,8 +101,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
case ReviveOffers =>
makeOffers()
- case KillTask(taskId, executorId) =>
- executorActor(executorId) ! KillTask(taskId, executorId)
+ case KillTask(taskId, executorId, interruptThread) =>
+ executorActor(executorId) ! KillTask(taskId, executorId, interruptThread)
case StopDriver =>
sender ! true
@@ -138,7 +140,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
// Launch tasks returned by a set of resource offers
def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
for (task <- tasks.flatten) {
- freeCores(task.executorId) -= 1
+ freeCores(task.executorId) -= scheduler.CPUS_PER_TASK
executorActor(task.executorId) ! LaunchTask(task)
}
}
@@ -147,10 +149,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
def removeExecutor(executorId: String, reason: String) {
if (executorActor.contains(executorId)) {
logInfo("Executor " + executorId + " disconnected, so removing it")
- val numCores = freeCores(executorId)
- addressToExecutorId -= executorAddress(executorId)
+ val numCores = totalCores(executorId)
executorActor -= executorId
executorHost -= executorId
+ addressToExecutorId -= executorAddress(executorId)
+ executorAddress -= executorId
+ totalCores -= executorId
freeCores -= executorId
totalCoreCount.addAndGet(-numCores)
scheduler.executorLost(executorId, SlaveLost(reason))
@@ -168,7 +172,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
properties += ((key, value))
}
}
- //TODO (prashant) send conf instead of properties
+ // TODO (prashant) send conf instead of properties
driverActor = actorSystem.actorOf(
Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME)
}
@@ -203,8 +207,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
driverActor ! ReviveOffers
}
- override def killTask(taskId: Long, executorId: String) {
- driverActor ! KillTask(taskId, executorId)
+ override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
+ driverActor ! KillTask(taskId, executorId, interruptThread)
}
override def defaultParallelism(): Int = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index ee4b65e312abc..9544ca05dca70 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -18,7 +18,7 @@
package org.apache.spark.scheduler.cluster
import org.apache.spark.{Logging, SparkContext}
-import org.apache.spark.deploy.{Command, ApplicationDescription}
+import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.deploy.client.{AppClient, AppClientListener}
import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl}
import org.apache.spark.util.Utils
@@ -26,8 +26,7 @@ import org.apache.spark.util.Utils
private[spark] class SparkDeploySchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext,
- masters: Array[String],
- appName: String)
+ masters: Array[String])
extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
with AppClientListener
with Logging {
@@ -43,14 +42,23 @@ private[spark] class SparkDeploySchedulerBackend(
// The endpoint for executors to talk to us
val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
- conf.get("spark.driver.host"), conf.get("spark.driver.port"),
+ conf.get("spark.driver.host"), conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}")
+ val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
+ val classPathEntries = sys.props.get("spark.executor.extraClassPath").toSeq.flatMap { cp =>
+ cp.split(java.io.File.pathSeparator)
+ }
+ val libraryPathEntries = sys.props.get("spark.executor.extraLibraryPath").toSeq.flatMap { cp =>
+ cp.split(java.io.File.pathSeparator)
+ }
+
val command = Command(
- "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs)
+ "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.testExecutorEnvs,
+ classPathEntries, libraryPathEntries, extraJavaOpts)
val sparkHome = sc.getSparkHome()
- val appDesc = new ApplicationDescription(appName, maxCores, sc.executorMemory, command,
- sparkHome, "http://" + sc.ui.appUIAddress)
+ val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command,
+ sparkHome, sc.ui.appUIAddress, sc.eventLogger.map(_.logDir))
client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf)
client.start()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 28b019d9fd495..2cd9d6c12eaf7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -45,8 +45,7 @@ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
private[spark] class CoarseMesosSchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext,
- master: String,
- appName: String)
+ master: String)
extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
with MScheduler
with Logging {
@@ -94,7 +93,7 @@ private[spark] class CoarseMesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = CoarseMesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try { {
val ret = driver.run()
@@ -112,7 +111,18 @@ private[spark] class CoarseMesosSchedulerBackend(
def createCommand(offer: Offer, numCores: Int): CommandInfo = {
val environment = Environment.newBuilder()
- sc.executorEnvs.foreach { case (key, value) =>
+ val extraClassPath = conf.getOption("spark.executor.extraClassPath")
+ extraClassPath.foreach { cp =>
+ environment.addVariables(
+ Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build())
+ }
+ val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions")
+
+ val libraryPathOption = "spark.executor.extraLibraryPath"
+ val extraLibraryPath = conf.getOption(libraryPathOption).map(p => s"-Djava.library.path=$p")
+ val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ")
+
+ sc.testExecutorEnvs.foreach { case (key, value) =>
environment.addVariables(Environment.Variable.newBuilder()
.setName(key)
.setValue(value)
@@ -124,20 +134,22 @@ private[spark] class CoarseMesosSchedulerBackend(
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
val uri = conf.get("spark.executor.uri", null)
if (uri == null) {
val runScript = new File(sparkHome, "./bin/spark-class").getCanonicalPath
command.setValue(
- "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format(
- runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
+ "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %s %d".format(
+ runScript, extraOpts, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
command.setValue(
("cd %s*; " +
- "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d")
- .format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
+ "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %s %d")
+ .format(basename, extraOpts, driverUrl, offer.getSlaveId.getValue,
+ offer.getHostname, numCores))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
command.build()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index c576beb0c0d38..c975f312324ed 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -41,8 +41,7 @@ import org.apache.spark.util.Utils
private[spark] class MesosSchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext,
- master: String,
- appName: String)
+ master: String)
extends SchedulerBackend
with MScheduler
with Logging {
@@ -71,7 +70,7 @@ private[spark] class MesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = MesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try {
val ret = driver.run()
@@ -91,7 +90,7 @@ private[spark] class MesosSchedulerBackend(
"Spark home is not set; set it through the spark.home system " +
"property, the SPARK_HOME environment variable or the SparkContext constructor"))
val environment = Environment.newBuilder()
- sc.executorEnvs.foreach { case (key, value) =>
+ sc.testExecutorEnvs.foreach { case (key, value) =>
environment.addVariables(Environment.Variable.newBuilder()
.setName(key)
.setValue(value)
@@ -203,7 +202,7 @@ private[spark] class MesosSchedulerBackend(
getResource(offer.getResourcesList, "cpus").toInt)
}
- // Call into the ClusterScheduler
+ // Call into the TaskSchedulerImpl
val taskLists = scheduler.resourceOffers(offerableWorkers)
// Build a list of Mesos tasks for each slave
@@ -247,7 +246,7 @@ private[spark] class MesosSchedulerBackend(
val cpuResource = Resource.newBuilder()
.setName("cpus")
.setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(1).build())
+ .setScalar(Value.Scalar.newBuilder().setValue(scheduler.CPUS_PER_TASK).build())
.build()
MesosTaskInfo.newBuilder()
.setTaskId(taskId)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index 50f7e79e97dd8..43f0e18a0cbe0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -30,12 +30,12 @@ private case class ReviveOffers()
private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
-private case class KillTask(taskId: Long)
+private case class KillTask(taskId: Long, interruptThread: Boolean)
/**
* Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on
* LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend
- * and the ClusterScheduler.
+ * and the TaskSchedulerImpl.
*/
private[spark] class LocalActor(
scheduler: TaskSchedulerImpl,
@@ -61,8 +61,8 @@ private[spark] class LocalActor(
reviveOffers()
}
- case KillTask(taskId) =>
- executor.killTask(taskId)
+ case KillTask(taskId, interruptThread) =>
+ executor.killTask(taskId, interruptThread)
}
def reviveOffers() {
@@ -76,7 +76,7 @@ private[spark] class LocalActor(
/**
* LocalBackend is used when running a local version of Spark where the executor, backend, and
- * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks
+ * master all run in the same JVM. It sits behind a TaskSchedulerImpl and handles launching tasks
* on a single Executor (created by the LocalBackend) running locally.
*/
private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int)
@@ -99,8 +99,8 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores:
override def defaultParallelism() = totalCores
- override def killTask(taskId: Long, executorId: String) {
- localActor ! KillTask(taskId)
+ override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
+ localActor ! KillTask(taskId, interruptThread)
}
override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 33c1705ad7c58..e9163deaf2036 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -21,18 +21,38 @@ import java.io._
import java.nio.ByteBuffer
import org.apache.spark.SparkConf
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.ByteBufferInputStream
+import org.apache.spark.util.Utils
-private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream {
- val objOut = new ObjectOutputStream(out)
- def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this }
+private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int)
+ extends SerializationStream {
+ private val objOut = new ObjectOutputStream(out)
+ private var counter = 0
+
+ /**
+ * Calling reset to avoid memory leak:
+ * http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api
+ * But only call it every 10,000th time to avoid bloated serialization streams (when
+ * the stream 'resets' object class descriptions have to be re-written)
+ */
+ def writeObject[T](t: T): SerializationStream = {
+ objOut.writeObject(t)
+ if (counterReset > 0 && counter >= counterReset) {
+ objOut.reset()
+ counter = 0
+ } else {
+ counter += 1
+ }
+ this
+ }
def flush() { objOut.flush() }
def close() { objOut.close() }
}
private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
extends DeserializationStream {
- val objIn = new ObjectInputStream(in) {
+ private val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, loader)
}
@@ -41,7 +61,7 @@ extends DeserializationStream {
def close() { objIn.close() }
}
-private[spark] class JavaSerializerInstance extends SerializerInstance {
+private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
def serialize[T](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
@@ -63,11 +83,11 @@ private[spark] class JavaSerializerInstance extends SerializerInstance {
}
def serializeStream(s: OutputStream): SerializationStream = {
- new JavaSerializationStream(s)
+ new JavaSerializationStream(s, counterReset)
}
def deserializeStream(s: InputStream): DeserializationStream = {
- new JavaDeserializationStream(s, Thread.currentThread.getContextClassLoader)
+ new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader)
}
def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
@@ -76,8 +96,24 @@ private[spark] class JavaSerializerInstance extends SerializerInstance {
}
/**
+ * :: DeveloperApi ::
* A Spark serializer that uses Java's built-in serialization.
+ *
+ * Note that this serializer is not guaranteed to be wire-compatible across different versions of
+ * Spark. It is intended to be used to serialize/de-serialize data within a single
+ * Spark application.
*/
-class JavaSerializer(conf: SparkConf) extends Serializer {
- def newInstance(): SerializerInstance = new JavaSerializerInstance
+@DeveloperApi
+class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
+ private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000)
+
+ def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset)
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeInt(counterReset)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ counterReset = in.readInt()
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 920490f9d0d61..c4daec7875d26 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -33,11 +33,19 @@ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock}
/**
* A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
+ *
+ * Note that this serializer is not guaranteed to be wire-compatible across different versions of
+ * Spark. It is intended to be used to serialize/de-serialize data within a single
+ * Spark application.
*/
-class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer with Logging {
- private val bufferSize = {
- conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024
- }
+class KryoSerializer(conf: SparkConf)
+ extends org.apache.spark.serializer.Serializer
+ with Logging
+ with Serializable {
+
+ private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024
+ private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true)
+ private val registrator = conf.getOption("spark.kryo.registrator")
def newKryoOutput() = new KryoOutput(bufferSize)
@@ -48,9 +56,11 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
// Do this before we invoke the user registrator so the user registrator can override this.
- kryo.setReferences(conf.getBoolean("spark.kryo.referenceTracking", true))
+ kryo.setReferences(referenceTracking)
- for (cls <- KryoSerializer.toRegister) kryo.register(cls)
+ for (cls <- KryoSerializer.toRegister) {
+ kryo.register(cls)
+ }
// Allow sending SerializableWritable
kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
@@ -58,7 +68,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial
// Allow the user to register their own classes by setting spark.kryo.registrator
try {
- for (regCls <- conf.getOption("spark.kryo.registrator")) {
+ for (regCls <- registrator) {
logDebug("Running user registrator: " + regCls)
val reg = Class.forName(regCls, true, classLoader).newInstance()
.asInstanceOf[KryoRegistrator]
@@ -69,7 +79,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial
}
// Register Chill's classes; we do this after our ranges and the user's own classes to let
- // our code override the generic serialziers in Chill for things like Seq
+ // our code override the generic serializers in Chill for things like Seq
new AllScalaRegistrar().apply(kryo)
kryo.setClassLoader(classLoader)
@@ -103,7 +113,8 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser
kryo.readClassAndObject(input).asInstanceOf[T]
} catch {
// DeserializationStream uses the EOF exception to indicate stopping condition.
- case _: KryoException => throw new EOFException
+ case e: KryoException if e.getMessage.toLowerCase.contains("buffer underflow") =>
+ throw new EOFException
}
}
@@ -167,10 +178,6 @@ private[serializer] object KryoSerializer {
classOf[GetBlock],
classOf[MapStatus],
classOf[BlockManagerId],
- classOf[Array[Byte]],
- (1 to 10).getClass,
- (1 until 10).getClass,
- (1L to 10L).getClass,
- (1L until 10L).getClass
+ classOf[Array[Byte]]
)
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index 16677ab54be04..f2c8f9b6218d6 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -17,30 +17,47 @@
package org.apache.spark.serializer
-import java.io.{EOFException, InputStream, OutputStream}
+import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream}
import java.nio.ByteBuffer
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-
+import org.apache.spark.SparkEnv
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
/**
+ * :: DeveloperApi ::
* A serializer. Because some serialization libraries are not thread safe, this class is used to
* create [[org.apache.spark.serializer.SerializerInstance]] objects that do the actual
* serialization and are guaranteed to only be called from one thread at a time.
*
- * Implementations of this trait should have a zero-arg constructor or a constructor that accepts a
- * [[org.apache.spark.SparkConf]] as parameter. If both constructors are defined, the latter takes
- * precedence.
+ * Implementations of this trait should implement:
+ *
+ * 1. a zero-arg constructor or a constructor that accepts a [[org.apache.spark.SparkConf]]
+ * as parameter. If both constructors are defined, the latter takes precedence.
+ *
+ * 2. Java serialization interface.
+ *
+ * Note that serializers are not required to be wire-compatible across different versions of Spark.
+ * They are intended to be used to serialize/de-serialize data within a single Spark application.
*/
+@DeveloperApi
trait Serializer {
def newInstance(): SerializerInstance
}
+object Serializer {
+ def getSerializer(serializer: Serializer): Serializer = {
+ if (serializer == null) SparkEnv.get.serializer else serializer
+ }
+}
+
+
/**
+ * :: DeveloperApi ::
* An instance of a serializer, for use by one thread at a time.
*/
+@DeveloperApi
trait SerializerInstance {
def serialize[T](t: T): ByteBuffer
@@ -54,10 +71,9 @@ trait SerializerInstance {
def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
// Default implementation uses serializeStream
- val stream = new FastByteArrayOutputStream()
+ val stream = new ByteArrayOutputStream()
serializeStream(stream).writeAll(iterator)
- val buffer = ByteBuffer.allocate(stream.position.toInt)
- buffer.put(stream.array, 0, stream.position.toInt)
+ val buffer = ByteBuffer.wrap(stream.toByteArray)
buffer.flip()
buffer
}
@@ -71,8 +87,10 @@ trait SerializerInstance {
/**
+ * :: DeveloperApi ::
* A stream for writing serialized objects.
*/
+@DeveloperApi
trait SerializationStream {
def writeObject[T](t: T): SerializationStream
def flush(): Unit
@@ -88,8 +106,10 @@ trait SerializationStream {
/**
+ * :: DeveloperApi ::
* A stream for reading serialized objects.
*/
+@DeveloperApi
trait DeserializationStream {
def readObject[T](): T
def close(): Unit
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
deleted file mode 100644
index 65ac0155f45e7..0000000000000
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.serializer
-
-import java.util.concurrent.ConcurrentHashMap
-
-import org.apache.spark.SparkConf
-
-/**
- * A service that returns a serializer object given the serializer's class name. If a previous
- * instance of the serializer object has been created, the get method returns that instead of
- * creating a new one.
- */
-private[spark] class SerializerManager {
- // TODO: Consider moving this into SparkConf itself to remove the global singleton.
-
- private val serializers = new ConcurrentHashMap[String, Serializer]
- private var _default: Serializer = _
-
- def default = _default
-
- def setDefault(clsName: String, conf: SparkConf): Serializer = {
- _default = get(clsName, conf)
- _default
- }
-
- def get(clsName: String, conf: SparkConf): Serializer = {
- if (clsName == null) {
- default
- } else {
- var serializer = serializers.get(clsName)
- if (serializer != null) {
- // If the serializer has been created previously, reuse that.
- serializer
- } else this.synchronized {
- // Otherwise, create a new one. But make sure no other thread has attempted
- // to create another new one at the same time.
- serializer = serializers.get(clsName)
- if (serializer == null) {
- val clsLoader = Thread.currentThread.getContextClassLoader
- val cls = Class.forName(clsName, true, clsLoader)
-
- // First try with the constructor that takes SparkConf. If we can't find one,
- // use a no-arg constructor instead.
- try {
- val constructor = cls.getConstructor(classOf[SparkConf])
- serializer = constructor.newInstance(conf).asInstanceOf[Serializer]
- } catch {
- case _: NoSuchMethodException =>
- val constructor = cls.getConstructor()
- serializer = constructor.newInstance().asInstanceOf[Serializer]
- }
-
- serializers.put(clsName, serializer)
- }
- serializer
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala
deleted file mode 100644
index 2e0b0e6eda765..0000000000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala
+++ /dev/null
@@ -1,27 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-private[spark] trait BlockFetchTracker {
- def totalBlocks : Int
- def numLocalBlocks: Int
- def numRemoteBlocks: Int
- def remoteFetchTime : Long
- def fetchWaitTime: Long
- def remoteBytesRead : Long
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index 09736dfadac54..a02dd9441d679 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -44,9 +44,13 @@ import org.apache.spark.util.Utils
*/
private[storage]
-trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])]
- with Logging with BlockFetchTracker {
+trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
def initialize()
+ def totalBlocks: Int
+ def numLocalBlocks: Int
+ def numRemoteBlocks: Int
+ def fetchWaitTime: Long
+ def remoteBytesRead: Long
}
@@ -74,7 +78,6 @@ object BlockFetcherIterator {
import blockManager._
private var _remoteBytesRead = 0L
- private var _remoteFetchTime = 0L
private var _fetchWaitTime = 0L
if (blocksByAddress == null) {
@@ -120,7 +123,6 @@ object BlockFetcherIterator {
future.onSuccess {
case Some(message) => {
val fetchDone = System.currentTimeMillis()
- _remoteFetchTime += fetchDone - fetchStart
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
for (blockMessage <- blockMessageArray) {
@@ -235,7 +237,15 @@ object BlockFetcherIterator {
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
}
- //an iterator that will read fetched blocks off the queue as they arrive.
+ override def totalBlocks: Int = numLocal + numRemote
+ override def numLocalBlocks: Int = numLocal
+ override def numRemoteBlocks: Int = numRemote
+ override def fetchWaitTime: Long = _fetchWaitTime
+ override def remoteBytesRead: Long = _remoteBytesRead
+
+
+ // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue
+ // as they arrive.
@volatile protected var resultsGotten = 0
override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
@@ -253,14 +263,6 @@ object BlockFetcherIterator {
}
(result.blockId, if (result.failed) None else Some(result.deserialize()))
}
-
- // Implementing BlockFetchTracker trait.
- override def totalBlocks: Int = numLocal + numRemote
- override def numLocalBlocks: Int = numLocal
- override def numRemoteBlocks: Int = numRemote
- override def remoteFetchTime: Long = _remoteFetchTime
- override def fetchWaitTime: Long = _fetchWaitTime
- override def remoteBytesRead: Long = _remoteBytesRead
}
// End of BasicBlockFetcherIterator
@@ -284,7 +286,7 @@ object BlockFetcherIterator {
}
} catch {
case x: InterruptedException => logInfo("Copier Interrupted")
- //case _ => throw new SparkException("Exception Throw in Shuffle Copier")
+ // case _ => throw new SparkException("Exception Throw in Shuffle Copier")
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 301d784b350a3..cffea28fbf794 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -34,7 +34,7 @@ private[spark] sealed abstract class BlockId {
def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
def isRDD = isInstanceOf[RDDBlockId]
def isShuffle = isInstanceOf[ShuffleBlockId]
- def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]
+ def isBroadcast = isInstanceOf[BroadcastBlockId]
override def toString = name
override def hashCode = name.hashCode
@@ -48,18 +48,13 @@ private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockI
def name = "rdd_" + rddId + "_" + splitIndex
}
-private[spark]
-case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
+ extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}
-private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
- def name = "broadcast_" + broadcastId
-}
-
-private[spark]
-case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
- def name = broadcastId.name + "_" + hType
+private[spark] case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId {
+ def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
}
private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
@@ -83,8 +78,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
private[spark] object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
- val BROADCAST = "broadcast_([0-9]+)".r
- val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
+ val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
val TEST = "test_(.*)".r
@@ -95,10 +89,8 @@ private[spark] object BlockId {
RDDBlockId(rddId.toInt, splitIndex.toInt)
case SHUFFLE(shuffleId, mapId, reduceId) =>
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
- case BROADCAST(broadcastId) =>
- BroadcastBlockId(broadcastId.toLong)
- case BROADCAST_HELPER(broadcastId, hType) =>
- BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
+ case BROADCAST(broadcastId, field) =>
+ BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
case TASKRESULT(taskId) =>
TaskResultBlockId(taskId.toLong)
case STREAM(streamId, uniqueId) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index d49819125fb12..6d7d4f922e1fa 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,7 +17,7 @@
package org.apache.spark.storage
-import java.io.{File, InputStream, OutputStream}
+import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
import scala.collection.mutable.{ArrayBuffer, HashMap}
@@ -26,22 +26,29 @@ import scala.concurrent.duration._
import scala.util.Random
import akka.actor.{ActorSystem, Cancellable, Props}
-import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import sun.nio.ch.DirectBuffer
-import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
+import org.apache.spark.{Logging, MapOutputTracker, SecurityManager, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
import org.apache.spark.util._
+private[spark] sealed trait Values
+
+private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends Values
+private[spark] case class IteratorValues(iterator: Iterator[Any]) extends Values
+private[spark] case class ArrayBufferValues(buffer: ArrayBuffer[Any]) extends Values
+
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
val master: BlockManagerMaster,
val defaultSerializer: Serializer,
maxMemory: Long,
- _conf: SparkConf)
+ val _conf: SparkConf,
+ securityManager: SecurityManager,
+ mapOutputTracker: MapOutputTracker)
extends Logging {
def conf = _conf
@@ -51,8 +58,19 @@ private[spark] class BlockManager(
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
- private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
+ private[storage] val memoryStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore = new DiskStore(this, diskBlockManager)
+ var tachyonInitialized = false
+ private[storage] lazy val tachyonStore: TachyonStore = {
+ val storeDir = conf.get("spark.tachyonStore.baseDir", "/tmp_spark_tachyon")
+ val appFolderName = conf.get("spark.tachyonStore.folderName")
+ val tachyonStorePath = s"${storeDir}/${appFolderName}/${this.executorId}"
+ val tachyonMaster = conf.get("spark.tachyonStore.url", "tachyon://localhost:19998")
+ val tachyonBlockManager = new TachyonBlockManager(
+ shuffleBlockManager, tachyonStorePath, tachyonMaster)
+ tachyonInitialized = true
+ new TachyonStore(this, tachyonBlockManager)
+ }
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
@@ -61,7 +79,7 @@ private[spark] class BlockManager(
if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
- val connectionManager = new ConnectionManager(0, conf)
+ val connectionManager = new ConnectionManager(0, conf, securityManager)
implicit val futureExecContext = connectionManager.futureExecContext
val blockManagerId = BlockManagerId(
@@ -83,10 +101,10 @@ private[spark] class BlockManager(
val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf)
- val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
+ val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this, mapOutputTracker)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
- // Pending reregistration action being executed asynchronously or null if none
+ // Pending re-registration action being executed asynchronously or null if none
// is pending. Accesses should synchronize on asyncReregisterLock.
var asyncReregisterTask: Future[Unit] = null
val asyncReregisterLock = new Object
@@ -116,9 +134,16 @@ private[spark] class BlockManager(
/**
* Construct a BlockManager with a memory limit set based on system properties.
*/
- def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
- serializer: Serializer, conf: SparkConf) = {
- this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf)
+ def this(
+ execId: String,
+ actorSystem: ActorSystem,
+ master: BlockManagerMaster,
+ serializer: Serializer,
+ conf: SparkConf,
+ securityManager: SecurityManager,
+ mapOutputTracker: MapOutputTracker) = {
+ this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
+ conf, securityManager, mapOutputTracker)
}
/**
@@ -141,14 +166,15 @@ private[spark] class BlockManager(
* an executor crash.
*
* This function deliberately fails silently if the master returns false (indicating that
- * the slave needs to reregister). The error condition will be detected again by the next
- * heart beat attempt or new block registration and another try to reregister all blocks
+ * the slave needs to re-register). The error condition will be detected again by the next
+ * heart beat attempt or new block registration and another try to re-register all blocks
* will be made then.
*/
private def reportAllBlocks() {
logInfo("Reporting " + blockInfo.size + " blocks to the master.")
for ((blockId, info) <- blockInfo) {
- if (!tryToReportBlockStatus(blockId, info)) {
+ val status = getCurrentBlockStatus(blockId, info)
+ if (!tryToReportBlockStatus(blockId, info, status)) {
logError("Failed to report " + blockId + " to master; giving up.")
return
}
@@ -156,20 +182,20 @@ private[spark] class BlockManager(
}
/**
- * Reregister with the master and report all blocks to it. This will be called by the heart beat
+ * Re-register with the master and report all blocks to it. This will be called by the heart beat
* thread if our heartbeat to the block manager indicates that we were not registered.
*
* Note that this method must be called without any BlockInfo locks held.
*/
def reregister() {
- // TODO: We might need to rate limit reregistering.
- logInfo("BlockManager reregistering with master")
+ // TODO: We might need to rate limit re-registering.
+ logInfo("BlockManager re-registering with master")
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
reportAllBlocks()
}
/**
- * Reregister with the master sometime soon.
+ * Re-register with the master sometime soon.
*/
def asyncReregister() {
asyncReregisterLock.synchronized {
@@ -185,7 +211,7 @@ private[spark] class BlockManager(
}
/**
- * For testing. Wait for any pending asynchronous reregistration; otherwise, do nothing.
+ * For testing. Wait for any pending asynchronous re-registration; otherwise, do nothing.
*/
def waitForAsyncReregister() {
val task = asyncReregisterTask
@@ -195,24 +221,45 @@ private[spark] class BlockManager(
}
/**
- * Get storage level of local block. If no info exists for the block, then returns null.
+ * Get the BlockStatus for the block identified by the given ID, if it exists.
+ * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon.
+ */
+ def getStatus(blockId: BlockId): Option[BlockStatus] = {
+ blockInfo.get(blockId).map { info =>
+ val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
+ val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L
+ // Assume that block is not in Tachyon
+ BlockStatus(info.level, memSize, diskSize, 0L)
+ }
+ }
+
+ /**
+ * Get the ids of existing blocks that match the given filter. Note that this will
+ * query the blocks stored in the disk block manager (that the block manager
+ * may not know of).
*/
- def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
+ def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = {
+ (blockInfo.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq
+ }
/**
* Tell the master about the current storage status of a block. This will send a block update
* message reflecting the current status, *not* the desired storage level in its block info.
* For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk.
*
- * droppedMemorySize exists to account for when block is dropped from memory to disk (so it
- * is still valid). This ensures that update in master will compensate for the increase in
+ * droppedMemorySize exists to account for when the block is dropped from memory to disk (so
+ * it is still valid). This ensures that update in master will compensate for the increase in
* memory on slave.
*/
- def reportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L) {
- val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize)
+ def reportBlockStatus(
+ blockId: BlockId,
+ info: BlockInfo,
+ status: BlockStatus,
+ droppedMemorySize: Long = 0L) {
+ val needReregister = !tryToReportBlockStatus(blockId, info, status, droppedMemorySize)
if (needReregister) {
- logInfo("Got told to reregister updating block " + blockId)
- // Reregistering will report our new block for free.
+ logInfo("Got told to re-register updating block " + blockId)
+ // Re-registering will report our new block for free.
asyncReregister()
}
logDebug("Told master about block " + blockId)
@@ -223,27 +270,47 @@ private[spark] class BlockManager(
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
- private def tryToReportBlockStatus(blockId: BlockId, info: BlockInfo,
+ private def tryToReportBlockStatus(
+ blockId: BlockId,
+ info: BlockInfo,
+ status: BlockStatus,
droppedMemorySize: Long = 0L): Boolean = {
- val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
+ if (info.tellMaster) {
+ val storageLevel = status.storageLevel
+ val inMemSize = Math.max(status.memSize, droppedMemorySize)
+ val inTachyonSize = status.tachyonSize
+ val onDiskSize = status.diskSize
+ master.updateBlockInfo(
+ blockManagerId, blockId, storageLevel, inMemSize, onDiskSize, inTachyonSize)
+ } else {
+ true
+ }
+ }
+
+ /**
+ * Return the updated storage status of the block with the given ID. More specifically, if
+ * the block is dropped from memory and possibly added to disk, return the new storage level
+ * and the updated in-memory and on-disk sizes.
+ */
+ private def getCurrentBlockStatus(blockId: BlockId, info: BlockInfo): BlockStatus = {
+ val (newLevel, inMemSize, onDiskSize, inTachyonSize) = info.synchronized {
info.level match {
case null =>
- (StorageLevel.NONE, 0L, 0L, false)
+ (StorageLevel.NONE, 0L, 0L, 0L)
case level =>
val inMem = level.useMemory && memoryStore.contains(blockId)
+ val inTachyon = level.useOffHeap && tachyonStore.contains(blockId)
val onDisk = level.useDisk && diskStore.contains(blockId)
- val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication)
- val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize
+ val deserialized = if (inMem) level.deserialized else false
+ val replication = if (inMem || inTachyon || onDisk) level.replication else 1
+ val storageLevel = StorageLevel(onDisk, inMem, inTachyon, deserialized, replication)
+ val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
+ val tachyonSize = if (inTachyon) tachyonStore.getSize(blockId) else 0L
val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
- (storageLevel, memSize, diskSize, info.tellMaster)
+ (storageLevel, memSize, diskSize, tachyonSize)
}
}
-
- if (tellMaster) {
- master.updateBlockInfo(blockManagerId, blockId, curLevel, inMemSize, onDiskSize)
- } else {
- true
- }
+ BlockStatus(newLevel, inMemSize, onDiskSize, inTachyonSize)
}
/**
@@ -324,6 +391,24 @@ private[spark] class BlockManager(
}
}
+ // Look for the block in Tachyon
+ if (level.useOffHeap) {
+ logDebug("Getting block " + blockId + " from tachyon")
+ if (tachyonStore.contains(blockId)) {
+ tachyonStore.getBytes(blockId) match {
+ case Some(bytes) => {
+ if (!asValues) {
+ return Some(bytes)
+ } else {
+ return Some(dataDeserialize(blockId, bytes))
+ }
+ }
+ case None =>
+ logDebug("Block " + blockId + " not found in tachyon")
+ }
+ }
+ }
+
// Look for block on disk, potentially storing it back into memory if required:
if (level.useDisk) {
logDebug("Getting block " + blockId + " from disk")
@@ -391,10 +476,10 @@ private[spark] class BlockManager(
/**
* Get block from remote block managers as serialized bytes.
*/
- def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
+ def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
logDebug("Getting remote block " + blockId + " as bytes")
doGetRemote(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]]
- }
+ }
private def doGetRemote(blockId: BlockId, asValues: Boolean): Option[Any] = {
require(blockId != null, "BlockId is null")
@@ -440,9 +525,8 @@ private[spark] class BlockManager(
* so that we can control the maxMegabytesInFlight for the fetch.
*/
def getMultiple(
- blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer)
- : BlockFetcherIterator = {
-
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
+ serializer: Serializer): BlockFetcherIterator = {
val iter =
if (conf.getBoolean("spark.shuffle.use.netty", false)) {
new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
@@ -454,53 +538,71 @@ private[spark] class BlockManager(
iter
}
- def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
- : Long = {
- val elements = new ArrayBuffer[Any]
- elements ++= values
- put(blockId, elements, level, tellMaster)
+ def put(
+ blockId: BlockId,
+ values: Iterator[Any],
+ level: StorageLevel,
+ tellMaster: Boolean): Seq[(BlockId, BlockStatus)] = {
+ doPut(blockId, IteratorValues(values), level, tellMaster)
}
/**
* A short circuited method to get a block writer that can write data directly to disk.
- * The Block will be appended to the File specified by filename.
- * This is currently used for writing shuffle files out. Callers should handle error
- * cases.
+ * The Block will be appended to the File specified by filename. This is currently used for
+ * writing shuffle files out. Callers should handle error cases.
*/
- def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int)
- : BlockObjectWriter = {
+ def getDiskWriter(
+ blockId: BlockId,
+ file: File,
+ serializer: Serializer,
+ bufferSize: Int): BlockObjectWriter = {
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites)
}
/**
- * Put a new block of values to the block manager. Returns its (estimated) size in bytes.
+ * Put a new block of values to the block manager. Return a list of blocks updated as a
+ * result of this put.
*/
- def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
- tellMaster: Boolean = true) : Long = {
+ def put(
+ blockId: BlockId,
+ values: ArrayBuffer[Any],
+ level: StorageLevel,
+ tellMaster: Boolean = true): Seq[(BlockId, BlockStatus)] = {
require(values != null, "Values is null")
- doPut(blockId, Left(values), level, tellMaster)
+ doPut(blockId, ArrayBufferValues(values), level, tellMaster)
}
/**
- * Put a new block of serialized bytes to the block manager.
+ * Put a new block of serialized bytes to the block manager. Return a list of blocks updated
+ * as a result of this put.
*/
- def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel,
- tellMaster: Boolean = true) {
+ def putBytes(
+ blockId: BlockId,
+ bytes: ByteBuffer,
+ level: StorageLevel,
+ tellMaster: Boolean = true): Seq[(BlockId, BlockStatus)] = {
require(bytes != null, "Bytes is null")
- doPut(blockId, Right(bytes), level, tellMaster)
+ doPut(blockId, ByteBufferValues(bytes), level, tellMaster)
}
- private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer],
- level: StorageLevel, tellMaster: Boolean = true): Long = {
+ private def doPut(
+ blockId: BlockId,
+ data: Values,
+ level: StorageLevel,
+ tellMaster: Boolean = true): Seq[(BlockId, BlockStatus)] = {
+
require(blockId != null, "BlockId is null")
require(level != null && level.isValid, "StorageLevel is null or invalid")
+ // Return value
+ val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
+
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
- val myInfo = {
+ val putBlockInfo = {
val tinfo = new BlockInfo(level, tellMaster)
// Do atomically !
val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
@@ -508,7 +610,7 @@ private[spark] class BlockManager(
if (oldBlockOpt.isDefined) {
if (oldBlockOpt.get.waitForReady()) {
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- return oldBlockOpt.get.size
+ return updatedBlocks
}
// TODO: So the block info exists - but previous attempt to load it (?) failed.
@@ -530,13 +632,14 @@ private[spark] class BlockManager(
// Ditto for the bytes after the put
var bytesAfterPut: ByteBuffer = null
- // Size of the block in bytes (to return to caller)
+ // Size of the block in bytes
var size = 0L
// If we're storing bytes, then initiate the replication before storing them locally.
// This is faster as data is already serialized and ready to send.
- val replicationFuture = if (data.isRight && level.replication > 1) {
- val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper
+ val replicationFuture = if (data.isInstanceOf[ByteBufferValues] && level.replication > 1) {
+ // Duplicate doesn't copy the bytes, just creates a wrapper
+ val bufferView = data.asInstanceOf[ByteBufferValues].buffer.duplicate()
Future {
replicate(blockId, bufferView, level)
}
@@ -544,58 +647,87 @@ private[spark] class BlockManager(
null
}
- myInfo.synchronized {
+ putBlockInfo.synchronized {
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
var marked = false
try {
- data match {
- case Left(values) => {
- if (level.useMemory) {
- // Save it just to memory first, even if it also has useDisk set to true; we will
- // drop it to disk later if the memory store can't hold it.
- val res = memoryStore.putValues(blockId, values, level, true)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case Left(newIterator) => valuesAfterPut = newIterator
- }
- } else {
- // Save directly to disk.
- // Don't get back the bytes unless we replicate them.
- val askForBytes = level.replication > 1
- val res = diskStore.putValues(blockId, values, level, askForBytes)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case _ =>
- }
- }
+ if (level.useMemory) {
+ // Save it just to memory first, even if it also has useDisk set to true; we will
+ // drop it to disk later if the memory store can't hold it.
+ val res = data match {
+ case IteratorValues(iterator) =>
+ memoryStore.putValues(blockId, iterator, level, true)
+ case ArrayBufferValues(array) =>
+ memoryStore.putValues(blockId, array, level, true)
+ case ByteBufferValues(bytes) =>
+ bytes.rewind()
+ memoryStore.putBytes(blockId, bytes, level)
+ }
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case Left(newIterator) => valuesAfterPut = newIterator
}
- case Right(bytes) => {
- bytes.rewind()
- // Store it only in memory at first, even if useDisk is also set to true
- (if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level)
- size = bytes.limit
+ // Keep track of which blocks are dropped from memory
+ res.droppedBlocks.foreach { block => updatedBlocks += block }
+ } else if (level.useOffHeap) {
+ // Save to Tachyon.
+ val res = data match {
+ case IteratorValues(iterator) =>
+ tachyonStore.putValues(blockId, iterator, level, false)
+ case ArrayBufferValues(array) =>
+ tachyonStore.putValues(blockId, array, level, false)
+ case ByteBufferValues(bytes) =>
+ bytes.rewind()
+ tachyonStore.putBytes(blockId, bytes, level)
+ }
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
+ }
+ } else {
+ // Save directly to disk.
+ // Don't get back the bytes unless we replicate them.
+ val askForBytes = level.replication > 1
+
+ val res = data match {
+ case IteratorValues(iterator) =>
+ diskStore.putValues(blockId, iterator, level, askForBytes)
+ case ArrayBufferValues(array) =>
+ diskStore.putValues(blockId, array, level, askForBytes)
+ case ByteBufferValues(bytes) =>
+ bytes.rewind()
+ diskStore.putBytes(blockId, bytes, level)
+ }
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
}
}
- // Now that the block is in either the memory or disk store, let other threads read it,
- // and tell the master about it.
- marked = true
- myInfo.markReady(size)
- if (tellMaster) {
- reportBlockStatus(blockId, myInfo)
+ val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo)
+ if (putBlockStatus.storageLevel != StorageLevel.NONE) {
+ // Now that the block is in either the memory, tachyon, or disk store,
+ // let other threads read it, and tell the master about it.
+ marked = true
+ putBlockInfo.markReady(size)
+ if (tellMaster) {
+ reportBlockStatus(blockId, putBlockInfo, putBlockStatus)
+ }
+ updatedBlocks += ((blockId, putBlockStatus))
}
} finally {
- // If we failed at putting the block to memory/disk, notify other possible readers
+ // If we failed in putting the block to memory/disk, notify other possible readers
// that it has failed, and then remove it from the block info map.
- if (! marked) {
+ if (!marked) {
// Note that the remove must happen before markFailure otherwise another thread
// could've inserted a new BlockInfo before we remove it.
blockInfo.remove(blockId)
- myInfo.markFailure()
+ putBlockInfo.markFailure()
logWarning("Putting block " + blockId + " failed")
}
}
@@ -606,8 +738,8 @@ private[spark] class BlockManager(
// values and need to serialize and replicate them now:
if (level.replication > 1) {
data match {
- case Right(bytes) => Await.ready(replicationFuture, Duration.Inf)
- case Left(values) => {
+ case ByteBufferValues(bytes) => Await.ready(replicationFuture, Duration.Inf)
+ case _ => {
val remoteStartTime = System.currentTimeMillis
// Serialize the block if not already done
if (bytesAfterPut == null) {
@@ -634,7 +766,7 @@ private[spark] class BlockManager(
Utils.getUsedTimeMs(startTimeMs))
}
- size
+ updatedBlocks
}
/**
@@ -642,7 +774,8 @@ private[spark] class BlockManager(
*/
var cachedPeers: Seq[BlockManagerId] = null
private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel) {
- val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
+ val tLevel = StorageLevel(
+ level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 1)
if (cachedPeers == null) {
cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
}
@@ -671,28 +804,42 @@ private[spark] class BlockManager(
/**
* Write a block consisting of a single object.
*/
- def putSingle(blockId: BlockId, value: Any, level: StorageLevel, tellMaster: Boolean = true) {
+ def putSingle(
+ blockId: BlockId,
+ value: Any,
+ level: StorageLevel,
+ tellMaster: Boolean = true): Seq[(BlockId, BlockStatus)] = {
put(blockId, Iterator(value), level, tellMaster)
}
/**
* Drop a block from memory, possibly putting it on disk if applicable. Called when the memory
* store reaches its limit and needs to free up space.
+ *
+ * Return the block status if the given block has been updated, else None.
*/
- def dropFromMemory(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer]) {
+ def dropFromMemory(
+ blockId: BlockId,
+ data: Either[ArrayBuffer[Any], ByteBuffer]): Option[BlockStatus] = {
+
logInfo("Dropping block " + blockId + " from memory")
val info = blockInfo.get(blockId).orNull
+
+ // If the block has not already been dropped
if (info != null) {
info.synchronized {
// required ? As of now, this will be invoked only for blocks which are ready
// But in case this changes in future, adding for consistency sake.
- if (! info.waitForReady() ) {
+ if (!info.waitForReady()) {
// If we get here, the block write failed.
logWarning("Block " + blockId + " was marked as failure. Nothing to drop")
- return
+ return None
}
+ var blockIsUpdated = false
val level = info.level
+
+ // Drop to disk, if storage level requires
if (level.useDisk && !diskStore.contains(blockId)) {
logInfo("Writing block " + blockId + " to disk")
data match {
@@ -701,24 +848,33 @@ private[spark] class BlockManager(
case Right(bytes) =>
diskStore.putBytes(blockId, bytes, level)
}
+ blockIsUpdated = true
}
+
+ // Actually drop from memory store
val droppedMemorySize =
if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
- val blockWasRemoved = memoryStore.remove(blockId)
- if (!blockWasRemoved) {
+ val blockIsRemoved = memoryStore.remove(blockId)
+ if (blockIsRemoved) {
+ blockIsUpdated = true
+ } else {
logWarning("Block " + blockId + " could not be dropped from memory as it does not exist")
}
+
+ val status = getCurrentBlockStatus(blockId, info)
if (info.tellMaster) {
- reportBlockStatus(blockId, info, droppedMemorySize)
+ reportBlockStatus(blockId, info, status, droppedMemorySize)
}
if (!level.useDisk) {
// The block is completely gone from this node; forget it so we can put() it again later.
blockInfo.remove(blockId)
}
+ if (blockIsUpdated) {
+ return Some(status)
+ }
}
- } else {
- // The block has already been dropped
}
+ None
}
/**
@@ -726,11 +882,22 @@ private[spark] class BlockManager(
* @return The number of blocks removed.
*/
def removeRdd(rddId: Int): Int = {
- // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
- // from RDD.id to blocks.
+ // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks.
logInfo("Removing RDD " + rddId)
val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
- blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false))
+ blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) }
+ blocksToRemove.size
+ }
+
+ /**
+ * Remove all blocks belonging to the given broadcast.
+ */
+ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = {
+ logInfo("Removing broadcast " + broadcastId)
+ val blocksToRemove = blockInfo.keys.collect {
+ case bid @ BroadcastBlockId(`broadcastId`, _) => bid
+ }
+ blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) }
blocksToRemove.size
}
@@ -744,13 +911,15 @@ private[spark] class BlockManager(
// Removals are idempotent in disk store and memory store. At worst, we get a warning.
val removedFromMemory = memoryStore.remove(blockId)
val removedFromDisk = diskStore.remove(blockId)
- if (!removedFromMemory && !removedFromDisk) {
+ val removedFromTachyon = if (tachyonInitialized) tachyonStore.remove(blockId) else false
+ if (!removedFromMemory && !removedFromDisk && !removedFromTachyon) {
logWarning("Block " + blockId + " could not be removed as it was not found in either " +
- "the disk or memory store")
+ "the disk, memory, or tachyon store")
}
blockInfo.remove(blockId)
if (tellMaster && info.tellMaster) {
- reportBlockStatus(blockId, info)
+ val status = getCurrentBlockStatus(blockId, info)
+ reportBlockStatus(blockId, info, status)
}
} else {
// The block has already been removed; do nothing.
@@ -769,10 +938,10 @@ private[spark] class BlockManager(
}
private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) {
- val iterator = blockInfo.internalMap.entrySet().iterator()
+ val iterator = blockInfo.getEntrySet.iterator
while (iterator.hasNext) {
val entry = iterator.next()
- val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
+ val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp)
if (time < cleanupTime && shouldDrop(id)) {
info.synchronized {
val level = info.level
@@ -782,17 +951,21 @@ private[spark] class BlockManager(
if (level.useDisk) {
diskStore.remove(id)
}
+ if (level.useOffHeap) {
+ tachyonStore.remove(id)
+ }
iterator.remove()
logInfo("Dropped block " + id)
}
- reportBlockStatus(id, info)
+ val status = getCurrentBlockStatus(id, info)
+ reportBlockStatus(id, info, status)
}
}
}
def shouldCompress(blockId: BlockId): Boolean = blockId match {
case ShuffleBlockId(_, _, _) => compressShuffle
- case BroadcastBlockId(_) => compressBroadcast
+ case BroadcastBlockId(_, _) => compressBroadcast
case RDDBlockId(_, _) => compressRdds
case TempBlockId(_) => compressShuffleSpill
case _ => false
@@ -818,7 +991,7 @@ private[spark] class BlockManager(
outputStream: OutputStream,
values: Iterator[Any],
serializer: Serializer = defaultSerializer) {
- val byteStream = new FastBufferedOutputStream(outputStream)
+ val byteStream = new BufferedOutputStream(outputStream)
val ser = serializer.newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
}
@@ -828,10 +1001,9 @@ private[spark] class BlockManager(
blockId: BlockId,
values: Iterator[Any],
serializer: Serializer = defaultSerializer): ByteBuffer = {
- val byteStream = new FastByteArrayOutputStream(4096)
+ val byteStream = new ByteArrayOutputStream(4096)
dataSerializeStream(blockId, byteStream, values, serializer)
- byteStream.trim()
- ByteBuffer.wrap(byteStream.array)
+ ByteBuffer.wrap(byteStream.toByteArray)
}
/**
@@ -852,10 +1024,15 @@ private[spark] class BlockManager(
heartBeatTask.cancel()
}
connectionManager.stop()
+ shuffleBlockManager.stop()
+ diskBlockManager.stop()
actorSystem.stop(slaveActor)
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
+ if (tachyonInitialized) {
+ tachyonStore.clear()
+ }
metadataCleaner.cancel()
broadcastCleaner.cancel()
logInfo("BlockManager stopped")
@@ -895,9 +1072,8 @@ private[spark] object BlockManager extends Logging {
def blockIdsToBlockManagers(
blockIds: Array[BlockId],
env: SparkEnv,
- blockManagerMaster: BlockManagerMaster = null)
- : Map[BlockId, Seq[BlockManagerId]] =
- {
+ blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[BlockManagerId]] = {
+
// blockManagerMaster != null is used in tests
assert (env != null || blockManagerMaster != null)
val blockLocations: Seq[Seq[BlockManagerId]] = if (blockManagerMaster == null) {
@@ -916,18 +1092,14 @@ private[spark] object BlockManager extends Logging {
def blockIdsToExecutorIds(
blockIds: Array[BlockId],
env: SparkEnv,
- blockManagerMaster: BlockManagerMaster = null)
- : Map[BlockId, Seq[String]] =
- {
+ blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = {
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId))
}
def blockIdsToHosts(
blockIds: Array[BlockId],
env: SparkEnv,
- blockManagerMaster: BlockManagerMaster = null)
- : Map[BlockId, Seq[String]] =
- {
+ blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = {
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host))
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index 98cd6e68fa724..be537d77309bc 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -50,7 +50,6 @@ private[spark] class BlockManagerId private (
// DEBUG code
Utils.checkHost(host)
assert (port > 0)
-
host + ":" + port
}
@@ -93,7 +92,7 @@ private[spark] class BlockManagerId private (
private[spark] object BlockManagerId {
/**
- * Returns a [[org.apache.spark.storage.BlockManagerId]] for the given configuraiton.
+ * Returns a [[org.apache.spark.storage.BlockManagerId]] for the given configuration.
*
* @param execId ID of the executor.
* @param host Host name of the block manager.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index e531467cccb40..7897fade2df2b 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -28,8 +28,7 @@ import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.AkkaUtils
private[spark]
-class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Logging {
-
+class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Logging {
val AKKA_RETRY_ATTEMPTS: Int = conf.getInt("spark.akka.num.retries", 3)
val AKKA_RETRY_INTERVAL_MS: Int = conf.getInt("spark.akka.retry.wait", 3000)
@@ -53,8 +52,7 @@ class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Lo
}
/** Register the BlockManager's id with the driver. */
- def registerBlockManager(
- blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
logInfo("Trying to register BlockManager")
tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor))
logInfo("Registered BlockManager")
@@ -65,9 +63,10 @@ class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Lo
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long): Boolean = {
+ diskSize: Long,
+ tachyonSize: Long): Boolean = {
val res = askDriverWithReply[Boolean](
- UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize))
+ UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize))
logInfo("Updated info of block " + blockId)
res
}
@@ -82,6 +81,14 @@ class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Lo
askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
+ /**
+ * Check if block manager master has a block. Note that this can be used to check for only
+ * those blocks that are reported to block manager master.
+ */
+ def contains(blockId: BlockId) = {
+ !getLocations(blockId).isEmpty
+ }
+
/** Get ids of other nodes in the cluster from the driver */
def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
@@ -100,12 +107,10 @@ class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Lo
askDriverWithReply(RemoveBlock(blockId))
}
- /**
- * Remove all blocks belonging to the given RDD.
- */
+ /** Remove all blocks belonging to the given RDD. */
def removeRdd(rddId: Int, blocking: Boolean) {
val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
- future onFailure {
+ future.onFailure {
case e: Throwable => logError("Failed to remove RDD " + rddId, e)
}
if (blocking) {
@@ -113,6 +118,31 @@ class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Lo
}
}
+ /** Remove all blocks belonging to the given shuffle. */
+ def removeShuffle(shuffleId: Int, blocking: Boolean) {
+ val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
+ future.onFailure {
+ case e: Throwable => logError("Failed to remove shuffle " + shuffleId, e)
+ }
+ if (blocking) {
+ Await.result(future, timeout)
+ }
+ }
+
+ /** Remove all blocks belonging to the given broadcast. */
+ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) {
+ val future = askDriverWithReply[Future[Seq[Int]]](
+ RemoveBroadcast(broadcastId, removeFromMaster))
+ future.onFailure {
+ case e: Throwable =>
+ logError("Failed to remove broadcast " + broadcastId +
+ " with removeFromMaster = " + removeFromMaster, e)
+ }
+ if (blocking) {
+ Await.result(future, timeout)
+ }
+ }
+
/**
* Return the memory status for each block manager, in the form of a map from
* the block manager's id to two long values. The first value is the maximum
@@ -127,6 +157,51 @@ class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Lo
askDriverWithReply[Array[StorageStatus]](GetStorageStatus)
}
+ /**
+ * Return the block's status on all block managers, if any. NOTE: This is a
+ * potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, this invokes the master to query each block manager for the most
+ * updated block statuses. This is useful when the master is not informed of the given block
+ * by all block managers.
+ */
+ def getBlockStatus(
+ blockId: BlockId,
+ askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = {
+ val msg = GetBlockStatus(blockId, askSlaves)
+ /*
+ * To avoid potential deadlocks, the use of Futures is necessary, because the master actor
+ * should not block on waiting for a block manager, which can in turn be waiting for the
+ * master actor for a response to a prior message.
+ */
+ val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
+ val (blockManagerIds, futures) = response.unzip
+ val result = Await.result(Future.sequence(futures), timeout)
+ if (result == null) {
+ throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId)
+ }
+ val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]]
+ blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) =>
+ status.map { s => (blockManagerId, s) }
+ }.toMap
+ }
+
+ /**
+ * Return a list of ids of existing blocks such that the ids match the given filter. NOTE: This
+ * is a potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, this invokes the master to query each block manager for the most
+ * updated block statuses. This is useful when the master is not informed of the given block
+ * by all block managers.
+ */
+ def getMatchingBlockIds(
+ filter: BlockId => Boolean,
+ askSlaves: Boolean): Seq[BlockId] = {
+ val msg = GetMatchingBlockIds(filter, askSlaves)
+ val future = askDriverWithReply[Future[Seq[BlockId]]](msg)
+ Await.result(future, timeout)
+ }
+
/** Stop the driver actor, called only on the Spark driver node */
def stop() {
if (driverActor != null) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index a999d76a326a6..63fa5d3eb6541 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -28,6 +28,7 @@ import akka.actor.{Actor, ActorRef, Cancellable}
import akka.pattern.ask
import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.scheduler._
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -36,11 +37,11 @@ import org.apache.spark.util.{AkkaUtils, Utils}
* all slaves' block managers.
*/
private[spark]
-class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Actor with Logging {
+class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus)
+ extends Actor with Logging {
// Mapping from block manager id to the block manager's information.
- private val blockManagerInfo =
- new mutable.HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo]
+ private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]
// Mapping from executor ID to block manager ID.
private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
@@ -72,10 +73,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act
register(blockManagerId, maxMemSize, slaveActor)
sender ! true
- case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
+ case UpdateBlockInfo(
+ blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) =>
// TODO: Ideally we want to handle all the message replies in receive instead of in the
// individual private methods.
- updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size)
+ updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)
case GetLocations(blockId) =>
sender ! getLocations(blockId)
@@ -92,9 +94,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act
case GetStorageStatus =>
sender ! storageStatus
+ case GetBlockStatus(blockId, askSlaves) =>
+ sender ! blockStatus(blockId, askSlaves)
+
+ case GetMatchingBlockIds(filter, askSlaves) =>
+ sender ! getMatchingBlockIds(filter, askSlaves)
+
case RemoveRdd(rddId) =>
sender ! removeRdd(rddId)
+ case RemoveShuffle(shuffleId) =>
+ sender ! removeShuffle(shuffleId)
+
+ case RemoveBroadcast(broadcastId, removeFromDriver) =>
+ sender ! removeBroadcast(broadcastId, removeFromDriver)
+
case RemoveBlock(blockId) =>
removeBlockFromWorkers(blockId)
sender ! true
@@ -138,9 +152,41 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act
// The dispatcher is used as an implicit argument into the Future sequence construction.
import context.dispatcher
val removeMsg = RemoveRdd(rddId)
- Future.sequence(blockManagerInfo.values.map { bm =>
- bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
- }.toSeq)
+ Future.sequence(
+ blockManagerInfo.values.map { bm =>
+ bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
+ }.toSeq
+ )
+ }
+
+ private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
+ // Nothing to do in the BlockManagerMasterActor data structures
+ import context.dispatcher
+ val removeMsg = RemoveShuffle(shuffleId)
+ Future.sequence(
+ blockManagerInfo.values.map { bm =>
+ bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean]
+ }.toSeq
+ )
+ }
+
+ /**
+ * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified
+ * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed
+ * from the executors, but not from the driver.
+ */
+ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = {
+ // TODO: Consolidate usages of
+ import context.dispatcher
+ val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver)
+ val requiredBlockManagers = blockManagerInfo.values.filter { info =>
+ removeFromDriver || info.blockManagerId.executorId != ""
+ }
+ Future.sequence(
+ requiredBlockManagers.map { bm =>
+ bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
+ }.toSeq
+ )
}
private def removeBlockManager(blockManagerId: BlockManagerId) {
@@ -157,9 +203,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act
val locations = blockLocations.get(blockId)
locations -= blockManagerId
if (locations.size == 0) {
- blockLocations.remove(locations)
+ blockLocations.remove(blockId)
}
}
+ listenerBus.post(SparkListenerBlockManagerRemoved(blockManagerId))
}
private def expireDeadHosts() {
@@ -217,11 +264,66 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act
private def storageStatus: Array[StorageStatus] = {
blockManagerInfo.map { case(blockManagerId, info) =>
- import collection.JavaConverters._
- StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap)
+ val blockMap = mutable.Map[BlockId, BlockStatus](info.blocks.toSeq: _*)
+ new StorageStatus(blockManagerId, info.maxMem, blockMap)
}.toArray
}
+ /**
+ * Return the block's status for all block managers, if any. NOTE: This is a
+ * potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, the master queries each block manager for the most updated block
+ * statuses. This is useful when the master is not informed of the given block by all block
+ * managers.
+ */
+ private def blockStatus(
+ blockId: BlockId,
+ askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = {
+ import context.dispatcher
+ val getBlockStatus = GetBlockStatus(blockId)
+ /*
+ * Rather than blocking on the block status query, master actor should simply return
+ * Futures to avoid potential deadlocks. This can arise if there exists a block manager
+ * that is also waiting for this master actor's response to a previous message.
+ */
+ blockManagerInfo.values.map { info =>
+ val blockStatusFuture =
+ if (askSlaves) {
+ info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]]
+ } else {
+ Future { info.getStatus(blockId) }
+ }
+ (info.blockManagerId, blockStatusFuture)
+ }.toMap
+ }
+
+ /**
+ * Return the ids of blocks present in all the block managers that match the given filter.
+ * NOTE: This is a potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, the master queries each block manager for the most updated block
+ * statuses. This is useful when the master is not informed of the given block by all block
+ * managers.
+ */
+ private def getMatchingBlockIds(
+ filter: BlockId => Boolean,
+ askSlaves: Boolean): Future[Seq[BlockId]] = {
+ import context.dispatcher
+ val getMatchingBlockIds = GetMatchingBlockIds(filter)
+ Future.sequence(
+ blockManagerInfo.values.map { info =>
+ val future =
+ if (askSlaves) {
+ info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]]
+ } else {
+ Future { info.blocks.keys.filter(filter).toSeq }
+ }
+ future
+ }
+ ).map(_.flatten.toSeq)
+ }
+
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
@@ -233,9 +335,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act
case None =>
blockManagerIdByExecutor(id.executorId) = id
}
- blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo(
- id, System.currentTimeMillis(), maxMemSize, slaveActor)
+ blockManagerInfo(id) =
+ new BlockManagerInfo(id, System.currentTimeMillis(), maxMemSize, slaveActor)
}
+ listenerBus.post(SparkListenerBlockManagerAdded(id, maxMemSize))
}
private def updateBlockInfo(
@@ -243,7 +346,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long) {
+ diskSize: Long,
+ tachyonSize: Long) {
if (!blockManagerInfo.contains(blockManagerId)) {
if (blockManagerId.executorId == "" && !isLocal) {
@@ -262,7 +366,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act
return
}
- blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize)
+ blockManagerInfo(blockManagerId).updateBlockInfo(
+ blockId, storageLevel, memSize, diskSize, tachyonSize)
var locations: mutable.HashSet[BlockManagerId] = null
if (blockLocations.containsKey(blockId)) {
@@ -306,98 +411,113 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act
}
}
+private[spark] case class BlockStatus(
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long,
+ tachyonSize: Long)
-private[spark]
-object BlockManagerMasterActor {
-
- case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long)
+private[spark] class BlockManagerInfo(
+ val blockManagerId: BlockManagerId,
+ timeMs: Long,
+ val maxMem: Long,
+ val slaveActor: ActorRef)
+ extends Logging {
- class BlockManagerInfo(
- val blockManagerId: BlockManagerId,
- timeMs: Long,
- val maxMem: Long,
- val slaveActor: ActorRef)
- extends Logging {
+ private var _lastSeenMs: Long = timeMs
+ private var _remainingMem: Long = maxMem
- private var _lastSeenMs: Long = timeMs
- private var _remainingMem: Long = maxMem
+ // Mapping from block id to its status.
+ private val _blocks = new JHashMap[BlockId, BlockStatus]
- // Mapping from block id to its status.
- private val _blocks = new JHashMap[BlockId, BlockStatus]
+ logInfo("Registering block manager %s with %s RAM".format(
+ blockManagerId.hostPort, Utils.bytesToString(maxMem)))
- logInfo("Registering block manager %s with %s RAM".format(
- blockManagerId.hostPort, Utils.bytesToString(maxMem)))
+ def getStatus(blockId: BlockId) = Option(_blocks.get(blockId))
- def updateLastSeenMs() {
- _lastSeenMs = System.currentTimeMillis()
- }
+ def updateLastSeenMs() {
+ _lastSeenMs = System.currentTimeMillis()
+ }
- def updateBlockInfo(blockId: BlockId, storageLevel: StorageLevel, memSize: Long,
- diskSize: Long) {
+ def updateBlockInfo(
+ blockId: BlockId,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long,
+ tachyonSize: Long) {
- updateLastSeenMs()
+ updateLastSeenMs()
- if (_blocks.containsKey(blockId)) {
- // The block exists on the slave already.
- val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel
+ if (_blocks.containsKey(blockId)) {
+ // The block exists on the slave already.
+ val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel
- if (originalLevel.useMemory) {
- _remainingMem += memSize
- }
+ if (originalLevel.useMemory) {
+ _remainingMem += memSize
}
+ }
- if (storageLevel.isValid) {
- // isValid means it is either stored in-memory or on-disk.
- // But the memSize here indicates the data size in or dropped from memory,
- // and the diskSize here indicates the data size in or dropped to disk.
- // They can be both larger than 0, when a block is dropped from memory to disk.
- // Therefore, a safe way to set BlockStatus is to set its info in accurate modes.
- if (storageLevel.useMemory) {
- _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0))
- _remainingMem -= memSize
- logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(memSize),
- Utils.bytesToString(_remainingMem)))
- }
- if (storageLevel.useDisk) {
- _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize))
- logInfo("Added %s on disk on %s (size: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize)))
- }
- } else if (_blocks.containsKey(blockId)) {
- // If isValid is not true, drop the block.
- val blockStatus: BlockStatus = _blocks.get(blockId)
- _blocks.remove(blockId)
- if (blockStatus.storageLevel.useMemory) {
- _remainingMem += blockStatus.memSize
- logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize),
- Utils.bytesToString(_remainingMem)))
- }
- if (blockStatus.storageLevel.useDisk) {
- logInfo("Removed %s on %s on disk (size: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize)))
- }
+ if (storageLevel.isValid) {
+ /* isValid means it is either stored in-memory, on-disk or on-Tachyon.
+ * But the memSize here indicates the data size in or dropped from memory,
+ * tachyonSize here indicates the data size in or dropped from Tachyon,
+ * and the diskSize here indicates the data size in or dropped to disk.
+ * They can be both larger than 0, when a block is dropped from memory to disk.
+ * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */
+ if (storageLevel.useMemory) {
+ _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0))
+ _remainingMem -= memSize
+ logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(memSize),
+ Utils.bytesToString(_remainingMem)))
+ }
+ if (storageLevel.useDisk) {
+ _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0))
+ logInfo("Added %s on disk on %s (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize)))
+ }
+ if (storageLevel.useOffHeap) {
+ _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, tachyonSize))
+ logInfo("Added %s on tachyon on %s (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(tachyonSize)))
+ }
+ } else if (_blocks.containsKey(blockId)) {
+ // If isValid is not true, drop the block.
+ val blockStatus: BlockStatus = _blocks.get(blockId)
+ _blocks.remove(blockId)
+ if (blockStatus.storageLevel.useMemory) {
+ _remainingMem += blockStatus.memSize
+ logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize),
+ Utils.bytesToString(_remainingMem)))
+ }
+ if (blockStatus.storageLevel.useDisk) {
+ logInfo("Removed %s on %s on disk (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize)))
+ }
+ if (blockStatus.storageLevel.useOffHeap) {
+ logInfo("Removed %s on %s on tachyon (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.tachyonSize)))
}
}
+ }
- def removeBlock(blockId: BlockId) {
- if (_blocks.containsKey(blockId)) {
- _remainingMem += _blocks.get(blockId).memSize
- _blocks.remove(blockId)
- }
+ def removeBlock(blockId: BlockId) {
+ if (_blocks.containsKey(blockId)) {
+ _remainingMem += _blocks.get(blockId).memSize
+ _blocks.remove(blockId)
}
+ }
- def remainingMem: Long = _remainingMem
+ def remainingMem: Long = _remainingMem
- def lastSeenMs: Long = _lastSeenMs
+ def lastSeenMs: Long = _lastSeenMs
- def blocks: JHashMap[BlockId, BlockStatus] = _blocks
+ def blocks: JHashMap[BlockId, BlockStatus] = _blocks
- override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
+ override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
- def clear() {
- _blocks.clear()
- }
+ def clear() {
+ _blocks.clear()
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index bbb9529b5a0ca..2b53bf33b5fba 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -34,6 +34,13 @@ private[storage] object BlockManagerMessages {
// Remove all blocks belonging to a specific RDD.
case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
+ // Remove all blocks belonging to a specific shuffle.
+ case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave
+
+ // Remove all blocks belonging to a specific broadcast.
+ case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true)
+ extends ToBlockManagerSlave
+
//////////////////////////////////////////////////////////////////////////////////
// Messages from slaves to the master.
@@ -53,11 +60,12 @@ private[storage] object BlockManagerMessages {
var blockId: BlockId,
var storageLevel: StorageLevel,
var memSize: Long,
- var diskSize: Long)
+ var diskSize: Long,
+ var tachyonSize: Long)
extends ToBlockManagerMaster
with Externalizable {
- def this() = this(null, null, null, 0, 0) // For deserialization only
+ def this() = this(null, null, null, 0, 0, 0) // For deserialization only
override def writeExternal(out: ObjectOutput) {
blockManagerId.writeExternal(out)
@@ -65,6 +73,7 @@ private[storage] object BlockManagerMessages {
storageLevel.writeExternal(out)
out.writeLong(memSize)
out.writeLong(diskSize)
+ out.writeLong(tachyonSize)
}
override def readExternal(in: ObjectInput) {
@@ -73,21 +82,25 @@ private[storage] object BlockManagerMessages {
storageLevel = StorageLevel(in)
memSize = in.readLong()
diskSize = in.readLong()
+ tachyonSize = in.readLong()
}
}
object UpdateBlockInfo {
- def apply(blockManagerId: BlockManagerId,
+ def apply(
+ blockManagerId: BlockManagerId,
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long): UpdateBlockInfo = {
- new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)
+ diskSize: Long,
+ tachyonSize: Long): UpdateBlockInfo = {
+ new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize)
}
// For pattern-matching
- def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, BlockId, StorageLevel, Long, Long)] = {
- Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize))
+ def unapply(h: UpdateBlockInfo)
+ : Option[(BlockManagerId, BlockId, StorageLevel, Long, Long, Long)] = {
+ Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize, h.tachyonSize))
}
}
@@ -103,7 +116,13 @@ private[storage] object BlockManagerMessages {
case object GetMemoryStatus extends ToBlockManagerMaster
- case object ExpireDeadHosts extends ToBlockManagerMaster
-
case object GetStorageStatus extends ToBlockManagerMaster
+
+ case class GetBlockStatus(blockId: BlockId, askSlaves: Boolean = true)
+ extends ToBlockManagerMaster
+
+ case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true)
+ extends ToBlockManagerMaster
+
+ case object ExpireDeadHosts extends ToBlockManagerMaster
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
index bcfb82d3c7336..6d4db064dff58 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -17,8 +17,11 @@
package org.apache.spark.storage
-import akka.actor.Actor
+import scala.concurrent.Future
+import akka.actor.{ActorRef, Actor}
+
+import org.apache.spark.{Logging, MapOutputTracker}
import org.apache.spark.storage.BlockManagerMessages._
/**
@@ -26,14 +29,59 @@ import org.apache.spark.storage.BlockManagerMessages._
* this is used to remove blocks from the slave's BlockManager.
*/
private[storage]
-class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
- override def receive = {
+class BlockManagerSlaveActor(
+ blockManager: BlockManager,
+ mapOutputTracker: MapOutputTracker)
+ extends Actor with Logging {
+
+ import context.dispatcher
+ // Operations that involve removing blocks may be slow and should be done asynchronously
+ override def receive = {
case RemoveBlock(blockId) =>
- blockManager.removeBlock(blockId)
+ doAsync[Boolean]("removing block " + blockId, sender) {
+ blockManager.removeBlock(blockId)
+ true
+ }
case RemoveRdd(rddId) =>
- val numBlocksRemoved = blockManager.removeRdd(rddId)
- sender ! numBlocksRemoved
+ doAsync[Int]("removing RDD " + rddId, sender) {
+ blockManager.removeRdd(rddId)
+ }
+
+ case RemoveShuffle(shuffleId) =>
+ doAsync[Boolean]("removing shuffle " + shuffleId, sender) {
+ if (mapOutputTracker != null) {
+ mapOutputTracker.unregisterShuffle(shuffleId)
+ }
+ blockManager.shuffleBlockManager.removeShuffle(shuffleId)
+ }
+
+ case RemoveBroadcast(broadcastId, tellMaster) =>
+ doAsync[Int]("removing broadcast " + broadcastId, sender) {
+ blockManager.removeBroadcast(broadcastId, tellMaster)
+ }
+
+ case GetBlockStatus(blockId, _) =>
+ sender ! blockManager.getStatus(blockId)
+
+ case GetMatchingBlockIds(filter, _) =>
+ sender ! blockManager.getMatchingBlockIds(filter)
+ }
+
+ private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) {
+ val future = Future {
+ logDebug(actionMessage)
+ body
+ }
+ future.onSuccess { case response =>
+ logDebug("Done " + actionMessage + ", response is " + response)
+ responseActor ! response
+ logDebug("Sent response: " + response + " to " + responseActor)
+ }
+ future.onFailure { case t: Throwable =>
+ logError("Error in " + actionMessage, t)
+ responseActor ! null.asInstanceOf[T]
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
index 7168ae18c2615..a2bfce7b4a0fa 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
@@ -37,7 +37,7 @@ private[spark] class BlockMessage() {
private var id: BlockId = null
private var data: ByteBuffer = null
private var level: StorageLevel = null
-
+
def set(getBlock: GetBlock) {
typ = BlockMessage.TYPE_GET_BLOCK
id = getBlock.id
@@ -57,7 +57,6 @@ private[spark] class BlockMessage() {
}
def set(buffer: ByteBuffer) {
- val startTime = System.currentTimeMillis
/*
println()
println("BlockMessage: ")
@@ -75,13 +74,13 @@ private[spark] class BlockMessage() {
idBuilder += buffer.getChar()
}
id = BlockId(idBuilder.toString)
-
+
if (typ == BlockMessage.TYPE_PUT_BLOCK) {
val booleanInt = buffer.getInt()
val replication = buffer.getInt()
level = StorageLevel(booleanInt, replication)
-
+
val dataLength = buffer.getInt()
data = ByteBuffer.allocate(dataLength)
if (dataLength != buffer.remaining) {
@@ -100,7 +99,6 @@ private[spark] class BlockMessage() {
data.flip()
}
- val finishTime = System.currentTimeMillis
}
def set(bufferMsg: BufferMessage) {
@@ -108,14 +106,13 @@ private[spark] class BlockMessage() {
buffer.clear()
set(buffer)
}
-
+
def getType: Int = typ
def getId: BlockId = id
def getData: ByteBuffer = data
def getLevel: StorageLevel = level
-
+
def toBufferMessage: BufferMessage = {
- val startTime = System.currentTimeMillis
val buffers = new ArrayBuffer[ByteBuffer]()
var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2)
buffer.putInt(typ).putInt(id.name.length)
@@ -127,7 +124,7 @@ private[spark] class BlockMessage() {
buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication)
buffer.flip()
buffers += buffer
-
+
buffer = ByteBuffer.allocate(4).putInt(data.remaining)
buffer.flip()
buffers += buffer
@@ -140,7 +137,7 @@ private[spark] class BlockMessage() {
buffers += data
}
-
+
/*
println()
println("BlockMessage: ")
@@ -153,12 +150,11 @@ private[spark] class BlockMessage() {
println()
println()
*/
- val finishTime = System.currentTimeMillis
Message.createBufferMessage(buffers)
}
override def toString: String = {
- "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level +
+ "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level +
", data = " + (if (data != null) data.remaining.toString else "null") + "]"
}
}
@@ -168,7 +164,7 @@ private[spark] object BlockMessage {
val TYPE_GET_BLOCK: Int = 1
val TYPE_GOT_BLOCK: Int = 2
val TYPE_PUT_BLOCK: Int = 3
-
+
def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = {
val newBlockMessage = new BlockMessage()
newBlockMessage.set(bufferMessage)
@@ -192,7 +188,7 @@ private[spark] object BlockMessage {
newBlockMessage.set(gotBlock)
newBlockMessage
}
-
+
def fromPutBlock(putBlock: PutBlock): BlockMessage = {
val newBlockMessage = new BlockMessage()
newBlockMessage.set(putBlock)
@@ -206,7 +202,7 @@ private[spark] object BlockMessage {
val bMsg = B.toBufferMessage
val C = new BlockMessage()
C.set(bMsg)
-
+
println(B.getId + " " + B.getLevel)
println(C.getId + " " + C.getLevel)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
index dc62b1efaa7d4..973d85c0a9b3a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
@@ -27,16 +27,16 @@ import org.apache.spark.network._
private[spark]
class BlockMessageArray(var blockMessages: Seq[BlockMessage])
extends Seq[BlockMessage] with Logging {
-
+
def this(bm: BlockMessage) = this(Array(bm))
def this() = this(null.asInstanceOf[Seq[BlockMessage]])
- def apply(i: Int) = blockMessages(i)
+ def apply(i: Int) = blockMessages(i)
def iterator = blockMessages.iterator
- def length = blockMessages.length
+ def length = blockMessages.length
def set(bufferMessage: BufferMessage) {
val startTime = System.currentTimeMillis
@@ -62,15 +62,15 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage])
logDebug("Trying to convert buffer " + newBuffer + " to block message")
val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer)
logDebug("Created " + newBlockMessage)
- newBlockMessages += newBlockMessage
+ newBlockMessages += newBlockMessage
buffer.position(buffer.position() + size)
}
val finishTime = System.currentTimeMillis
logDebug("Converted block message array from buffer message in " +
(finishTime - startTime) / 1000.0 + " s")
- this.blockMessages = newBlockMessages
+ this.blockMessages = newBlockMessages
}
-
+
def toBufferMessage: BufferMessage = {
val buffers = new ArrayBuffer[ByteBuffer]()
@@ -83,7 +83,7 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage])
buffers ++= bufferMessage.buffers
logDebug("Added " + bufferMessage)
})
-
+
logDebug("Buffer list:")
buffers.foreach((x: ByteBuffer) => logDebug("" + x))
/*
@@ -103,13 +103,13 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage])
}
private[spark] object BlockMessageArray {
-
+
def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
val newBlockMessageArray = new BlockMessageArray()
newBlockMessageArray.set(bufferMessage)
newBlockMessageArray
}
-
+
def main(args: Array[String]) {
val blockMessages =
(0 until 10).map { i =>
@@ -124,10 +124,10 @@ private[spark] object BlockMessageArray {
}
val blockMessageArray = new BlockMessageArray(blockMessages)
println("Block message array created")
-
+
val bufferMessage = blockMessageArray.toBufferMessage
println("Converted to buffer message")
-
+
val totalSize = bufferMessage.size
val newBuffer = ByteBuffer.allocate(totalSize)
newBuffer.clear()
@@ -137,7 +137,7 @@ private[spark] object BlockMessageArray {
buffer.rewind()
})
newBuffer.flip
- val newBufferMessage = Message.createBufferMessage(newBuffer)
+ val newBufferMessage = Message.createBufferMessage(newBuffer)
println("Copied to new buffer message, size = " + newBufferMessage.size)
val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage)
@@ -147,7 +147,7 @@ private[spark] object BlockMessageArray {
case BlockMessage.TYPE_PUT_BLOCK => {
val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
println(pB)
- }
+ }
case BlockMessage.TYPE_GET_BLOCK => {
val gB = new GetBlock(blockMessage.getId)
println(gB)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 696b930a26b9e..a2687e6be4e34 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -17,11 +17,9 @@
package org.apache.spark.storage
-import java.io.{FileOutputStream, File, OutputStream}
+import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream}
import java.nio.channels.FileChannel
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
import org.apache.spark.Logging
import org.apache.spark.serializer.{SerializationStream, Serializer}
@@ -119,7 +117,7 @@ private[spark] class DiskBlockObjectWriter(
ts = new TimeTrackingOutputStream(fos)
channel = fos.getChannel()
lastValidPosition = initialPosition
- bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
+ bs = compressStream(new BufferedOutputStream(ts, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
initialized = true
this
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
index b047644b88f48..9a9be047c7245 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
@@ -28,7 +28,7 @@ import org.apache.spark.Logging
*/
private[spark]
abstract class BlockStore(val blockManager: BlockManager) extends Logging {
- def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel)
+ def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) : PutResult
/**
* Put in a block and, possibly, also return its content as either bytes or another Iterator.
@@ -37,6 +37,9 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging {
* @return a PutResult that contains the size of the data, as well as the values put if
* returnValues is true (if not, the result's data field can be null)
*/
+ def putValues(blockId: BlockId, values: Iterator[Any], level: StorageLevel,
+ returnValues: Boolean) : PutResult
+
def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
returnValues: Boolean) : PutResult
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index f3e1c38744d78..cf6ef0029a861 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -90,6 +90,20 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
def getFile(blockId: BlockId): File = getFile(blockId.name)
+ /** Check if disk block manager has a block. */
+ def containsBlock(blockId: BlockId): Boolean = {
+ getBlockLocation(blockId).file.exists()
+ }
+
+ /** List all the blocks currently stored on disk by the disk manager. */
+ def getAllBlocks(): Seq[BlockId] = {
+ // Get all the files inside the array of array of directories
+ subDirs.flatten.filter(_ != null).flatMap { dir =>
+ val files = dir.list()
+ if (files != null) files else Seq.empty
+ }.map(BlockId.apply)
+ }
+
/** Produces a unique block id and File suitable for intermediate results. */
def createTempBlock(): (TempBlockId, File) = {
var blockId = new TempBlockId(UUID.randomUUID())
@@ -136,20 +150,27 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run() {
logDebug("Shutdown hook called")
- localDirs.foreach { localDir =>
- try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
- } catch {
- case t: Throwable =>
- logError("Exception while deleting local spark dir: " + localDir, t)
- }
- }
+ DiskBlockManager.this.stop()
+ }
+ })
+ }
- if (shuffleSender != null) {
- shuffleSender.stop()
+ /** Cleanup local dirs and stop shuffle sender. */
+ private[spark] def stop() {
+ localDirs.foreach { localDir =>
+ if (localDir.isDirectory() && localDir.exists()) {
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting local spark dir: " + localDir, t)
}
}
- })
+ }
+
+ if (shuffleSender != null) {
+ shuffleSender.stop()
+ }
}
private[storage] def startShuffleBlockSender(port: Int): Int = {
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 74d7b5b82f357..0ab9fad422717 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -39,7 +39,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
diskManager.getBlockLocation(blockId).length
}
- override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) : PutResult = {
// So that we do not modify the input offsets !
// duplicate does not copy buffer, so inexpensive
val bytes = _bytes.duplicate()
@@ -54,6 +54,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
+ return PutResult(bytes.limit(), Right(bytes.duplicate()))
}
override def putValues(
@@ -61,13 +62,22 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean)
+ : PutResult = {
+ return putValues(blockId, values.toIterator, level, returnValues)
+ }
+
+ override def putValues(
+ blockId: BlockId,
+ values: Iterator[Any],
+ level: StorageLevel,
+ returnValues: Boolean)
: PutResult = {
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
val file = diskManager.getFile(blockId)
val outputStream = new FileOutputStream(file)
- blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
+ blockManager.dataSerializeStream(blockId, outputStream, values)
val length = file.length
val timeTaken = System.currentTimeMillis - startTime
diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
index 555486830a769..132502b75f8cd 100644
--- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
+++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
@@ -23,6 +23,6 @@ import java.io.File
* References a particular segment of a file (potentially the entire file),
* based off an offset and a length.
*/
-private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) {
+private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) {
override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 18141756518c5..488f1ea9628f5 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -49,7 +49,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = {
// Work on a duplicate - since the original input might be used elsewhere.
val bytes = _bytes.duplicate()
bytes.rewind()
@@ -59,8 +59,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
elements ++= values
val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
tryToPut(blockId, elements, sizeEstimate, true)
+ PutResult(sizeEstimate, Left(values.toIterator))
} else {
tryToPut(blockId, bytes, bytes.limit, false)
+ PutResult(bytes.limit(), Right(bytes.duplicate()))
}
}
@@ -68,20 +70,28 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
blockId: BlockId,
values: ArrayBuffer[Any],
level: StorageLevel,
- returnValues: Boolean)
- : PutResult = {
-
+ returnValues: Boolean): PutResult = {
if (level.deserialized) {
val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef])
- tryToPut(blockId, values, sizeEstimate, true)
- PutResult(sizeEstimate, Left(values.iterator))
+ val putAttempt = tryToPut(blockId, values, sizeEstimate, deserialized = true)
+ PutResult(sizeEstimate, Left(values.iterator), putAttempt.droppedBlocks)
} else {
val bytes = blockManager.dataSerialize(blockId, values.iterator)
- tryToPut(blockId, bytes, bytes.limit, false)
- PutResult(bytes.limit(), Right(bytes.duplicate()))
+ val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false)
+ PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks)
}
}
+ override def putValues(
+ blockId: BlockId,
+ values: Iterator[Any],
+ level: StorageLevel,
+ returnValues: Boolean): PutResult = {
+ val valueEntries = new ArrayBuffer[Any]()
+ valueEntries ++= values
+ putValues(blockId, valueEntries, level, returnValues)
+ }
+
override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
val entry = entries.synchronized {
entries.get(blockId)
@@ -143,19 +153,34 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* an ArrayBuffer if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated)
* size must also be passed by the caller.
*
- * Locks on the object putLock to ensure that all the put requests and its associated block
+ * Lock on the object putLock to ensure that all the put requests and its associated block
* dropping is done by only on thread at a time. Otherwise while one thread is dropping
* blocks to free memory for one block, another thread may use up the freed space for
* another block.
+ *
+ * Return whether put was successful, along with the blocks dropped in the process.
*/
- private def tryToPut(blockId: BlockId, value: Any, size: Long, deserialized: Boolean): Boolean = {
- // TODO: Its possible to optimize the locking by locking entries only when selecting blocks
- // to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been
- // released, it must be ensured that those to-be-dropped blocks are not double counted for
- // freeing up more space for another block that needs to be put. Only then the actually dropping
- // of blocks (and writing to disk if necessary) can proceed in parallel.
+ private def tryToPut(
+ blockId: BlockId,
+ value: Any,
+ size: Long,
+ deserialized: Boolean): ResultWithDroppedBlocks = {
+
+ /* TODO: Its possible to optimize the locking by locking entries only when selecting blocks
+ * to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has
+ * been released, it must be ensured that those to-be-dropped blocks are not double counted
+ * for freeing up more space for another block that needs to be put. Only then the actually
+ * dropping of blocks (and writing to disk if necessary) can proceed in parallel. */
+
+ var putSuccess = false
+ val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
+
putLock.synchronized {
- if (ensureFreeSpace(blockId, size)) {
+ val freeSpaceResult = ensureFreeSpace(blockId, size)
+ val enoughFreeSpace = freeSpaceResult.success
+ droppedBlocks ++= freeSpaceResult.droppedBlocks
+
+ if (enoughFreeSpace) {
val entry = new Entry(value, size, deserialized)
entries.synchronized {
entries.put(blockId, entry)
@@ -168,7 +193,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
logInfo("Block %s stored as bytes to memory (size %s, free %s)".format(
blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory)))
}
- true
+ putSuccess = true
} else {
// Tell the block manager that we couldn't put it in memory so that it can drop it to
// disk if the block allows disk storage.
@@ -177,29 +202,33 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
} else {
Right(value.asInstanceOf[ByteBuffer].duplicate())
}
- blockManager.dropFromMemory(blockId, data)
- false
+ val droppedBlockStatus = blockManager.dropFromMemory(blockId, data)
+ droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
}
}
+ ResultWithDroppedBlocks(putSuccess, droppedBlocks)
}
/**
- * Tries to free up a given amount of space to store a particular block, but can fail and return
- * false if either the block is bigger than our memory or it would require replacing another
- * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
+ * Try to free up a given amount of space to store a particular block, but can fail if
+ * either the block is bigger than our memory or it would require replacing another block
+ * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
* don't fit into memory that we want to avoid).
*
- * Assumes that a lock is held by the caller to ensure only one thread is dropping blocks.
+ * Assume that a lock is held by the caller to ensure only one thread is dropping blocks.
* Otherwise, the freed space may fill up before the caller puts in their new value.
+ *
+ * Return whether there is enough free space, along with the blocks dropped in the process.
*/
- private def ensureFreeSpace(blockIdToAdd: BlockId, space: Long): Boolean = {
-
+ private def ensureFreeSpace(blockIdToAdd: BlockId, space: Long): ResultWithDroppedBlocks = {
logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
space, currentMemory, maxMemory))
+ val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
+
if (space > maxMemory) {
logInfo("Will not store " + blockIdToAdd + " as it is larger than our memory limit")
- return false
+ return ResultWithDroppedBlocks(success = false, droppedBlocks)
}
if (maxMemory - currentMemory < space) {
@@ -215,13 +244,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) {
val pair = iterator.next()
val blockId = pair.getKey
- if (rddToAdd.isDefined && rddToAdd == getRddId(blockId)) {
- logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " +
- "block from the same RDD")
- return false
+ if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) {
+ selectedBlocks += blockId
+ selectedMemory += pair.getValue.size
}
- selectedBlocks += blockId
- selectedMemory += pair.getValue.size
}
}
@@ -238,15 +264,18 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
} else {
Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
}
- blockManager.dropFromMemory(blockId, data)
+ val droppedBlockStatus = blockManager.dropFromMemory(blockId, data)
+ droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
}
}
- return true
+ return ResultWithDroppedBlocks(success = true, droppedBlocks)
} else {
- return false
+ logInfo(s"Will not store $blockIdToAdd as it would require dropping another block " +
+ "from the same RDD")
+ return ResultWithDroppedBlocks(success = false, droppedBlocks)
}
}
- true
+ ResultWithDroppedBlocks(success = true, droppedBlocks)
}
override def contains(blockId: BlockId): Boolean = {
@@ -254,3 +283,6 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
+private case class ResultWithDroppedBlocks(
+ success: Boolean,
+ droppedBlocks: Seq[(BlockId, BlockStatus)])
diff --git a/core/src/main/scala/org/apache/spark/storage/PutResult.scala b/core/src/main/scala/org/apache/spark/storage/PutResult.scala
index 2eba2f06b5bfd..f0eac7594ecf6 100644
--- a/core/src/main/scala/org/apache/spark/storage/PutResult.scala
+++ b/core/src/main/scala/org/apache/spark/storage/PutResult.scala
@@ -20,7 +20,13 @@ package org.apache.spark.storage
import java.nio.ByteBuffer
/**
- * Result of adding a block into a BlockStore. Contains its estimated size, and possibly the
- * values put if the caller asked for them to be returned (e.g. for chaining replication)
+ * Result of adding a block into a BlockStore. This case class contains a few things:
+ * (1) The estimated size of the put,
+ * (2) The values put if the caller asked for them to be returned (e.g. for chaining
+ * replication), and
+ * (3) A list of blocks dropped as a result of this put. This is always empty for DiskStore.
*/
-private[spark] case class PutResult(size: Long, data: Either[Iterator[_], ByteBuffer])
+private[spark] case class PutResult(
+ size: Long,
+ data: Either[Iterator[_], ByteBuffer],
+ droppedBlocks: Seq[(BlockId, BlockStatus)] = Seq.empty)
diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala
new file mode 100644
index 0000000000000..023fd6e4d8baa
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+
+@DeveloperApi
+class RDDInfo(
+ val id: Int,
+ val name: String,
+ val numPartitions: Int,
+ val storageLevel: StorageLevel)
+ extends Ordered[RDDInfo] {
+
+ var numCachedPartitions = 0
+ var memSize = 0L
+ var diskSize = 0L
+ var tachyonSize = 0L
+
+ override def toString = {
+ import Utils.bytesToString
+ ("RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; " +
+ "TachyonSize: %s; DiskSize: %s").format(
+ name, id, storageLevel.toString, numCachedPartitions, numPartitions,
+ bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize))
+ }
+
+ override def compare(that: RDDInfo) = {
+ this.id - that.id
+ }
+}
+
+private[spark] object RDDInfo {
+ def fromRdd(rdd: RDD[_]): RDDInfo = {
+ val rddName = Option(rdd.name).getOrElse(rdd.id.toString)
+ new RDDInfo(rdd.id, rddName, rdd.partitions.size, rdd.getStorageLevel)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index bb07c8cb134cc..35910e552fe86 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -169,23 +169,47 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
throw new IllegalStateException("Failed to find shuffle block: " + id)
}
+ /** Remove all the blocks / files and metadata related to a particular shuffle. */
+ def removeShuffle(shuffleId: ShuffleId): Boolean = {
+ // Do not change the ordering of this, if shuffleStates should be removed only
+ // after the corresponding shuffle blocks have been removed
+ val cleaned = removeShuffleBlocks(shuffleId)
+ shuffleStates.remove(shuffleId)
+ cleaned
+ }
+
+ /** Remove all the blocks / files related to a particular shuffle. */
+ private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
+ shuffleStates.get(shuffleId) match {
+ case Some(state) =>
+ if (consolidateShuffleFiles) {
+ for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
+ file.delete()
+ }
+ } else {
+ for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
+ val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
+ blockManager.diskBlockManager.getFile(blockId).delete()
+ }
+ }
+ logInfo("Deleted all files for shuffle " + shuffleId)
+ true
+ case None =>
+ logInfo("Could not find files for shuffle " + shuffleId + " for deleting")
+ false
+ }
+ }
+
private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
"merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
}
private def cleanup(cleanupTime: Long) {
- shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => {
- if (consolidateShuffleFiles) {
- for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
- file.delete()
- }
- } else {
- for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
- val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
- blockManager.diskBlockManager.getFile(blockId).delete()
- }
- }
- })
+ shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId))
+ }
+
+ def stop() {
+ metadataCleaner.cancel()
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 1b7934d59fa1d..c9a52e0366d93 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -19,10 +19,13 @@ package org.apache.spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
+import org.apache.spark.annotation.DeveloperApi
+
/**
* Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory,
- * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory
- * in a serialized format, and whether to replicate the RDD partitions on multiple nodes.
+ * or Tachyon, whether to drop the RDD to disk if it falls out of memory or Tachyon , whether to
+ * keep the data in memory in a serialized format, and whether to replicate the RDD partitions on
+ * multiple nodes.
* The [[org.apache.spark.storage.StorageLevel$]] singleton object contains some static constants
* for commonly useful storage levels. To create your own storage level object, use the
* factory method of the singleton object (`StorageLevel(...)`).
@@ -30,45 +33,58 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
class StorageLevel private(
private var useDisk_ : Boolean,
private var useMemory_ : Boolean,
+ private var useOffHeap_ : Boolean,
private var deserialized_ : Boolean,
private var replication_ : Int = 1)
extends Externalizable {
// TODO: Also add fields for caching priority, dataset ID, and flushing.
private def this(flags: Int, replication: Int) {
- this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
+ this((flags & 8) != 0, (flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
}
- def this() = this(false, true, false) // For deserialization
+ def this() = this(false, true, false, false) // For deserialization
def useDisk = useDisk_
def useMemory = useMemory_
+ def useOffHeap = useOffHeap_
def deserialized = deserialized_
def replication = replication_
assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes")
+ if (useOffHeap) {
+ require(useDisk == false, "Off-heap storage level does not support using disk")
+ require(useMemory == false, "Off-heap storage level does not support using heap memory")
+ require(deserialized == false, "Off-heap storage level does not support deserialized storage")
+ require(replication == 1, "Off-heap storage level does not support multiple replication")
+ }
+
override def clone(): StorageLevel = new StorageLevel(
- this.useDisk, this.useMemory, this.deserialized, this.replication)
+ this.useDisk, this.useMemory, this.useOffHeap, this.deserialized, this.replication)
override def equals(other: Any): Boolean = other match {
case s: StorageLevel =>
s.useDisk == useDisk &&
s.useMemory == useMemory &&
+ s.useOffHeap == useOffHeap &&
s.deserialized == deserialized &&
s.replication == replication
case _ =>
false
}
- def isValid = ((useMemory || useDisk) && (replication > 0))
+ def isValid = ((useMemory || useDisk || useOffHeap) && (replication > 0))
def toInt: Int = {
var ret = 0
if (useDisk_) {
- ret |= 4
+ ret |= 8
}
if (useMemory_) {
+ ret |= 4
+ }
+ if (useOffHeap_) {
ret |= 2
}
if (deserialized_) {
@@ -84,8 +100,9 @@ class StorageLevel private(
override def readExternal(in: ObjectInput) {
val flags = in.readByte()
- useDisk_ = (flags & 4) != 0
- useMemory_ = (flags & 2) != 0
+ useDisk_ = (flags & 8) != 0
+ useMemory_ = (flags & 4) != 0
+ useOffHeap_ = (flags & 2) != 0
deserialized_ = (flags & 1) != 0
replication_ = in.readByte()
}
@@ -93,14 +110,15 @@ class StorageLevel private(
@throws(classOf[IOException])
private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this)
- override def toString: String =
- "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
+ override def toString: String = "StorageLevel(%b, %b, %b, %b, %d)".format(
+ useDisk, useMemory, useOffHeap, deserialized, replication)
override def hashCode(): Int = toInt * 41 + replication
def description : String = {
var result = ""
result += (if (useDisk) "Disk " else "")
result += (if (useMemory) "Memory " else "")
+ result += (if (useOffHeap) "Tachyon " else "")
result += (if (deserialized) "Deserialized " else "Serialized ")
result += "%sx Replicated".format(replication)
result
@@ -113,28 +131,51 @@ class StorageLevel private(
* new storage levels.
*/
object StorageLevel {
- val NONE = new StorageLevel(false, false, false)
- val DISK_ONLY = new StorageLevel(true, false, false)
- val DISK_ONLY_2 = new StorageLevel(true, false, false, 2)
- val MEMORY_ONLY = new StorageLevel(false, true, true)
- val MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2)
- val MEMORY_ONLY_SER = new StorageLevel(false, true, false)
- val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2)
- val MEMORY_AND_DISK = new StorageLevel(true, true, true)
- val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2)
- val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false)
- val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2)
-
- /** Create a new StorageLevel object */
- def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) =
- getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication))
-
- /** Create a new StorageLevel object from its integer representation */
- def apply(flags: Int, replication: Int) =
+ val NONE = new StorageLevel(false, false, false, false)
+ val DISK_ONLY = new StorageLevel(true, false, false, false)
+ val DISK_ONLY_2 = new StorageLevel(true, false, false, false, 2)
+ val MEMORY_ONLY = new StorageLevel(false, true, false, true)
+ val MEMORY_ONLY_2 = new StorageLevel(false, true, false, true, 2)
+ val MEMORY_ONLY_SER = new StorageLevel(false, true, false, false)
+ val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, false, 2)
+ val MEMORY_AND_DISK = new StorageLevel(true, true, false, true)
+ val MEMORY_AND_DISK_2 = new StorageLevel(true, true, false, true, 2)
+ val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, false)
+ val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, false, 2)
+ val OFF_HEAP = new StorageLevel(false, false, true, false)
+
+ /**
+ * :: DeveloperApi ::
+ * Create a new StorageLevel object without setting useOffHeap
+ */
+ @DeveloperApi
+ def apply(useDisk: Boolean, useMemory: Boolean, useOffHeap: Boolean,
+ deserialized: Boolean, replication: Int) = getCachedStorageLevel(
+ new StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication))
+
+ /**
+ * :: DeveloperApi ::
+ * Create a new StorageLevel object
+ */
+ @DeveloperApi
+ def apply(useDisk: Boolean, useMemory: Boolean,
+ deserialized: Boolean, replication: Int = 1) = getCachedStorageLevel(
+ new StorageLevel(useDisk, useMemory, false, deserialized, replication))
+
+ /**
+ * :: DeveloperApi ::
+ * Create a new StorageLevel object from its integer representation
+ */
+ @DeveloperApi
+ def apply(flags: Int, replication: Int): StorageLevel =
getCachedStorageLevel(new StorageLevel(flags, replication))
- /** Read StorageLevel object from ObjectInput stream */
- def apply(in: ObjectInput) = {
+ /**
+ * :: DeveloperApi ::
+ * Read StorageLevel object from ObjectInput stream
+ */
+ @DeveloperApi
+ def apply(in: ObjectInput): StorageLevel = {
val obj = new StorageLevel()
obj.readExternal(in)
getCachedStorageLevel(obj)
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
new file mode 100644
index 0000000000000..7a174959037be
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import scala.collection.mutable
+
+import org.apache.spark.scheduler._
+
+/**
+ * A SparkListener that maintains executor storage status
+ */
+private[spark] class StorageStatusListener extends SparkListener {
+ private val executorIdToStorageStatus = mutable.Map[String, StorageStatus]()
+
+ def storageStatusList = executorIdToStorageStatus.values.toSeq
+
+ /** Update storage status list to reflect updated block statuses */
+ def updateStorageStatus(execId: String, updatedBlocks: Seq[(BlockId, BlockStatus)]) {
+ val filteredStatus = storageStatusList.find(_.blockManagerId.executorId == execId)
+ filteredStatus.foreach { storageStatus =>
+ updatedBlocks.foreach { case (blockId, updatedStatus) =>
+ storageStatus.blocks(blockId) = updatedStatus
+ }
+ }
+ }
+
+ /** Update storage status list to reflect the removal of an RDD from the cache */
+ def updateStorageStatus(unpersistedRDDId: Int) {
+ storageStatusList.foreach { storageStatus =>
+ val unpersistedBlocksIds = storageStatus.rddBlocks.keys.filter(_.rddId == unpersistedRDDId)
+ unpersistedBlocksIds.foreach { blockId =>
+ storageStatus.blocks(blockId) = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)
+ }
+ }
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
+ val info = taskEnd.taskInfo
+ val metrics = taskEnd.taskMetrics
+ if (info != null && metrics != null) {
+ val execId = formatExecutorId(info.executorId)
+ val updatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
+ if (updatedBlocks.length > 0) {
+ updateStorageStatus(execId, updatedBlocks)
+ }
+ }
+ }
+
+ override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized {
+ updateStorageStatus(unpersistRDD.rddId)
+ }
+
+ override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded) {
+ synchronized {
+ val blockManagerId = blockManagerAdded.blockManagerId
+ val executorId = blockManagerId.executorId
+ val maxMem = blockManagerAdded.maxMem
+ val storageStatus = new StorageStatus(blockManagerId, maxMem)
+ executorIdToStorageStatus(executorId) = storageStatus
+ }
+ }
+
+ override def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved) {
+ synchronized {
+ val executorId = blockManagerRemoved.blockManagerId.executorId
+ executorIdToStorageStatus.remove(executorId)
+ }
+ }
+
+ /**
+ * In the local mode, there is a discrepancy between the executor ID according to the
+ * task ("localhost") and that according to SparkEnv (""). In the UI, this
+ * results in duplicate rows for the same executor. Thus, in this mode, we aggregate
+ * these two rows and use the executor ID of "" to be consistent.
+ */
+ def formatExecutorId(execId: String): String = {
+ if (execId == "localhost") "" else execId
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
index 2d88a40fbb3f2..1eddd1cdc483b 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -17,100 +17,108 @@
package org.apache.spark.storage
+import scala.collection.Map
+import scala.collection.mutable
+
import org.apache.spark.SparkContext
-import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus
-import org.apache.spark.util.Utils
-private[spark]
-case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
- blocks: Map[BlockId, BlockStatus]) {
+/** Storage information for each BlockManager. */
+private[spark] class StorageStatus(
+ val blockManagerId: BlockManagerId,
+ val maxMem: Long,
+ val blocks: mutable.Map[BlockId, BlockStatus] = mutable.Map.empty) {
- def memUsed() = blocks.values.map(_.memSize).reduceOption(_ + _).getOrElse(0L)
+ def memUsed = blocks.values.map(_.memSize).reduceOption(_ + _).getOrElse(0L)
def memUsedByRDD(rddId: Int) =
rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_ + _).getOrElse(0L)
- def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_ + _).getOrElse(0L)
+ def diskUsed = blocks.values.map(_.diskSize).reduceOption(_ + _).getOrElse(0L)
def diskUsedByRDD(rddId: Int) =
rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_ + _).getOrElse(0L)
- def memRemaining : Long = maxMem - memUsed()
-
- def rddBlocks = blocks.flatMap {
- case (rdd: RDDBlockId, status) => Some(rdd, status)
- case _ => None
- }
-}
-
-case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
- numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long)
- extends Ordered[RDDInfo] {
- override def toString = {
- import Utils.bytesToString
- ("RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; " +
- "DiskSize: %s").format(name, id, storageLevel.toString, numCachedPartitions,
- numPartitions, bytesToString(memSize), bytesToString(diskSize))
- }
+ def memRemaining: Long = maxMem - memUsed
- override def compare(that: RDDInfo) = {
- this.id - that.id
- }
+ def rddBlocks = blocks.collect { case (rdd: RDDBlockId, status) => (rdd, status) }
}
-/* Helper methods for storage-related objects */
-private[spark]
-object StorageUtils {
-
- /* Returns RDD-level information, compiled from a list of StorageStatus objects */
- def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus],
- sc: SparkContext) : Array[RDDInfo] = {
- rddInfoFromBlockStatusList(
- storageStatusList.flatMap(_.rddBlocks).toMap[RDDBlockId, BlockStatus], sc)
+/** Helper methods for storage-related objects. */
+private[spark] object StorageUtils {
+
+ /**
+ * Returns basic information of all RDDs persisted in the given SparkContext. This does not
+ * include storage information.
+ */
+ def rddInfoFromSparkContext(sc: SparkContext): Array[RDDInfo] = {
+ sc.persistentRdds.values.map { rdd =>
+ val rddName = Option(rdd.name).getOrElse(rdd.id.toString)
+ val rddNumPartitions = rdd.partitions.size
+ val rddStorageLevel = rdd.getStorageLevel
+ val rddInfo = new RDDInfo(rdd.id, rddName, rddNumPartitions, rddStorageLevel)
+ rddInfo
+ }.toArray
}
- /* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */
- def blockLocationsFromStorageStatus(storageStatusList: Seq[StorageStatus]) = {
- val blockLocationPairs = storageStatusList
- .flatMap(s => s.blocks.map(b => (b._1, s.blockManagerId.hostPort)))
- blockLocationPairs.groupBy(_._1).map{case (k, v) => (k, v.unzip._2)}.toMap
+ /** Returns storage information of all RDDs persisted in the given SparkContext. */
+ def rddInfoFromStorageStatus(
+ storageStatuses: Seq[StorageStatus],
+ sc: SparkContext): Array[RDDInfo] = {
+ rddInfoFromStorageStatus(storageStatuses, rddInfoFromSparkContext(sc))
}
- /* Given a list of BlockStatus objets, returns information for each RDD */
- def rddInfoFromBlockStatusList(infos: Map[RDDBlockId, BlockStatus],
- sc: SparkContext) : Array[RDDInfo] = {
-
- // Group by rddId, ignore the partition name
- val groupedRddBlocks = infos.groupBy { case(k, v) => k.rddId }.mapValues(_.values.toArray)
-
- // For each RDD, generate an RDDInfo object
- val rddInfos = groupedRddBlocks.map { case (rddId, rddBlocks) =>
- // Add up memory and disk sizes
- val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
- val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _)
-
- // Get the friendly name and storage level for the RDD, if available
- sc.persistentRdds.get(rddId).map { r =>
- val rddName = Option(r.name).getOrElse(rddId.toString)
- val rddStorageLevel = r.getStorageLevel
- RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size,
- memSize, diskSize)
+ /** Returns storage information of all RDDs in the given list. */
+ def rddInfoFromStorageStatus(
+ storageStatuses: Seq[StorageStatus],
+ rddInfos: Seq[RDDInfo]): Array[RDDInfo] = {
+
+ // Mapping from RDD ID -> an array of associated BlockStatuses
+ val blockStatusMap = storageStatuses.flatMap(_.rddBlocks).toMap
+ .groupBy { case (k, _) => k.rddId }
+ .mapValues(_.values.toArray)
+
+ // Mapping from RDD ID -> the associated RDDInfo (with potentially outdated storage information)
+ val rddInfoMap = rddInfos.map { info => (info.id, info) }.toMap
+
+ val rddStorageInfos = blockStatusMap.flatMap { case (rddId, blocks) =>
+ // Add up memory, disk and Tachyon sizes
+ val persistedBlocks =
+ blocks.filter { status => status.memSize + status.diskSize + status.tachyonSize > 0 }
+ val memSize = persistedBlocks.map(_.memSize).reduceOption(_ + _).getOrElse(0L)
+ val diskSize = persistedBlocks.map(_.diskSize).reduceOption(_ + _).getOrElse(0L)
+ val tachyonSize = persistedBlocks.map(_.tachyonSize).reduceOption(_ + _).getOrElse(0L)
+ rddInfoMap.get(rddId).map { rddInfo =>
+ rddInfo.numCachedPartitions = persistedBlocks.length
+ rddInfo.memSize = memSize
+ rddInfo.diskSize = diskSize
+ rddInfo.tachyonSize = tachyonSize
+ rddInfo
}
- }.flatten.toArray
-
- scala.util.Sorting.quickSort(rddInfos)
+ }.toArray
- rddInfos
+ scala.util.Sorting.quickSort(rddStorageInfos)
+ rddStorageInfos
}
- /* Filters storage status by a given RDD id. */
- def filterStorageStatusByRDD(storageStatusList: Array[StorageStatus], rddId: Int)
- : Array[StorageStatus] = {
-
- storageStatusList.map { status =>
- val newBlocks = status.rddBlocks.filterKeys(_.rddId == rddId).toMap[BlockId, BlockStatus]
- //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _)
- StorageStatus(status.blockManagerId, status.maxMem, newBlocks)
+ /** Returns a mapping from BlockId to the locations of the associated block. */
+ def blockLocationsFromStorageStatus(
+ storageStatuses: Seq[StorageStatus]): Map[BlockId, Seq[String]] = {
+ val blockLocationPairs = storageStatuses.flatMap { storageStatus =>
+ storageStatus.blocks.map { case (bid, _) => (bid, storageStatus.blockManagerId.hostPort) }
}
+ blockLocationPairs.toMap
+ .groupBy { case (blockId, _) => blockId }
+ .mapValues(_.values.toSeq)
+ }
+
+ /** Filters the given list of StorageStatus by the given RDD ID. */
+ def filterStorageStatusByRDD(
+ storageStatuses: Seq[StorageStatus],
+ rddId: Int): Array[StorageStatus] = {
+ storageStatuses.map { status =>
+ val filteredBlocks = status.rddBlocks.filterKeys(_.rddId == rddId).toSeq
+ val filteredBlockMap = mutable.Map[BlockId, BlockStatus](filteredBlocks: _*)
+ new StorageStatus(status.blockManagerId, status.maxMem, filteredBlockMap)
+ }.toArray
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
new file mode 100644
index 0000000000000..b0b9674856568
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.text.SimpleDateFormat
+import java.util.{Date, Random}
+
+import tachyon.client.TachyonFS
+import tachyon.client.TachyonFile
+
+import org.apache.spark.Logging
+import org.apache.spark.executor.ExecutorExitCode
+import org.apache.spark.network.netty.ShuffleSender
+import org.apache.spark.util.Utils
+
+
+/**
+ * Creates and maintains the logical mapping between logical blocks and tachyon fs locations. By
+ * default, one block is mapped to one file with a name given by its BlockId.
+ *
+ * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
+ */
+private[spark] class TachyonBlockManager(
+ shuffleManager: ShuffleBlockManager,
+ rootDirs: String,
+ val master: String)
+ extends Logging {
+
+ val client = if (master != null && master != "") TachyonFS.get(master) else null
+
+ if (client == null) {
+ logError("Failed to connect to the Tachyon as the master address is not configured")
+ System.exit(ExecutorExitCode.TACHYON_STORE_FAILED_TO_INITIALIZE)
+ }
+
+ private val MAX_DIR_CREATION_ATTEMPTS = 10
+ private val subDirsPerTachyonDir =
+ shuffleManager.conf.get("spark.tachyonStore.subDirectories", "64").toInt
+
+ // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName;
+ // then, inside this directory, create multiple subdirectories that we will hash files into,
+ // in order to avoid having really large inodes at the top level in Tachyon.
+ private val tachyonDirs: Array[TachyonFile] = createTachyonDirs()
+ private val subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir))
+
+ addShutdownHook()
+
+ def removeFile(file: TachyonFile): Boolean = {
+ client.delete(file.getPath(), false)
+ }
+
+ def fileExists(file: TachyonFile): Boolean = {
+ client.exist(file.getPath())
+ }
+
+ def getFile(filename: String): TachyonFile = {
+ // Figure out which tachyon directory it hashes to, and which subdirectory in that
+ val hash = Utils.nonNegativeHash(filename)
+ val dirId = hash % tachyonDirs.length
+ val subDirId = (hash / tachyonDirs.length) % subDirsPerTachyonDir
+
+ // Create the subdirectory if it doesn't already exist
+ var subDir = subDirs(dirId)(subDirId)
+ if (subDir == null) {
+ subDir = subDirs(dirId).synchronized {
+ val old = subDirs(dirId)(subDirId)
+ if (old != null) {
+ old
+ } else {
+ val path = tachyonDirs(dirId) + "/" + "%02x".format(subDirId)
+ client.mkdir(path)
+ val newDir = client.getFile(path)
+ subDirs(dirId)(subDirId) = newDir
+ newDir
+ }
+ }
+ }
+ val filePath = subDir + "/" + filename
+ if(!client.exist(filePath)) {
+ client.createFile(filePath)
+ }
+ val file = client.getFile(filePath)
+ file
+ }
+
+ def getFile(blockId: BlockId): TachyonFile = getFile(blockId.name)
+
+ // TODO: Some of the logic here could be consolidated/de-duplicated with that in the DiskStore.
+ private def createTachyonDirs(): Array[TachyonFile] = {
+ logDebug("Creating tachyon directories at root dirs '" + rootDirs + "'")
+ val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
+ rootDirs.split(",").map { rootDir =>
+ var foundLocalDir = false
+ var tachyonDir: TachyonFile = null
+ var tachyonDirId: String = null
+ var tries = 0
+ val rand = new Random()
+ while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+ tries += 1
+ try {
+ tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
+ val path = rootDir + "/" + "spark-tachyon-" + tachyonDirId
+ if (!client.exist(path)) {
+ foundLocalDir = client.mkdir(path)
+ tachyonDir = client.getFile(path)
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create tachyon dir " + tachyonDir + " failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + " attempts to create tachyon dir in " +
+ rootDir)
+ System.exit(ExecutorExitCode.TACHYON_STORE_FAILED_TO_CREATE_DIR)
+ }
+ logInfo("Created tachyon directory at " + tachyonDir)
+ tachyonDir
+ }
+ }
+
+ private def addShutdownHook() {
+ tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir))
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark tachyon dirs") {
+ override def run() {
+ logDebug("Shutdown hook called")
+ tachyonDirs.foreach { tachyonDir =>
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) {
+ Utils.deleteRecursively(tachyonDir, client)
+ }
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting tachyon spark dir: " + tachyonDir, t)
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala
new file mode 100644
index 0000000000000..b86abbda1d3e7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import tachyon.client.TachyonFile
+
+/**
+ * References a particular segment of a file (potentially the entire file), based off an offset and
+ * a length.
+ */
+private[spark] class TachyonFileSegment(val file: TachyonFile, val offset: Long, val length: Long) {
+ override def toString = "(name=%s, offset=%d, length=%d)".format(file.getPath(), offset, length)
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
new file mode 100644
index 0000000000000..c37e76f893605
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.IOException
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import tachyon.client.{WriteType, ReadType}
+
+import org.apache.spark.Logging
+import org.apache.spark.util.Utils
+import org.apache.spark.serializer.Serializer
+
+
+private class Entry(val size: Long)
+
+
+/**
+ * Stores BlockManager blocks on Tachyon.
+ */
+private class TachyonStore(
+ blockManager: BlockManager,
+ tachyonManager: TachyonBlockManager)
+ extends BlockStore(blockManager: BlockManager) with Logging {
+
+ logInfo("TachyonStore started")
+
+ override def getSize(blockId: BlockId): Long = {
+ tachyonManager.getFile(blockId.name).length
+ }
+
+ override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = {
+ putToTachyonStore(blockId, bytes, true)
+ }
+
+ override def putValues(
+ blockId: BlockId,
+ values: ArrayBuffer[Any],
+ level: StorageLevel,
+ returnValues: Boolean): PutResult = {
+ return putValues(blockId, values.toIterator, level, returnValues)
+ }
+
+ override def putValues(
+ blockId: BlockId,
+ values: Iterator[Any],
+ level: StorageLevel,
+ returnValues: Boolean): PutResult = {
+ logDebug("Attempting to write values for block " + blockId)
+ val _bytes = blockManager.dataSerialize(blockId, values)
+ putToTachyonStore(blockId, _bytes, returnValues)
+ }
+
+ private def putToTachyonStore(
+ blockId: BlockId,
+ bytes: ByteBuffer,
+ returnValues: Boolean): PutResult = {
+ // So that we do not modify the input offsets !
+ // duplicate does not copy buffer, so inexpensive
+ val byteBuffer = bytes.duplicate()
+ byteBuffer.rewind()
+ logDebug("Attempting to put block " + blockId + " into Tachyon")
+ val startTime = System.currentTimeMillis
+ val file = tachyonManager.getFile(blockId)
+ val os = file.getOutStream(WriteType.TRY_CACHE)
+ os.write(byteBuffer.array())
+ os.close()
+ val finishTime = System.currentTimeMillis
+ logDebug("Block %s stored as %s file in Tachyon in %d ms".format(
+ blockId, Utils.bytesToString(byteBuffer.limit), (finishTime - startTime)))
+
+ if (returnValues) {
+ PutResult(bytes.limit(), Right(bytes.duplicate()))
+ } else {
+ PutResult(bytes.limit(), null)
+ }
+ }
+
+ override def remove(blockId: BlockId): Boolean = {
+ val file = tachyonManager.getFile(blockId)
+ if (tachyonManager.fileExists(file)) {
+ tachyonManager.removeFile(file)
+ } else {
+ false
+ }
+ }
+
+ override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
+ getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
+ }
+
+
+ override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
+ val file = tachyonManager.getFile(blockId)
+ if (file == null || file.getLocationHosts().size == 0) {
+ return None
+ }
+ val is = file.getInStream(ReadType.CACHE)
+ var buffer: ByteBuffer = null
+ try {
+ if (is != null) {
+ val size = file.length
+ val bs = new Array[Byte](size.asInstanceOf[Int])
+ val fetchSize = is.read(bs, 0, size.asInstanceOf[Int])
+ buffer = ByteBuffer.wrap(bs)
+ if (fetchSize != size) {
+ logWarning("Failed to fetch the block " + blockId + " from Tachyon : Size " + size +
+ " is not equal to fetched size " + fetchSize)
+ return None
+ }
+ }
+ } catch {
+ case ioe: IOException => {
+ logWarning("Failed to fetch the block " + blockId + " from Tachyon", ioe)
+ return None
+ }
+ }
+ Some(buffer)
+ }
+
+ override def contains(blockId: BlockId): Boolean = {
+ val file = tachyonManager.getFile(blockId)
+ tachyonManager.fileExists(file)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index 1d81d006c0b29..a107c5182b3be 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -22,7 +22,8 @@ import java.util.concurrent.ArrayBlockingQueue
import akka.actor._
import util.Random
-import org.apache.spark.SparkConf
+import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
+import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.KryoSerializer
/**
@@ -47,7 +48,7 @@ private[spark] object ThreadingTest {
val block = (1 to blockSize).map(_ => Random.nextInt())
val level = randomLevel()
val startTime = System.currentTimeMillis()
- manager.put(blockId, block.iterator, level, true)
+ manager.put(blockId, block.iterator, level, tellMaster = true)
println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms")
queue.add((blockId, block))
}
@@ -96,9 +97,11 @@ private[spark] object ThreadingTest {
val conf = new SparkConf()
val serializer = new KryoSerializer(conf)
val blockManagerMaster = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf)
+ actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))),
+ conf)
val blockManager = new BlockManager(
- "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf)
+ "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf,
+ new SecurityManager(conf), new MapOutputTrackerMaster(conf))
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start)
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 1f048a84cdfb6..b3ac2320f3431 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -17,98 +17,168 @@
package org.apache.spark.ui
-import java.net.InetSocketAddress
-import javax.servlet.http.{HttpServletResponse, HttpServletRequest}
+import java.net.{InetSocketAddress, URL}
+import javax.servlet.DispatcherType
+import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
import scala.annotation.tailrec
+import scala.language.implicitConversions
import scala.util.{Failure, Success, Try}
import scala.xml.Node
-import net.liftweb.json.{JValue, pretty, render}
-import org.eclipse.jetty.server.{Handler, Request, Server}
-import org.eclipse.jetty.server.handler.{AbstractHandler, ContextHandler, HandlerList, ResourceHandler}
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.server.handler._
+import org.eclipse.jetty.servlet._
import org.eclipse.jetty.util.thread.QueuedThreadPool
+import org.json4s.JValue
+import org.json4s.jackson.JsonMethods.{pretty, render}
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.util.Utils
-/** Utilities for launching a web server using Jetty's HTTP Server class */
+/**
+ * Utilities for launching a web server using Jetty's HTTP Server class
+ */
private[spark] object JettyUtils extends Logging {
+
// Base type for a function that returns something based on an HTTP request. Allows for
// implicit conversion from many types of functions to jetty Handlers.
-
type Responder[T] = HttpServletRequest => T
- // Conversions from various types of Responder's to jetty Handlers
- implicit def jsonResponderToHandler(responder: Responder[JValue]): Handler =
- createHandler(responder, "text/json", (in: JValue) => pretty(render(in)))
-
- implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): Handler =
- createHandler(responder, "text/html", (in: Seq[Node]) => "" + in.toString)
-
- implicit def textResponderToHandler(responder: Responder[String]): Handler =
- createHandler(responder, "text/plain")
-
- def createHandler[T <% AnyRef](responder: Responder[T], contentType: String,
- extractFn: T => String = (in: Any) => in.toString): Handler = {
- new AbstractHandler {
- def handle(target: String,
- baseRequest: Request,
- request: HttpServletRequest,
- response: HttpServletResponse) {
- response.setContentType("%s;charset=utf-8".format(contentType))
- response.setStatus(HttpServletResponse.SC_OK)
- baseRequest.setHandled(true)
- val result = responder(request)
- response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
- response.getWriter().println(extractFn(result))
+ class ServletParams[T <% AnyRef](val responder: Responder[T],
+ val contentType: String,
+ val extractFn: T => String = (in: Any) => in.toString) {}
+
+ // Conversions from various types of Responder's to appropriate servlet parameters
+ implicit def jsonResponderToServlet(responder: Responder[JValue]): ServletParams[JValue] =
+ new ServletParams(responder, "text/json", (in: JValue) => pretty(render(in)))
+
+ implicit def htmlResponderToServlet(responder: Responder[Seq[Node]]): ServletParams[Seq[Node]] =
+ new ServletParams(responder, "text/html", (in: Seq[Node]) => "" + in.toString)
+
+ implicit def textResponderToServlet(responder: Responder[String]): ServletParams[String] =
+ new ServletParams(responder, "text/plain")
+
+ def createServlet[T <% AnyRef](
+ servletParams: ServletParams[T],
+ securityMgr: SecurityManager): HttpServlet = {
+ new HttpServlet {
+ override def doGet(request: HttpServletRequest, response: HttpServletResponse) {
+ if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) {
+ response.setContentType("%s;charset=utf-8".format(servletParams.contentType))
+ response.setStatus(HttpServletResponse.SC_OK)
+ val result = servletParams.responder(request)
+ response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+ response.getWriter.println(servletParams.extractFn(result))
+ } else {
+ response.setStatus(HttpServletResponse.SC_UNAUTHORIZED)
+ response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+ response.sendError(HttpServletResponse.SC_UNAUTHORIZED,
+ "User is not authorized to access this page.")
+ }
}
}
}
- /** Creates a handler that always redirects the user to a given path */
- def createRedirectHandler(newPath: String): Handler = {
- new AbstractHandler {
- def handle(target: String,
- baseRequest: Request,
- request: HttpServletRequest,
- response: HttpServletResponse) {
- response.setStatus(302)
- response.setHeader("Location", baseRequest.getRootURL + newPath)
- baseRequest.setHandled(true)
+ /** Create a context handler that responds to a request with the given path prefix */
+ def createServletHandler[T <% AnyRef](
+ path: String,
+ servletParams: ServletParams[T],
+ securityMgr: SecurityManager,
+ basePath: String = ""): ServletContextHandler = {
+ createServletHandler(path, createServlet(servletParams, securityMgr), basePath)
+ }
+
+ /** Create a context handler that responds to a request with the given path prefix */
+ def createServletHandler(
+ path: String,
+ servlet: HttpServlet,
+ basePath: String = ""): ServletContextHandler = {
+ val prefixedPath = attachPrefix(basePath, path)
+ val contextHandler = new ServletContextHandler
+ val holder = new ServletHolder(servlet)
+ contextHandler.setContextPath(prefixedPath)
+ contextHandler.addServlet(holder, "/")
+ contextHandler
+ }
+
+ /** Create a handler that always redirects the user to the given path */
+ def createRedirectHandler(
+ srcPath: String,
+ destPath: String,
+ beforeRedirect: HttpServletRequest => Unit = x => (),
+ basePath: String = ""): ServletContextHandler = {
+ val prefixedDestPath = attachPrefix(basePath, destPath)
+ val servlet = new HttpServlet {
+ override def doGet(request: HttpServletRequest, response: HttpServletResponse) {
+ beforeRedirect(request)
+ // Make sure we don't end up with "//" in the middle
+ val newUrl = new URL(new URL(request.getRequestURL.toString), prefixedDestPath).toString
+ response.sendRedirect(newUrl)
}
}
+ createServletHandler(srcPath, servlet, basePath)
}
- /** Creates a handler for serving files from a static directory */
- def createStaticHandler(resourceBase: String): ResourceHandler = {
- val staticHandler = new ResourceHandler
- Option(getClass.getClassLoader.getResource(resourceBase)) match {
+ /** Create a handler for serving files from a static directory */
+ def createStaticHandler(resourceBase: String, path: String): ServletContextHandler = {
+ val contextHandler = new ServletContextHandler
+ contextHandler.setInitParameter("org.eclipse.jetty.servlet.Default.gzip", "false")
+ val staticHandler = new DefaultServlet
+ val holder = new ServletHolder(staticHandler)
+ Option(Utils.getSparkClassLoader.getResource(resourceBase)) match {
case Some(res) =>
- staticHandler.setResourceBase(res.toString)
+ holder.setInitParameter("resourceBase", res.toString)
case None =>
throw new Exception("Could not find resource path for Web UI: " + resourceBase)
}
- staticHandler
+ contextHandler.setContextPath(path)
+ contextHandler.addServlet(holder, "/")
+ contextHandler
+ }
+
+ /** Add filters, if any, to the given list of ServletContextHandlers */
+ def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) {
+ val filters: Array[String] = conf.get("spark.ui.filters", "").split(',').map(_.trim())
+ filters.foreach {
+ case filter : String =>
+ if (!filter.isEmpty) {
+ logInfo("Adding filter: " + filter)
+ val holder : FilterHolder = new FilterHolder()
+ holder.setClassName(filter)
+ // Get any parameters for each filter
+ val paramName = "spark." + filter + ".params"
+ val params = conf.get(paramName, "").split(',').map(_.trim()).toSet
+ params.foreach {
+ case param : String =>
+ if (!param.isEmpty) {
+ val parts = param.split("=")
+ if (parts.length == 2) holder.setInitParameter(parts(0), parts(1))
+ }
+ }
+ val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR,
+ DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST)
+ handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) }
+ }
+ }
}
/**
- * Attempts to start a Jetty server at the supplied hostName:port which uses the supplied
- * handlers.
+ * Attempt to start a Jetty server bound to the supplied hostName:port using the given
+ * context handlers.
*
- * If the desired port number is contented, continues incrementing ports until a free port is
- * found. Returns the chosen port and the jetty Server object.
+ * If the desired port number is contended, continues incrementing ports until a free port is
+ * found. Return the jetty Server object, the chosen port, and a mutable collection of handlers.
*/
- def startJettyServer(hostName: String, port: Int, handlers: Seq[(String, Handler)]): (Server, Int)
- = {
-
- val handlersToRegister = handlers.map { case(path, handler) =>
- val contextHandler = new ContextHandler(path)
- contextHandler.setHandler(handler)
- contextHandler.asInstanceOf[org.eclipse.jetty.server.Handler]
- }
+ def startJettyServer(
+ hostName: String,
+ port: Int,
+ handlers: Seq[ServletContextHandler],
+ conf: SparkConf): ServerInfo = {
- val handlerList = new HandlerList
- handlerList.setHandlers(handlersToRegister.toArray)
+ val collection = new ContextHandlerCollection
+ collection.setHandlers(handlers.toArray)
+ addFilters(handlers, conf)
@tailrec
def connect(currentPort: Int): (Server, Int) = {
@@ -116,19 +186,33 @@ private[spark] object JettyUtils extends Logging {
val pool = new QueuedThreadPool
pool.setDaemon(true)
server.setThreadPool(pool)
- server.setHandler(handlerList)
+ server.setHandler(collection)
- Try { server.start() } match {
+ Try {
+ server.start()
+ } match {
case s: Success[_] =>
(server, server.getConnectors.head.getLocalPort)
case f: Failure[_] =>
server.stop()
+ pool.stop()
logInfo("Failed to create UI at port, %s. Trying again.".format(currentPort))
logInfo("Error was: " + f.toString)
connect((currentPort + 1) % 65536)
}
}
- connect(port)
+ val (server, boundPort) = connect(port)
+ ServerInfo(server, boundPort, collection)
+ }
+
+ /** Attach a prefix to the given path, but avoid returning an empty path */
+ private def attachPrefix(basePath: String, relativePath: String): String = {
+ if (basePath == "") relativePath else (basePath + relativePath).stripSuffix("/")
}
}
+
+private[spark] case class ServerInfo(
+ server: Server,
+ boundPort: Int,
+ rootHandler: ContextHandlerCollection)
diff --git a/core/src/main/scala/org/apache/spark/ui/Page.scala b/core/src/main/scala/org/apache/spark/ui/Page.scala
deleted file mode 100644
index b2a069a37552d..0000000000000
--- a/core/src/main/scala/org/apache/spark/ui/Page.scala
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ui
-
-private[spark] object Page extends Enumeration {
- val Stages, Storage, Environment, Executors = Value
-}
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index af6b65860e006..097a1b81e1dd1 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -17,71 +17,94 @@
package org.apache.spark.ui
-import org.eclipse.jetty.server.{Handler, Server}
-
-import org.apache.spark.{Logging, SparkContext, SparkEnv}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext}
+import org.apache.spark.scheduler._
+import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.ui.env.EnvironmentUI
-import org.apache.spark.ui.exec.ExecutorsUI
-import org.apache.spark.ui.jobs.JobProgressUI
-import org.apache.spark.ui.storage.BlockManagerUI
-import org.apache.spark.util.Utils
-
-/** Top level user interface for Spark */
-private[spark] class SparkUI(sc: SparkContext) extends Logging {
- val host = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(Utils.localHostName())
- val port = sc.conf.get("spark.ui.port", SparkUI.DEFAULT_PORT).toInt
- var boundPort: Option[Int] = None
- var server: Option[Server] = None
-
- val handlers = Seq[(String, Handler)](
- ("/static", createStaticHandler(SparkUI.STATIC_RESOURCE_DIR)),
- ("/", createRedirectHandler("/stages"))
- )
- val storage = new BlockManagerUI(sc)
- val jobs = new JobProgressUI(sc)
- val env = new EnvironmentUI(sc)
- val exec = new ExecutorsUI(sc)
-
- // Add MetricsServlet handlers by default
- val metricsServletHandlers = SparkEnv.get.metricsSystem.getServletHandlers
-
- val allHandlers = storage.getHandlers ++ jobs.getHandlers ++ env.getHandlers ++
- exec.getHandlers ++ metricsServletHandlers ++ handlers
-
- /** Bind the HTTP server which backs this web interface */
- def bind() {
- try {
- val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers)
- logInfo("Started Spark Web UI at http://%s:%d".format(host, usedPort))
- server = Some(srv)
- boundPort = Some(usedPort)
- } catch {
- case e: Exception =>
- logError("Failed to create Spark JettyUtils", e)
- System.exit(1)
+import org.apache.spark.ui.env.EnvironmentTab
+import org.apache.spark.ui.exec.ExecutorsTab
+import org.apache.spark.ui.jobs.JobProgressTab
+import org.apache.spark.ui.storage.StorageTab
+
+/**
+ * Top level user interface for a Spark application.
+ */
+private[spark] class SparkUI(
+ val sc: SparkContext,
+ val conf: SparkConf,
+ val securityManager: SecurityManager,
+ val listenerBus: SparkListenerBus,
+ var appName: String,
+ val basePath: String = "")
+ extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath)
+ with Logging {
+
+ def this(sc: SparkContext) = this(sc, sc.conf, sc.env.securityManager, sc.listenerBus, sc.appName)
+ def this(conf: SparkConf, listenerBus: SparkListenerBus, appName: String, basePath: String) =
+ this(null, conf, new SecurityManager(conf), listenerBus, appName, basePath)
+
+ def this(
+ conf: SparkConf,
+ securityManager: SecurityManager,
+ listenerBus: SparkListenerBus,
+ appName: String,
+ basePath: String) =
+ this(null, conf, securityManager, listenerBus, appName, basePath)
+
+ // If SparkContext is not provided, assume the associated application is not live
+ val live = sc != null
+
+ // Maintain executor storage status through Spark events
+ val storageStatusListener = new StorageStatusListener
+
+ initialize()
+
+ /** Initialize all components of the server. */
+ def initialize() {
+ listenerBus.addListener(storageStatusListener)
+ val jobProgressTab = new JobProgressTab(this)
+ attachTab(jobProgressTab)
+ attachTab(new StorageTab(this))
+ attachTab(new EnvironmentTab(this))
+ attachTab(new ExecutorsTab(this))
+ attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static"))
+ attachHandler(createRedirectHandler("/", "/stages", basePath = basePath))
+ attachHandler(
+ createRedirectHandler("/stages/stage/kill", "/stages", jobProgressTab.handleKillRequest))
+ if (live) {
+ sc.env.metricsSystem.getServletHandlers.foreach(attachHandler)
}
}
- /** Initialize all components of the server */
- def start() {
- // NOTE: This is decoupled from bind() because of the following dependency cycle:
- // DAGScheduler() requires that the port of this server is known
- // This server must register all handlers, including JobProgressUI, before binding
- // JobProgressUI registers a listener with SparkContext, which requires sc to initialize
- jobs.start()
- exec.start()
+ /** Set the app name for this UI. */
+ def setAppName(name: String) {
+ appName = name
+ }
+
+ /** Register the given listener with the listener bus. */
+ def registerListener(listener: SparkListener) {
+ listenerBus.addListener(listener)
}
- def stop() {
- server.foreach(_.stop())
+ /** Stop the server behind this web interface. Only valid after bind(). */
+ override def stop() {
+ super.stop()
+ logInfo("Stopped Spark web UI at %s".format(appUIAddress))
}
- private[spark] def appUIAddress = host + ":" + boundPort.getOrElse("-1")
+ /**
+ * Return the application UI host:port. This does not include the scheme (http://).
+ */
+ private[spark] def appUIHostPort = publicHostName + ":" + boundPort
+ private[spark] def appUIAddress = s"http://$appUIHostPort"
}
private[spark] object SparkUI {
- val DEFAULT_PORT = "4040"
+ val DEFAULT_PORT = 4040
val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+
+ def getUIPort(conf: SparkConf): Int = {
+ conf.getInt("spark.ui.port", SparkUI.DEFAULT_PORT)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 547a194d58a5c..a3d6a1821245b 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -17,40 +17,136 @@
package org.apache.spark.ui
-import scala.xml.Node
+import java.text.SimpleDateFormat
+import java.util.{Locale, Date}
-import org.apache.spark.SparkContext
+import scala.xml.Node
+import org.apache.spark.Logging
/** Utility functions for generating XML pages with spark content. */
-private[spark] object UIUtils {
- import Page._
+private[spark] object UIUtils extends Logging {
- // Yarn has to go through a proxy so the base uri is provided and has to be on all links
- private[spark] val uiRoot : String = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).
- getOrElse("")
+ // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
+ private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
+ override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ }
- def prependBaseUri(resource: String = "") = uiRoot + resource
+ def formatDate(date: Date): String = dateFormat.get.format(date)
- /** Returns a spark page with correctly formatted headers */
- def headerSparkPage(content: => Seq[Node], sc: SparkContext, title: String, page: Page.Value)
- : Seq[Node] = {
- val jobs = page match {
- case Stages =>
@@ -123,21 +214,36 @@ private[spark] object UIUtils {
/** Returns an HTML table constructed by generating a row for each object in a sequence. */
def listingTable[T](
headers: Seq[String],
- makeRow: T => Seq[Node],
- rows: Seq[T],
+ generateDataRow: T => Seq[Node],
+ data: Seq[T],
fixedWidth: Boolean = false): Seq[Node] = {
- val colWidth = 100.toDouble / headers.size
- val colWidthAttr = if (fixedWidth) colWidth + "%" else ""
var tableClass = "table table-bordered table-striped table-condensed sortable"
if (fixedWidth) {
tableClass += " table-fixed"
}
-
+ val colWidth = 100.toDouble / headers.size
+ val colWidthAttr = if (fixedWidth) colWidth + "%" else ""
+ val headerRow: Seq[Node] = {
+ // if none of the headers have "\n" in them
+ if (headers.forall(!_.contains("\n"))) {
+ // represent header as simple text
+ headers.map(h =>
{h}
)
+ } else {
+ // represent header text as list while respecting "\n"
+ headers.map { case h =>
+
}
diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
new file mode 100644
index 0000000000000..b08f308fda1dd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.collection.mutable.ArrayBuffer
+import scala.xml.Node
+
+import org.eclipse.jetty.servlet.ServletContextHandler
+import org.json4s.JsonAST.{JNothing, JValue}
+
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.ui.JettyUtils._
+import org.apache.spark.util.Utils
+
+/**
+ * The top level component of the UI hierarchy that contains the server.
+ *
+ * Each WebUI represents a collection of tabs, each of which in turn represents a collection of
+ * pages. The use of tabs is optional, however; a WebUI may choose to include pages directly.
+ */
+private[spark] abstract class WebUI(
+ securityManager: SecurityManager,
+ port: Int,
+ conf: SparkConf,
+ basePath: String = "")
+ extends Logging {
+
+ protected val tabs = ArrayBuffer[WebUITab]()
+ protected val handlers = ArrayBuffer[ServletContextHandler]()
+ protected var serverInfo: Option[ServerInfo] = None
+ protected val localHostName = Utils.localHostName()
+ protected val publicHostName = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName)
+ private val className = Utils.getFormattedClassName(this)
+
+ def getTabs: Seq[WebUITab] = tabs.toSeq
+ def getHandlers: Seq[ServletContextHandler] = handlers.toSeq
+
+ /** Attach a tab to this UI, along with all of its attached pages. */
+ def attachTab(tab: WebUITab) {
+ tab.pages.foreach(attachPage)
+ tabs += tab
+ }
+
+ /** Attach a page to this UI. */
+ def attachPage(page: WebUIPage) {
+ val pagePath = "/" + page.prefix
+ attachHandler(createServletHandler(pagePath,
+ (request: HttpServletRequest) => page.render(request), securityManager, basePath))
+ attachHandler(createServletHandler(pagePath.stripSuffix("/") + "/json",
+ (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath))
+ }
+
+ /** Attach a handler to this UI. */
+ def attachHandler(handler: ServletContextHandler) {
+ handlers += handler
+ serverInfo.foreach { info =>
+ info.rootHandler.addHandler(handler)
+ if (!handler.isStarted) {
+ handler.start()
+ }
+ }
+ }
+
+ /** Detach a handler from this UI. */
+ def detachHandler(handler: ServletContextHandler) {
+ handlers -= handler
+ serverInfo.foreach { info =>
+ info.rootHandler.removeHandler(handler)
+ if (handler.isStarted) {
+ handler.stop()
+ }
+ }
+ }
+
+ /** Initialize all components of the server. */
+ def initialize()
+
+ /** Bind to the HTTP server behind this web interface. */
+ def bind() {
+ assert(!serverInfo.isDefined, "Attempted to bind %s more than once!".format(className))
+ try {
+ serverInfo = Some(startJettyServer("0.0.0.0", port, handlers, conf))
+ logInfo("Started %s at http://%s:%d".format(className, publicHostName, boundPort))
+ } catch {
+ case e: Exception =>
+ logError("Failed to bind %s".format(className), e)
+ System.exit(1)
+ }
+ }
+
+ /** Return the actual port to which this server is bound. Only valid after bind(). */
+ def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1)
+
+ /** Stop the server behind this web interface. Only valid after bind(). */
+ def stop() {
+ assert(serverInfo.isDefined,
+ "Attempted to stop %s before binding to a server!".format(className))
+ serverInfo.get.server.stop()
+ }
+}
+
+
+/**
+ * A tab that represents a collection of pages.
+ * The prefix is appended to the parent address to form a full path, and must not contain slashes.
+ */
+private[spark] abstract class WebUITab(parent: WebUI, val prefix: String) {
+ val pages = ArrayBuffer[WebUIPage]()
+ val name = prefix.capitalize
+
+ /** Attach a page to this tab. This prepends the page's prefix with the tab's own prefix. */
+ def attachPage(page: WebUIPage) {
+ page.prefix = (prefix + "/" + page.prefix).stripSuffix("/")
+ pages += page
+ }
+
+ /** Get a list of header tabs from the parent UI. */
+ def headerTabs: Seq[WebUITab] = parent.getTabs
+}
+
+
+/**
+ * A page that represents the leaf node in the UI hierarchy.
+ *
+ * The direct parent of a WebUIPage is not specified as it can be either a WebUI or a WebUITab.
+ * If the parent is a WebUI, the prefix is appended to the parent's address to form a full path.
+ * Else, if the parent is a WebUITab, the prefix is appended to the super prefix of the parent
+ * to form a relative path. The prefix must not contain slashes.
+ */
+private[spark] abstract class WebUIPage(var prefix: String) {
+ def render(request: HttpServletRequest): Seq[Node]
+ def renderJson(request: HttpServletRequest): JValue = JNothing
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala
new file mode 100644
index 0000000000000..b347eb1b83c1f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.env
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import org.apache.spark.ui.{UIUtils, WebUIPage}
+
+private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("") {
+ private val appName = parent.appName
+ private val basePath = parent.basePath
+ private val listener = parent.listener
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val runtimeInformationTable = UIUtils.listingTable(
+ propertyHeader, jvmRow, listener.jvmInformation, fixedWidth = true)
+ val sparkPropertiesTable = UIUtils.listingTable(
+ propertyHeader, propertyRow, listener.sparkProperties, fixedWidth = true)
+ val systemPropertiesTable = UIUtils.listingTable(
+ propertyHeader, propertyRow, listener.systemProperties, fixedWidth = true)
+ val classpathEntriesTable = UIUtils.listingTable(
+ classPathHeaders, classPathRow, listener.classpathEntries, fixedWidth = true)
+ val content =
+
+
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
new file mode 100644
index 0000000000000..03b46e1bd59af
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.env
+
+import org.apache.spark.scheduler._
+import org.apache.spark.ui._
+
+private[ui] class EnvironmentTab(parent: SparkUI) extends WebUITab(parent, "environment") {
+ val appName = parent.appName
+ val basePath = parent.basePath
+ val listener = new EnvironmentListener
+
+ attachPage(new EnvironmentPage(this))
+ parent.registerListener(listener)
+}
+
+/**
+ * A SparkListener that prepares information to be displayed on the EnvironmentTab
+ */
+private[ui] class EnvironmentListener extends SparkListener {
+ var jvmInformation = Seq[(String, String)]()
+ var sparkProperties = Seq[(String, String)]()
+ var systemProperties = Seq[(String, String)]()
+ var classpathEntries = Seq[(String, String)]()
+
+ override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) {
+ synchronized {
+ val environmentDetails = environmentUpdate.environmentDetails
+ jvmInformation = environmentDetails("JVM Information")
+ sparkProperties = environmentDetails("Spark Properties")
+ systemProperties = environmentDetails("System Properties")
+ classpathEntries = environmentDetails("Classpath Entries")
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
deleted file mode 100644
index 9e7cdc88162e8..0000000000000
--- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ui.env
-
-import javax.servlet.http.HttpServletRequest
-
-import scala.collection.JavaConversions._
-import scala.util.Properties
-import scala.xml.Node
-
-import org.eclipse.jetty.server.Handler
-
-import org.apache.spark.SparkContext
-import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.ui.Page.Environment
-import org.apache.spark.ui.UIUtils
-
-private[spark] class EnvironmentUI(sc: SparkContext) {
-
- def getHandlers = Seq[(String, Handler)](
- ("/environment", (request: HttpServletRequest) => envDetails(request))
- )
-
- def envDetails(request: HttpServletRequest): Seq[Node] = {
- val jvmInformation = Seq(
- ("Java Version", "%s (%s)".format(Properties.javaVersion, Properties.javaVendor)),
- ("Java Home", Properties.javaHome),
- ("Scala Version", Properties.versionString),
- ("Scala Home", Properties.scalaHome)
- ).sorted
- def jvmRow(kv: (String, String)) =
{kv._1}
{kv._2}
- def jvmTable =
- UIUtils.listingTable(Seq("Name", "Value"), jvmRow, jvmInformation, fixedWidth = true)
-
- val sparkProperties = sc.conf.getAll.sorted
-
- val systemProperties = System.getProperties.iterator.toSeq
- val classPathProperty = systemProperties.find { case (k, v) =>
- k == "java.class.path"
- }.getOrElse(("", ""))
- val otherProperties = systemProperties.filter { case (k, v) =>
- k != "java.class.path" && !k.startsWith("spark.")
- }.sorted
-
- val propertyHeaders = Seq("Name", "Value")
- def propertyRow(kv: (String, String)) =
{kv._1}
{kv._2}
- val sparkPropertyTable =
- UIUtils.listingTable(propertyHeaders, propertyRow, sparkProperties, fixedWidth = true)
- val otherPropertyTable =
- UIUtils.listingTable(propertyHeaders, propertyRow, otherProperties, fixedWidth = true)
-
- val classPathEntries = classPathProperty._2
- .split(sc.conf.get("path.separator", ":"))
- .filterNot(e => e.isEmpty)
- .map(e => (e, "System Classpath"))
- val addedJars = sc.addedJars.iterator.toSeq.map{case (path, time) => (path, "Added By User")}
- val addedFiles = sc.addedFiles.iterator.toSeq.map{case (path, time) => (path, "Added By User")}
- val classPath = (addedJars ++ addedFiles ++ classPathEntries).sorted
-
- val classPathHeaders = Seq("Resource", "Source")
- def classPathRow(data: (String, String)) =
{data._1}
{data._2}
- val classPathTable =
- UIUtils.listingTable(classPathHeaders, classPathRow, classPath, fixedWidth = true)
-
- val content =
-
-
Runtime Information
{jvmTable}
-
Spark Properties
- {sparkPropertyTable}
-
System Properties
- {otherPropertyTable}
-
Classpath Entries
- {classPathTable}
-
-
- UIUtils.headerSparkPage(content, sc, "Environment", Environment)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
new file mode 100644
index 0000000000000..6cb43c02b8f08
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.exec
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.util.Utils
+
+private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") {
+ private val appName = parent.appName
+ private val basePath = parent.basePath
+ private val listener = parent.listener
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val storageStatusList = listener.storageStatusList
+ val maxMem = storageStatusList.map(_.maxMem).fold(0L)(_ + _)
+ val memUsed = storageStatusList.map(_.memUsed).fold(0L)(_ + _)
+ val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_ + _)
+ val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId)
+ val execInfoSorted = execInfo.sortBy(_.getOrElse("Executor ID", ""))
+ val execTable = UIUtils.listingTable(execHeader, execRow, execInfoSorted)
+
+ val content =
+
+
+
+
Memory:
+ {Utils.bytesToString(memUsed)} Used
+ ({Utils.bytesToString(maxMem)} Total)
+ }
+
+ /** Represent an executor's info as a map given a storage status index */
+ private def getExecInfo(statusId: Int): Map[String, String] = {
+ val status = listener.storageStatusList(statusId)
+ val execId = status.blockManagerId.executorId
+ val hostPort = status.blockManagerId.hostPort
+ val rddBlocks = status.blocks.size
+ val memUsed = status.memUsed
+ val maxMem = status.maxMem
+ val diskUsed = status.diskUsed
+ val activeTasks = listener.executorToTasksActive.getOrElse(execId, 0)
+ val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0)
+ val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0)
+ val totalTasks = activeTasks + failedTasks + completedTasks
+ val totalDuration = listener.executorToDuration.getOrElse(execId, 0)
+ val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0)
+ val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0)
+
+ // Also include fields not in the header
+ val execFields = execHeader ++ Seq("Maximum Memory")
+
+ val execValues = Seq(
+ execId,
+ hostPort,
+ rddBlocks,
+ memUsed,
+ diskUsed,
+ activeTasks,
+ failedTasks,
+ completedTasks,
+ totalTasks,
+ totalDuration,
+ totalShuffleRead,
+ totalShuffleWrite,
+ maxMem
+ ).map(_.toString)
+
+ execFields.zip(execValues).toMap
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
new file mode 100644
index 0000000000000..5678bf34ac730
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.exec
+
+import scala.collection.mutable.HashMap
+
+import org.apache.spark.ExceptionFailure
+import org.apache.spark.scheduler._
+import org.apache.spark.storage.StorageStatusListener
+import org.apache.spark.ui.{SparkUI, WebUITab}
+
+private[ui] class ExecutorsTab(parent: SparkUI) extends WebUITab(parent, "executors") {
+ val appName = parent.appName
+ val basePath = parent.basePath
+ val listener = new ExecutorsListener(parent.storageStatusListener)
+
+ attachPage(new ExecutorsPage(this))
+ parent.registerListener(listener)
+}
+
+/**
+ * A SparkListener that prepares information to be displayed on the ExecutorsTab
+ */
+private[ui] class ExecutorsListener(storageStatusListener: StorageStatusListener)
+ extends SparkListener {
+
+ val executorToTasksActive = HashMap[String, Int]()
+ val executorToTasksComplete = HashMap[String, Int]()
+ val executorToTasksFailed = HashMap[String, Int]()
+ val executorToDuration = HashMap[String, Long]()
+ val executorToShuffleRead = HashMap[String, Long]()
+ val executorToShuffleWrite = HashMap[String, Long]()
+
+ def storageStatusList = storageStatusListener.storageStatusList
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
+ val eid = formatExecutorId(taskStart.taskInfo.executorId)
+ executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
+ val info = taskEnd.taskInfo
+ if (info != null) {
+ val eid = formatExecutorId(info.executorId)
+ executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1
+ executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration
+ taskEnd.reason match {
+ case e: ExceptionFailure =>
+ executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1
+ case _ =>
+ executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1
+ }
+
+ // Update shuffle read/write
+ val metrics = taskEnd.taskMetrics
+ if (metrics != null) {
+ metrics.shuffleReadMetrics.foreach { shuffleRead =>
+ executorToShuffleRead(eid) =
+ executorToShuffleRead.getOrElse(eid, 0L) + shuffleRead.remoteBytesRead
+ }
+ metrics.shuffleWriteMetrics.foreach { shuffleWrite =>
+ executorToShuffleWrite(eid) =
+ executorToShuffleWrite.getOrElse(eid, 0L) + shuffleWrite.shuffleBytesWritten
+ }
+ }
+ }
+ }
+
+ // This addresses executor ID inconsistencies in the local mode
+ private def formatExecutorId(execId: String) = storageStatusListener.formatExecutorId(execId)
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
deleted file mode 100644
index 1f3b7a4c231b6..0000000000000
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
+++ /dev/null
@@ -1,181 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ui.exec
-
-import javax.servlet.http.HttpServletRequest
-
-import scala.collection.mutable.{HashMap, HashSet}
-import scala.xml.Node
-
-import org.eclipse.jetty.server.Handler
-
-import org.apache.spark.{ExceptionFailure, Logging, SparkContext}
-import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart}
-import org.apache.spark.scheduler.TaskInfo
-import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.ui.Page.Executors
-import org.apache.spark.ui.UIUtils
-import org.apache.spark.util.Utils
-
-private[spark] class ExecutorsUI(val sc: SparkContext) {
-
- private var _listener: Option[ExecutorsListener] = None
- def listener = _listener.get
-
- def start() {
- _listener = Some(new ExecutorsListener)
- sc.addSparkListener(listener)
- }
-
- def getHandlers = Seq[(String, Handler)](
- ("/executors", (request: HttpServletRequest) => render(request))
- )
-
- def render(request: HttpServletRequest): Seq[Node] = {
- val storageStatusList = sc.getExecutorStorageStatus
-
- val maxMem = storageStatusList.map(_.maxMem).fold(0L)(_ + _)
- val memUsed = storageStatusList.map(_.memUsed()).fold(0L)(_ + _)
- val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_ + _)
-
- val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used",
- "Active tasks", "Failed tasks", "Complete tasks", "Total tasks", "Task Time", "Shuffle Read",
- "Shuffle Write")
-
- def execRow(kv: Seq[String]) = {
-
}
}
- case _ => { Seq[Node]() }
+ case _ => Seq[Node]()
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
deleted file mode 100644
index 81713edcf5db2..0000000000000
--- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
+++ /dev/null
@@ -1,90 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ui.jobs
-
-import javax.servlet.http.HttpServletRequest
-
-import scala.xml.{NodeSeq, Node}
-
-import org.apache.spark.scheduler.SchedulingMode
-import org.apache.spark.ui.Page._
-import org.apache.spark.ui.UIUtils._
-
-/** Page showing list of all ongoing and recently finished stages and pools*/
-private[spark] class IndexPage(parent: JobProgressUI) {
- def listener = parent.listener
-
- def render(request: HttpServletRequest): Seq[Node] = {
- listener.synchronized {
- val activeStages = listener.activeStages.toSeq
- val completedStages = listener.completedStages.reverse.toSeq
- val failedStages = listener.failedStages.reverse.toSeq
- val now = System.currentTimeMillis()
-
- var activeTime = 0L
- for (tasks <- listener.stageIdToTasksActive.values; t <- tasks) {
- activeTime += t.timeRunning(now)
- }
-
- val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent)
- val completedStagesTable = new StageTable(completedStages.sortBy(_.submissionTime).reverse,
- parent)
- val failedStagesTable = new StageTable(failedStages.sortBy(_.submissionTime).reverse, parent)
-
- val pools = listener.sc.getAllPools
- val poolTable = new PoolTable(pools, listener)
- val summary: NodeSeq =
-
-
-
- Total Duration:
- {parent.formatDuration(now - listener.sc.startTime)}
-
++
+ failedStagesTable.toNodeSeq
+
+ UIUtils.headerSparkPage(content, basePath, appName, "Spark Stages", parent.headerTabs, parent)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala
new file mode 100644
index 0000000000000..3308c8c8a3d37
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import javax.servlet.http.HttpServletRequest
+
+import org.apache.spark.SparkConf
+import org.apache.spark.scheduler.SchedulingMode
+import org.apache.spark.ui.{SparkUI, WebUITab}
+
+/** Web UI showing progress status of all jobs in the given SparkContext. */
+private[ui] class JobProgressTab(parent: SparkUI) extends WebUITab(parent, "stages") {
+ val appName = parent.appName
+ val basePath = parent.basePath
+ val live = parent.live
+ val sc = parent.sc
+ val conf = if (live) sc.conf else new SparkConf
+ val killEnabled = conf.getBoolean("spark.ui.killEnabled", true)
+ val listener = new JobProgressListener(conf)
+
+ attachPage(new JobProgressPage(this))
+ attachPage(new StagePage(this))
+ attachPage(new PoolPage(this))
+ parent.registerListener(listener)
+
+ def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR)
+
+ def handleKillRequest(request: HttpServletRequest) = {
+ if (killEnabled) {
+ val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
+ val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt
+ if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) {
+ sc.cancelStage(stageId)
+ }
+ // Do a quick pause here to give Spark time to kill the stage so it shows up as
+ // killed after the refresh. Note that this will block the serving thread so the
+ // time should be limited in duration.
+ Thread.sleep(100)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
deleted file mode 100644
index 557bce6b66353..0000000000000
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ui.jobs
-
-import java.text.SimpleDateFormat
-import javax.servlet.http.HttpServletRequest
-
-import scala.Seq
-
-import org.eclipse.jetty.server.Handler
-
-import org.apache.spark.SparkContext
-import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.util.Utils
-
-/** Web UI showing progress status of all jobs in the given SparkContext. */
-private[spark] class JobProgressUI(val sc: SparkContext) {
- private var _listener: Option[JobProgressListener] = None
- def listener = _listener.get
- val dateFmt = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
-
- private val indexPage = new IndexPage(this)
- private val stagePage = new StagePage(this)
- private val poolPage = new PoolPage(this)
-
- def start() {
- _listener = Some(new JobProgressListener(sc))
- sc.addSparkListener(listener)
- }
-
- def formatDuration(ms: Long) = Utils.msDurationToString(ms)
-
- def getHandlers = Seq[(String, Handler)](
- ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)),
- ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)),
- ("/stages", (request: HttpServletRequest) => indexPage.render(request))
- )
-}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
index eb7518a020840..0a2bf31833d2b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
@@ -21,27 +21,38 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
-import org.apache.spark.ui.Page._
-import org.apache.spark.ui.UIUtils._
+import org.apache.spark.scheduler.{Schedulable, StageInfo}
+import org.apache.spark.ui.{WebUIPage, UIUtils}
/** Page showing specific pool details */
-private[spark] class PoolPage(parent: JobProgressUI) {
- def listener = parent.listener
+private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") {
+ private val appName = parent.appName
+ private val basePath = parent.basePath
+ private val live = parent.live
+ private val sc = parent.sc
+ private val listener = parent.listener
def render(request: HttpServletRequest): Seq[Node] = {
listener.synchronized {
val poolName = request.getParameter("poolname")
val poolToActiveStages = listener.poolToActiveStages
- val activeStages = poolToActiveStages.get(poolName).toSeq.flatten
- val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent)
-
- val pool = listener.sc.getPoolForName(poolName).get
- val poolTable = new PoolTable(Seq(pool), listener)
-
- val content =
Summary
++ poolTable.toNodeSeq() ++
-
{activeStages.size} Active Stages
++ activeStagesTable.toNodeSeq()
-
- headerSparkPage(content, parent.sc, "Fair Scheduler Pool: " + poolName, Stages)
+ val activeStages = poolToActiveStages.get(poolName) match {
+ case Some(s) => s.values.toSeq
+ case None => Seq[StageInfo]()
+ }
+ val activeStagesTable =
+ new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, parent)
+
+ // For now, pool information is only accessible in live UIs
+ val pools = if (live) Seq(sc.getPoolForName(poolName).get) else Seq[Schedulable]()
+ val poolTable = new PoolTable(pools, parent)
+
+ val content =
+
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index ddc687a45a095..4bce472036f7d 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -22,54 +22,52 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
-import org.apache.spark.ExceptionFailure
-import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.scheduler.TaskInfo
-import org.apache.spark.ui.UIUtils._
-import org.apache.spark.ui.Page._
+import org.apache.spark.ui.{WebUIPage, UIUtils}
import org.apache.spark.util.{Utils, Distribution}
/** Page showing statistics and task list for a given stage */
-private[spark] class StagePage(parent: JobProgressUI) {
- def listener = parent.listener
- val dateFmt = parent.dateFmt
+private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
+ private val appName = parent.appName
+ private val basePath = parent.basePath
+ private val listener = parent.listener
def render(request: HttpServletRequest): Seq[Node] = {
listener.synchronized {
val stageId = request.getParameter("id").toInt
- val now = System.currentTimeMillis()
- if (!listener.stageIdToTaskInfos.contains(stageId)) {
+ if (!listener.stageIdToTaskData.contains(stageId)) {
val content =
Summary Metrics
No tasks have started yet
Tasks
No tasks have started yet
- return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages)
+ return UIUtils.headerSparkPage(content, basePath, appName,
+ "Details for Stage %s".format(stageId), parent.headerTabs, parent)
}
- val tasks = listener.stageIdToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
+ val tasks = listener.stageIdToTaskData(stageId).values.toSeq.sortBy(_.taskInfo.launchTime)
- val numCompleted = tasks.count(_._1.finished)
+ val numCompleted = tasks.count(_.taskInfo.finished)
val shuffleReadBytes = listener.stageIdToShuffleRead.getOrElse(stageId, 0L)
val hasShuffleRead = shuffleReadBytes > 0
val shuffleWriteBytes = listener.stageIdToShuffleWrite.getOrElse(stageId, 0L)
val hasShuffleWrite = shuffleWriteBytes > 0
val memoryBytesSpilled = listener.stageIdToMemoryBytesSpilled.getOrElse(stageId, 0L)
val diskBytesSpilled = listener.stageIdToDiskBytesSpilled.getOrElse(stageId, 0L)
- val hasBytesSpilled = (memoryBytesSpilled > 0 && diskBytesSpilled > 0)
+ val hasBytesSpilled = memoryBytesSpilled > 0 && diskBytesSpilled > 0
var activeTime = 0L
- listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
+ val now = System.currentTimeMillis
+ val tasksActive = listener.stageIdToTasksActive(stageId).values
+ tasksActive.foreach(activeTime += _.timeRunning(now))
- val finishedTasks = listener.stageIdToTaskInfos(stageId).filter(_._1.finished)
// scalastyle:off
val summary =
Total task time across all tasks:
- {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
+ {UIUtils.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
{if (hasShuffleRead)
@@ -104,42 +102,45 @@ private[spark] class StagePage(parent: JobProgressUI) {
{if (hasBytesSpilled) Seq("Shuffle Spill (Memory)", "Shuffle Spill (Disk)") else Nil} ++
Seq("Errors")
- val taskTable = listingTable(
+ val taskTable = UIUtils.listingTable(
taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite, hasBytesSpilled), tasks)
// Excludes tasks which failed and have incomplete metrics
- val validTasks = tasks.filter(t => t._1.status == "SUCCESS" && (t._2.isDefined))
+ val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined)
val summaryTable: Option[Seq[Node]] =
if (validTasks.size == 0) {
None
}
else {
- val serializationTimes = validTasks.map{case (info, metrics, exception) =>
- metrics.get.resultSerializationTime.toDouble}
+ val serializationTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
+ metrics.get.resultSerializationTime.toDouble
+ }
val serializationQuantiles =
"Result serialization time" +: Distribution(serializationTimes).
- get.getQuantiles().map(ms => parent.formatDuration(ms.toLong))
+ get.getQuantiles().map(ms => UIUtils.formatDuration(ms.toLong))
- val serviceTimes = validTasks.map{case (info, metrics, exception) =>
- metrics.get.executorRunTime.toDouble}
- val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map(
- ms => parent.formatDuration(ms.toLong))
+ val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
+ metrics.get.executorRunTime.toDouble
+ }
+ val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles()
+ .map(ms => UIUtils.formatDuration(ms.toLong))
- val gettingResultTimes = validTasks.map{case (info, metrics, exception) =>
+ val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) =>
if (info.gettingResultTime > 0) {
(info.finishTime - info.gettingResultTime).toDouble
} else {
0.0
}
}
- val gettingResultQuantiles = ("Time spent fetching task results" +:
- Distribution(gettingResultTimes).get.getQuantiles().map(
- millis => parent.formatDuration(millis.toLong)))
+ val gettingResultQuantiles = "Time spent fetching task results" +:
+ Distribution(gettingResultTimes).get.getQuantiles().map { millis =>
+ UIUtils.formatDuration(millis.toLong)
+ }
// The scheduler delay includes the network delay to send the task to the worker
// machine and to send back the result (but not the time to fetch the task result,
// if it needed to be fetched from the block manager on the worker).
- val schedulerDelays = validTasks.map{case (info, metrics, exception) =>
+ val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) =>
val totalExecutionTime = {
if (info.gettingResultTime > 0) {
(info.gettingResultTime - info.launchTime).toDouble
@@ -149,35 +150,32 @@ private[spark] class StagePage(parent: JobProgressUI) {
}
totalExecutionTime - metrics.get.executorRunTime
}
- val schedulerDelayQuantiles = ("Scheduler delay" +:
- Distribution(schedulerDelays).get.getQuantiles().map(
- millis => parent.formatDuration(millis.toLong)))
+ val schedulerDelayQuantiles = "Scheduler delay" +:
+ Distribution(schedulerDelays).get.getQuantiles().map { millis =>
+ UIUtils.formatDuration(millis.toLong)
+ }
def getQuantileCols(data: Seq[Double]) =
Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong))
- val shuffleReadSizes = validTasks.map {
- case(info, metrics, exception) =>
- metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble
+ val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
+ metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble
}
val shuffleReadQuantiles = "Shuffle Read (Remote)" +: getQuantileCols(shuffleReadSizes)
- val shuffleWriteSizes = validTasks.map {
- case(info, metrics, exception) =>
- metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble
+ val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
+ metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble
}
val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes)
- val memoryBytesSpilledSizes = validTasks.map {
- case(info, metrics, exception) =>
- metrics.get.memoryBytesSpilled.toDouble
+ val memoryBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
+ metrics.get.memoryBytesSpilled.toDouble
}
val memoryBytesSpilledQuantiles = "Shuffle spill (memory)" +:
getQuantileCols(memoryBytesSpilledSizes)
- val diskBytesSpilledSizes = validTasks.map {
- case(info, metrics, exception) =>
- metrics.get.diskBytesSpilled.toDouble
+ val diskBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
+ metrics.get.diskBytesSpilled.toDouble
}
val diskBytesSpilledQuantiles = "Shuffle spill (disk)" +:
getQuantileCols(diskBytesSpilledSizes)
@@ -195,98 +193,104 @@ private[spark] class StagePage(parent: JobProgressUI) {
val quantileHeaders = Seq("Metric", "Min", "25th percentile",
"Median", "75th percentile", "Max")
def quantileRow(data: Seq[String]): Seq[Node] =
{data.map(d =>
{d}
)}
- Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true))
+ Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true))
}
- val executorTable = new ExecutorTable(parent, stageId)
+ val executorTable = new ExecutorTable(stageId, parent)
val content =
summary ++
Summary Metrics for {numCompleted} Completed Tasks
++
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++
-
Aggregated Metrics by Executor
++ executorTable.toNodeSeq() ++
+
Aggregated Metrics by Executor
++ executorTable.toNodeSeq ++
Tasks
++ taskTable
- headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages)
+ UIUtils.headerSparkPage(content, basePath, appName, "Details for Stage %d".format(stageId),
+ parent.headerTabs, parent)
}
}
def taskRow(shuffleRead: Boolean, shuffleWrite: Boolean, bytesSpilled: Boolean)
- (taskData: (TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])): Seq[Node] = {
+ (taskData: TaskUIData): Seq[Node] = {
def fmtStackTrace(trace: Seq[StackTraceElement]): Seq[Node] =
trace.map(e => {e.toString})
- val (info, metrics, exception) = taskData
-
- val duration = if (info.status == "RUNNING") info.timeRunning(System.currentTimeMillis())
- else metrics.map(m => m.executorRunTime).getOrElse(1)
- val formatDuration = if (info.status == "RUNNING") parent.formatDuration(duration)
- else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("")
- val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L)
- val serializationTime = metrics.map(m => m.resultSerializationTime).getOrElse(0L)
-
- val maybeShuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s => s.remoteBytesRead)
- val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("")
- val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("")
-
- val maybeShuffleWrite =
- metrics.flatMap{m => m.shuffleWriteMetrics}.map(s => s.shuffleBytesWritten)
- val shuffleWriteSortable = maybeShuffleWrite.map(_.toString).getOrElse("")
- val shuffleWriteReadable = maybeShuffleWrite.map(Utils.bytesToString).getOrElse("")
-
- val maybeWriteTime = metrics.flatMap(m => m.shuffleWriteMetrics).map(s => s.shuffleWriteTime)
- val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("")
- val writeTimeReadable = maybeWriteTime.map( t => t / (1000 * 1000)).map{ ms =>
- if (ms == 0) "" else parent.formatDuration(ms)}.getOrElse("")
-
- val maybeMemoryBytesSpilled = metrics.map(m => m.memoryBytesSpilled)
- val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.map(_.toString).getOrElse("")
- val memoryBytesSpilledReadable = maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("")
-
- val maybeDiskBytesSpilled = metrics.map{m => m.diskBytesSpilled}
- val diskBytesSpilledSortable = maybeDiskBytesSpilled.map(_.toString).getOrElse("")
- val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("")
-
-
)
+ }
- private def stageRow(s: StageInfo): Seq[Node] = {
+ protected def stageRow(s: StageInfo): Seq[Node] = {
+ val poolName = listener.stageIdToPool.get(s.stageId)
val submissionTime = s.submissionTime match {
- case Some(t) => dateFmt.format(new Date(t))
+ case Some(t) => UIUtils.formatDate(new Date(t))
case None => "Unknown"
}
-
+ val finishTime = s.completionTime.getOrElse(System.currentTimeMillis)
+ val duration = s.submissionTime.map { t =>
+ if (finishTime > t) finishTime - t else System.currentTimeMillis - t
+ }
+ val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown")
+ val startedTasks =
+ listener.stageIdToTasksActive.getOrElse(s.stageId, HashMap[Long, TaskInfo]()).size
+ val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0)
+ val failedTasks = listener.stageIdToTasksFailed.getOrElse(s.stageId, 0) match {
+ case f if f > 0 => "(%s failed)".format(f)
+ case _ => ""
+ }
+ val totalTasks = s.numTasks
val shuffleReadSortable = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L)
val shuffleRead = shuffleReadSortable match {
case 0 => ""
case b => Utils.bytesToString(b)
}
-
val shuffleWriteSortable = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L)
val shuffleWrite = shuffleWriteSortable match {
case 0 => ""
case b => Utils.bytesToString(b)
}
+
+ }
- val startedTasks = listener.stageIdToTasksActive.getOrElse(s.stageId, HashSet[TaskInfo]()).size
- val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0)
- val failedTasks = listener.stageIdToTasksFailed.getOrElse(s.stageId, 0) match {
- case f if f > 0 => "(%s failed)".format(f)
- case _ => ""
- }
- val totalTasks = s.numTasks
+ /** Render an HTML row that represents a stage */
+ private def renderStageRow(s: StageInfo): Seq[Node] =
{stageRow(s)}
+}
- val poolName = listener.stageIdToPool.get(s.stageId)
+private[ui] class FailedStageTable(
+ stages: Seq[StageInfo],
+ parent: JobProgressTab,
+ killEnabled: Boolean = false)
+ extends StageTableBase(stages, parent, killEnabled) {
- val nameLink =
- {s.name}
- val description = listener.stageIdToDescription.get(s.stageId)
- .map(d =>
{d}
{nameLink}
).getOrElse(nameLink)
- val finishTime = s.completionTime.getOrElse(System.currentTimeMillis())
- val duration = s.submissionTime.map(t => finishTime - t)
-
-
+
+ override protected def stageRow(s: StageInfo): Seq[Node] = {
+ val basicColumns = super.stageRow(s)
+ val failureReason =
{s.failureReason.getOrElse("")}
+ basicColumns ++ failureReason
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
deleted file mode 100644
index dc18eab74e0da..0000000000000
--- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ui.storage
-
-import javax.servlet.http.HttpServletRequest
-
-import org.eclipse.jetty.server.Handler
-
-import org.apache.spark.{Logging, SparkContext}
-import org.apache.spark.ui.JettyUtils._
-
-/** Web UI showing storage status of all RDD's in the given SparkContext. */
-private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging {
- val indexPage = new IndexPage(this)
- val rddPage = new RDDPage(this)
-
- def getHandlers = Seq[(String, Handler)](
- ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)),
- ("/storage", (request: HttpServletRequest) => indexPage.render(request))
- )
-}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala
deleted file mode 100644
index 6a3c41fb1155d..0000000000000
--- a/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ui.storage
-
-import javax.servlet.http.HttpServletRequest
-
-import scala.xml.Node
-
-import org.apache.spark.storage.{RDDInfo, StorageUtils}
-import org.apache.spark.ui.Page._
-import org.apache.spark.ui.UIUtils._
-import org.apache.spark.util.Utils
-
-/** Page showing list of RDD's currently stored in the cluster */
-private[spark] class IndexPage(parent: BlockManagerUI) {
- val sc = parent.sc
-
- def render(request: HttpServletRequest): Seq[Node] = {
- val storageStatusList = sc.getExecutorStorageStatus
- // Calculate macro-level statistics
-
- val rddHeaders = Seq(
- "RDD Name",
- "Storage Level",
- "Cached Partitions",
- "Fraction Cached",
- "Size in Memory",
- "Size on Disk")
- val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
- val content = listingTable(rddHeaders, rddRow, rdds)
-
- headerSparkPage(content, parent.sc, "Storage ", Storage)
- }
-
- def rddRow(rdd: RDDInfo): Seq[Node] = {
-
- }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
new file mode 100644
index 0000000000000..b66edd91f56c0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.storage
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import org.apache.spark.storage.RDDInfo
+import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.util.Utils
+
+/** Page showing list of RDD's currently stored in the cluster */
+private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") {
+ private val appName = parent.appName
+ private val basePath = parent.basePath
+ private val listener = parent.listener
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val rdds = listener.rddInfoList
+ val content = UIUtils.listingTable(rddHeader, rddRow, rdds)
+ UIUtils.headerSparkPage(content, basePath, appName, "Storage ", parent.headerTabs, parent)
+ }
+
+ /** Header fields for the RDD table */
+ private def rddHeader = Seq(
+ "RDD Name",
+ "Storage Level",
+ "Cached Partitions",
+ "Fraction Cached",
+ "Size in Memory",
+ "Size in Tachyon",
+ "Size on Disk")
+
+ /** Render an HTML row representing an RDD */
+ private def rddRow(rdd: RDDInfo): Seq[Node] = {
+